1use crate::daemon::error::DaemonError;
2use crate::daemon::lifecycle::{bootstrap_daemon, ensure_daemon_running, socket_path};
3use crate::envd::error::EnvDaemonError;
4use crate::envd::lifecycle::{ensure_env_running, socket_path as env_socket_path};
5use crate::protocol::{InstanceAction, Request, RequestAction, Response};
6use thiserror::Error;
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::net::UnixStream;
9use tokio::time::{Duration, sleep, timeout};
10
11#[derive(Debug, Error)]
12pub enum ConnectionError {
13 #[error(transparent)]
14 Daemon(#[from] DaemonError),
15 #[error(transparent)]
16 EnvDaemon(#[from] EnvDaemonError),
17 #[error("connection timeout")]
18 Timeout,
19 #[error("connection error: {0}")]
20 Io(#[from] std::io::Error),
21 #[error("serialization error: {0}")]
22 Serde(#[from] serde_json::Error),
23 #[error("response missing")]
24 MissingResponse,
25}
26
27pub async fn send_request(session: &str, request: &Request) -> Result<Response, ConnectionError> {
28 SessionConnector.prepare(session, request).await?;
29 RequestTransport::default()
30 .send_to_socket(&socket_path(session), request)
31 .await
32}
33
34pub async fn send_env_request(env: &str, request: &Request) -> Result<Response, ConnectionError> {
35 EnvConnector.prepare(env).await?;
36 RequestTransport::default()
37 .send_to_socket(&env_socket_path(env), request)
38 .await
39}
40
41#[derive(Debug, Clone, Copy)]
42struct RequestTransport {
43 max_attempts: u32,
44 retry_delay: Duration,
45 connect_timeout: Duration,
46 write_timeout: Duration,
47 read_timeout: Duration,
48}
49
50impl Default for RequestTransport {
51 fn default() -> Self {
52 Self {
53 max_attempts: 5,
54 retry_delay: Duration::from_millis(200),
55 connect_timeout: Duration::from_secs(30),
56 write_timeout: Duration::from_secs(5),
57 read_timeout: Duration::from_secs(30),
58 }
59 }
60}
61
62impl RequestTransport {
63 async fn send_to_socket(
64 self,
65 socket: &std::path::Path,
66 request: &Request,
67 ) -> Result<Response, ConnectionError> {
68 let payload = {
69 let mut line = serde_json::to_string(request)?;
70 line.push('\n');
71 line
72 };
73
74 let mut attempt = 0_u32;
75 loop {
76 match self.send_once(socket, &payload).await {
77 Ok(response) => return Ok(response),
78 Err(err) => {
79 attempt += 1;
80 if attempt >= self.max_attempts {
81 return Err(err);
82 }
83 sleep(self.retry_delay).await;
84 }
85 }
86 }
87 }
88
89 async fn send_once(
90 self,
91 socket: &std::path::Path,
92 payload: &str,
93 ) -> Result<Response, ConnectionError> {
94 let mut stream = timeout(self.connect_timeout, UnixStream::connect(socket))
95 .await
96 .map_err(|_| ConnectionError::Timeout)??;
97 timeout(self.write_timeout, stream.write_all(payload.as_bytes()))
98 .await
99 .map_err(|_| ConnectionError::Timeout)??;
100
101 let mut reader = BufReader::new(stream);
102 let mut line = String::new();
103 timeout(self.read_timeout, reader.read_line(&mut line))
104 .await
105 .map_err(|_| ConnectionError::Timeout)??;
106 if line.is_empty() {
107 return Err(ConnectionError::MissingResponse);
108 }
109 let response = serde_json::from_str::<Response>(line.trim_end())?;
110 Ok(response)
111 }
112}
113
114#[derive(Debug, Clone, Copy, Default)]
115struct SessionConnector;
116
117impl SessionConnector {
118 async fn prepare(self, session: &str, request: &Request) -> Result<(), ConnectionError> {
119 match &request.action {
120 RequestAction::Instance(InstanceAction::Load { load_spec }) => {
121 bootstrap_daemon(session, load_spec).await?;
122 }
123 _ => ensure_daemon_running(session).await?,
124 }
125 Ok(())
126 }
127}
128
129#[derive(Debug, Clone, Copy, Default)]
130struct EnvConnector;
131
132impl EnvConnector {
133 async fn prepare(self, env: &str) -> Result<(), ConnectionError> {
134 ensure_env_running(env).await?;
135 Ok(())
136 }
137}