Skip to main content

agent_can/
connection.rs

1use crate::daemon::error::DaemonError;
2use crate::daemon::{config::DaemonConfig, lifecycle};
3use crate::ipc;
4use crate::protocol::{Request, Response};
5use thiserror::Error;
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::time::{Duration, sleep, timeout};
8
9#[derive(Debug, Error)]
10pub enum ConnectionError {
11    #[error(transparent)]
12    Daemon(#[from] DaemonError),
13    #[error("connection timeout")]
14    Timeout,
15    #[error("connection error: {0}")]
16    Io(#[from] std::io::Error),
17    #[error("serialization error: {0}")]
18    Serde(#[from] serde_json::Error),
19    #[error("response missing")]
20    MissingResponse,
21}
22
23pub async fn send_request(bus: &str, request: &Request) -> Result<Response, ConnectionError> {
24    lifecycle::ensure_daemon_running(bus).await?;
25    RequestTransport::default()
26        .send_to_endpoint(&lifecycle::socket_path(bus), request)
27        .await
28}
29
30pub async fn bootstrap_and_send_request(
31    bus: &str,
32    config: &DaemonConfig,
33    request: &Request,
34) -> Result<Response, ConnectionError> {
35    lifecycle::bootstrap_daemon(bus, config).await?;
36    RequestTransport::default()
37        .send_to_endpoint(&lifecycle::socket_path(bus), request)
38        .await
39}
40
41#[derive(Debug, Clone, Copy)]
42struct RequestTransport {
43    max_attempts: u32,
44    retry_delay_base: Duration,
45    connect_timeout: Duration,
46    write_timeout: Duration,
47    read_timeout: Duration,
48}
49
50impl Default for RequestTransport {
51    fn default() -> Self {
52        Self {
53            max_attempts: 3,
54            retry_delay_base: Duration::from_millis(100),
55            connect_timeout: Duration::from_secs(2),
56            write_timeout: Duration::from_secs(5),
57            read_timeout: Duration::from_secs(30),
58        }
59    }
60}
61
62impl RequestTransport {
63    async fn send_to_endpoint(
64        self,
65        endpoint: &std::path::Path,
66        request: &Request,
67    ) -> Result<Response, ConnectionError> {
68        let payload = {
69            let mut line = serde_json::to_string(request)?;
70            line.push('\n');
71            line
72        };
73
74        let mut attempt = 0_u32;
75        loop {
76            match self.send_once(endpoint, &payload).await {
77                Ok(response) => return Ok(response),
78                Err(err) => {
79                    attempt += 1;
80                    if attempt >= self.max_attempts || !self.should_retry(&err) {
81                        return Err(err);
82                    }
83                    sleep(self.retry_delay_for_attempt(attempt)).await;
84                }
85            }
86        }
87    }
88
89    fn should_retry(self, err: &ConnectionError) -> bool {
90        match err {
91            ConnectionError::Timeout => true,
92            ConnectionError::Io(io_err) => matches!(
93                io_err.kind(),
94                std::io::ErrorKind::NotFound
95                    | std::io::ErrorKind::ConnectionRefused
96                    | std::io::ErrorKind::ConnectionAborted
97                    | std::io::ErrorKind::ConnectionReset
98                    | std::io::ErrorKind::TimedOut
99                    | std::io::ErrorKind::Interrupted
100            ),
101            ConnectionError::Daemon(_)
102            | ConnectionError::Serde(_)
103            | ConnectionError::MissingResponse => false,
104        }
105    }
106
107    fn retry_delay_for_attempt(self, attempt: u32) -> Duration {
108        let shift = attempt.saturating_sub(1).min(3);
109        self.retry_delay_base
110            .checked_mul(1_u32 << shift)
111            .unwrap_or(Duration::MAX)
112    }
113
114    async fn send_once(
115        self,
116        endpoint: &std::path::Path,
117        payload: &str,
118    ) -> Result<Response, ConnectionError> {
119        let mut stream = timeout(self.connect_timeout, ipc::connect(endpoint))
120            .await
121            .map_err(|_| ConnectionError::Timeout)??;
122        timeout(self.write_timeout, stream.write_all(payload.as_bytes()))
123            .await
124            .map_err(|_| ConnectionError::Timeout)??;
125
126        let mut reader = BufReader::new(stream);
127        let mut line = String::new();
128        timeout(self.read_timeout, reader.read_line(&mut line))
129            .await
130            .map_err(|_| ConnectionError::Timeout)??;
131        if line.is_empty() {
132            return Err(ConnectionError::MissingResponse);
133        }
134        Ok(serde_json::from_str::<Response>(line.trim_end())?)
135    }
136}