Skip to main content

agent_sim/
connection.rs

1use crate::daemon::error::DaemonError;
2use crate::daemon::lifecycle::{bootstrap_daemon, ensure_daemon_running, socket_path};
3use crate::envd::error::EnvDaemonError;
4use crate::envd::lifecycle::{ensure_env_running, socket_path as env_socket_path};
5use crate::protocol::{InstanceAction, Request, RequestAction, Response};
6use thiserror::Error;
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::net::UnixStream;
9use tokio::time::{Duration, sleep, timeout};
10
11#[derive(Debug, Error)]
12pub enum ConnectionError {
13    #[error(transparent)]
14    Daemon(#[from] DaemonError),
15    #[error(transparent)]
16    EnvDaemon(#[from] EnvDaemonError),
17    #[error("connection timeout")]
18    Timeout,
19    #[error("connection error: {0}")]
20    Io(#[from] std::io::Error),
21    #[error("serialization error: {0}")]
22    Serde(#[from] serde_json::Error),
23    #[error("response missing")]
24    MissingResponse,
25}
26
27pub async fn send_request(session: &str, request: &Request) -> Result<Response, ConnectionError> {
28    SessionConnector.prepare(session, request).await?;
29    RequestTransport::default()
30        .send_to_socket(&socket_path(session), request)
31        .await
32}
33
34pub async fn send_env_request(env: &str, request: &Request) -> Result<Response, ConnectionError> {
35    EnvConnector.prepare(env).await?;
36    RequestTransport::default()
37        .send_to_socket(&env_socket_path(env), request)
38        .await
39}
40
41#[derive(Debug, Clone, Copy)]
42struct RequestTransport {
43    max_attempts: u32,
44    retry_delay: Duration,
45    connect_timeout: Duration,
46    write_timeout: Duration,
47    read_timeout: Duration,
48}
49
50impl Default for RequestTransport {
51    fn default() -> Self {
52        Self {
53            max_attempts: 5,
54            retry_delay: Duration::from_millis(200),
55            connect_timeout: Duration::from_secs(30),
56            write_timeout: Duration::from_secs(5),
57            read_timeout: Duration::from_secs(30),
58        }
59    }
60}
61
62impl RequestTransport {
63    async fn send_to_socket(
64        self,
65        socket: &std::path::Path,
66        request: &Request,
67    ) -> Result<Response, ConnectionError> {
68        let payload = {
69            let mut line = serde_json::to_string(request)?;
70            line.push('\n');
71            line
72        };
73
74        let mut attempt = 0_u32;
75        loop {
76            match self.send_once(socket, &payload).await {
77                Ok(response) => return Ok(response),
78                Err(err) => {
79                    attempt += 1;
80                    if attempt >= self.max_attempts {
81                        return Err(err);
82                    }
83                    sleep(self.retry_delay).await;
84                }
85            }
86        }
87    }
88
89    async fn send_once(
90        self,
91        socket: &std::path::Path,
92        payload: &str,
93    ) -> Result<Response, ConnectionError> {
94        let mut stream = timeout(self.connect_timeout, UnixStream::connect(socket))
95            .await
96            .map_err(|_| ConnectionError::Timeout)??;
97        timeout(self.write_timeout, stream.write_all(payload.as_bytes()))
98            .await
99            .map_err(|_| ConnectionError::Timeout)??;
100
101        let mut reader = BufReader::new(stream);
102        let mut line = String::new();
103        timeout(self.read_timeout, reader.read_line(&mut line))
104            .await
105            .map_err(|_| ConnectionError::Timeout)??;
106        if line.is_empty() {
107            return Err(ConnectionError::MissingResponse);
108        }
109        let response = serde_json::from_str::<Response>(line.trim_end())?;
110        Ok(response)
111    }
112}
113
114#[derive(Debug, Clone, Copy, Default)]
115struct SessionConnector;
116
117impl SessionConnector {
118    async fn prepare(self, session: &str, request: &Request) -> Result<(), ConnectionError> {
119        match &request.action {
120            RequestAction::Instance(InstanceAction::Load { load_spec }) => {
121                bootstrap_daemon(session, load_spec).await?;
122            }
123            _ => ensure_daemon_running(session).await?,
124        }
125        Ok(())
126    }
127}
128
129#[derive(Debug, Clone, Copy, Default)]
130struct EnvConnector;
131
132impl EnvConnector {
133    async fn prepare(self, env: &str) -> Result<(), ConnectionError> {
134        ensure_env_running(env).await?;
135        Ok(())
136    }
137}