Skip to main content

agent_can/
connection.rs

1use crate::daemon::error::DaemonError;
2use crate::daemon::{config::DaemonConfig, lifecycle};
3use crate::ipc;
4use crate::protocol::{
5    ConnectResult, Request, RequestAction, Response, ResponseData, SessionStatus,
6    normalize_connect_request_paths, normalize_trace_start_request_path,
7};
8use thiserror::Error;
9use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
10use tokio::time::{Duration, sleep, timeout};
11
12#[derive(Debug, Error)]
13pub enum ConnectionError {
14    #[error(transparent)]
15    Daemon(#[from] DaemonError),
16    #[error("connection timeout")]
17    Timeout,
18    #[error("connection error: {0}")]
19    Io(#[from] std::io::Error),
20    #[error("serialization error: {0}")]
21    Serde(#[from] serde_json::Error),
22    #[error("response missing")]
23    MissingResponse,
24    #[error("unexpected response: {0}")]
25    UnexpectedResponse(&'static str),
26    #[error(transparent)]
27    PathNormalization(#[from] crate::protocol::PathNormalizationError),
28}
29
30pub async fn execute_request(request: Request) -> Result<Response, ConnectionError> {
31    let request = normalize_request(request)?;
32    match &request.action {
33        RequestAction::AdaptersList => Ok(Response::ok(
34            request.id,
35            ResponseData::AdaptersList {
36                adapters: crate::can::available_adapters(),
37            },
38        )),
39        RequestAction::Connect(connect) => connect_request(request.id, connect.clone()).await,
40        RequestAction::Disconnect => {
41            let response = send_to_running_daemon(&request).await?;
42            if response.success {
43                lifecycle::wait_for_daemon_shutdown(Duration::from_secs(5)).await?;
44            }
45            Ok(response)
46        }
47        _ => send_to_running_daemon(&request).await,
48    }
49}
50
51fn normalize_request(mut request: Request) -> Result<Request, ConnectionError> {
52    request.action = match request.action {
53        RequestAction::Connect(connect) => {
54            RequestAction::Connect(normalize_connect_request_paths(connect)?)
55        }
56        RequestAction::TraceStart(trace) => {
57            RequestAction::TraceStart(normalize_trace_start_request_path(trace)?)
58        }
59        other => other,
60    };
61    Ok(request)
62}
63
64async fn connect_request(
65    request_id: uuid::Uuid,
66    connect: crate::protocol::ConnectRequest,
67) -> Result<Response, ConnectionError> {
68    let config = DaemonConfig { connect };
69    if lifecycle::ensure_daemon_running().await.is_err() {
70        match lifecycle::bootstrap_daemon(&config).await {
71            Ok(()) => {
72                let status = read_status().await?;
73                return Ok(Response::ok(
74                    request_id,
75                    ResponseData::Connected(ConnectResult {
76                        created: true,
77                        already_connected: false,
78                        status,
79                    }),
80                ));
81            }
82            Err(DaemonError::AlreadyRunning) => {}
83            Err(err) => return Err(err.into()),
84        }
85    }
86    send_to_running_daemon(&Request {
87        id: request_id,
88        action: RequestAction::Connect(config.connect),
89    })
90    .await
91}
92
93async fn read_status() -> Result<SessionStatus, ConnectionError> {
94    let response = send_to_running_daemon(&Request {
95        id: uuid::Uuid::new_v4(),
96        action: RequestAction::Status,
97    })
98    .await?;
99    match response.data {
100        Some(ResponseData::Status(status)) => Ok(status),
101        _ => Err(ConnectionError::UnexpectedResponse(
102            "expected status after bootstrap",
103        )),
104    }
105}
106
107async fn send_to_running_daemon(request: &Request) -> Result<Response, ConnectionError> {
108    lifecycle::ensure_daemon_running().await?;
109    RequestTransport::default()
110        .send_to_endpoint(&lifecycle::socket_path(), request)
111        .await
112}
113
114#[derive(Debug, Clone, Copy)]
115struct RequestTransport {
116    max_attempts: u32,
117    retry_delay_base: Duration,
118    connect_timeout: Duration,
119    write_timeout: Duration,
120    read_timeout: Duration,
121}
122
123impl Default for RequestTransport {
124    fn default() -> Self {
125        Self {
126            max_attempts: 3,
127            retry_delay_base: Duration::from_millis(100),
128            connect_timeout: Duration::from_secs(2),
129            write_timeout: Duration::from_secs(5),
130            read_timeout: Duration::from_secs(30),
131        }
132    }
133}
134
135impl RequestTransport {
136    async fn send_to_endpoint(
137        self,
138        endpoint: &std::path::Path,
139        request: &Request,
140    ) -> Result<Response, ConnectionError> {
141        let payload = {
142            let mut line = serde_json::to_string(request)?;
143            line.push('\n');
144            line
145        };
146
147        let mut attempt = 0_u32;
148        loop {
149            match self.send_once(endpoint, &payload).await {
150                Ok(response) => return Ok(response),
151                Err(err) => {
152                    attempt += 1;
153                    if attempt >= self.max_attempts || !self.should_retry(&err) {
154                        return Err(err);
155                    }
156                    sleep(self.retry_delay_for_attempt(attempt)).await;
157                }
158            }
159        }
160    }
161
162    fn should_retry(self, err: &ConnectionError) -> bool {
163        match err {
164            ConnectionError::Timeout => true,
165            ConnectionError::Io(io_err) => matches!(
166                io_err.kind(),
167                std::io::ErrorKind::NotFound
168                    | std::io::ErrorKind::ConnectionRefused
169                    | std::io::ErrorKind::ConnectionAborted
170                    | std::io::ErrorKind::ConnectionReset
171                    | std::io::ErrorKind::TimedOut
172                    | std::io::ErrorKind::Interrupted
173            ),
174            ConnectionError::Daemon(_)
175            | ConnectionError::Serde(_)
176            | ConnectionError::MissingResponse
177            | ConnectionError::UnexpectedResponse(_)
178            | ConnectionError::PathNormalization(_) => false,
179        }
180    }
181
182    fn retry_delay_for_attempt(self, attempt: u32) -> Duration {
183        let shift = attempt.saturating_sub(1).min(3);
184        self.retry_delay_base
185            .checked_mul(1_u32 << shift)
186            .unwrap_or(Duration::MAX)
187    }
188
189    async fn send_once(
190        self,
191        endpoint: &std::path::Path,
192        payload: &str,
193    ) -> Result<Response, ConnectionError> {
194        let mut stream = timeout(self.connect_timeout, ipc::connect(endpoint))
195            .await
196            .map_err(|_| ConnectionError::Timeout)??;
197        timeout(self.write_timeout, stream.write_all(payload.as_bytes()))
198            .await
199            .map_err(|_| ConnectionError::Timeout)??;
200
201        let mut reader = BufReader::new(stream);
202        let mut line = String::new();
203        timeout(self.read_timeout, reader.read_line(&mut line))
204            .await
205            .map_err(|_| ConnectionError::Timeout)??;
206        if line.is_empty() {
207            return Err(ConnectionError::MissingResponse);
208        }
209        Ok(serde_json::from_str::<Response>(line.trim_end())?)
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::normalize_request;
216    use crate::protocol::{DbcSpec, Request, RequestAction, TraceStartRequest};
217
218    #[test]
219    fn normalize_request_canonicalizes_connect_and_trace_paths() {
220        let temp = tempfile::tempdir().expect("tempdir");
221        let real_dir = temp.path().join("real");
222        std::fs::create_dir_all(&real_dir).expect("real dir");
223        let dbc_path = real_dir.join("bus.dbc");
224        std::fs::write(&dbc_path, "VERSION \"\"\n").expect("write dbc");
225        let expected_dbc_path = std::fs::canonicalize(&dbc_path)
226            .expect("canonical dbc path")
227            .display()
228            .to_string();
229
230        #[cfg(unix)]
231        let (dbc_input, trace_input, expected_trace) = {
232            let link_dir = temp.path().join("link");
233            std::os::unix::fs::symlink(&real_dir, &link_dir).expect("symlink dir");
234            let canonical_real_dir = std::fs::canonicalize(&real_dir).expect("canonical real dir");
235            (
236                link_dir.join("bus.dbc").display().to_string(),
237                link_dir
238                    .join("captures")
239                    .join("run.asc")
240                    .display()
241                    .to_string(),
242                canonical_real_dir
243                    .join("captures")
244                    .join("run.asc")
245                    .display()
246                    .to_string(),
247            )
248        };
249
250        #[cfg(not(unix))]
251        let (dbc_input, trace_input, expected_trace) = (
252            dbc_path.display().to_string(),
253            real_dir
254                .join("captures")
255                .join("run.asc")
256                .display()
257                .to_string(),
258            real_dir
259                .join("captures")
260                .join("run.asc")
261                .display()
262                .to_string(),
263        );
264
265        let request = Request {
266            id: uuid::Uuid::new_v4(),
267            action: RequestAction::Connect(crate::protocol::ConnectRequest {
268                adapter: "pcan".to_string(),
269                bitrate: 500_000,
270                bitrate_data: None,
271                fd: false,
272                dbcs: vec![DbcSpec {
273                    alias: "main".to_string(),
274                    path: dbc_input,
275                }],
276            }),
277        };
278        let normalized = normalize_request(request).expect("normalize connect");
279        let RequestAction::Connect(connect) = normalized.action else {
280            panic!("expected connect request");
281        };
282        assert_eq!(connect.dbcs[0].path, expected_dbc_path);
283
284        let request = Request {
285            id: uuid::Uuid::new_v4(),
286            action: RequestAction::TraceStart(TraceStartRequest { path: trace_input }),
287        };
288        let normalized = normalize_request(request).expect("normalize trace");
289        let RequestAction::TraceStart(trace) = normalized.action else {
290            panic!("expected trace request");
291        };
292        assert_eq!(trace.path, expected_trace);
293    }
294
295    #[test]
296    fn normalize_request_rejects_relative_connect_dbc_paths() {
297        let request = Request {
298            id: uuid::Uuid::new_v4(),
299            action: RequestAction::Connect(crate::protocol::ConnectRequest {
300                adapter: "pcan".to_string(),
301                bitrate: 500_000,
302                bitrate_data: None,
303                fd: false,
304                dbcs: vec![DbcSpec {
305                    alias: "main".to_string(),
306                    path: "relative.dbc".to_string(),
307                }],
308            }),
309        };
310
311        let err = normalize_request(request).expect_err("relative DBC path must fail");
312        assert!(err.to_string().contains("must be absolute"));
313    }
314
315    #[test]
316    fn normalize_request_rejects_relative_trace_paths() {
317        let request = Request {
318            id: uuid::Uuid::new_v4(),
319            action: RequestAction::TraceStart(TraceStartRequest {
320                path: "captures/run.asc".to_string(),
321            }),
322        };
323
324        let err = normalize_request(request).expect_err("relative trace path must fail");
325        assert!(err.to_string().contains("must be absolute"));
326    }
327}