1use crate::daemon::error::DaemonError;
2use crate::daemon::{config::DaemonConfig, lifecycle};
3use crate::ipc;
4use crate::protocol::{
5 ConnectResult, Request, RequestAction, Response, ResponseData, SessionStatus,
6 normalize_connect_request_paths, normalize_trace_start_request_path,
7};
8use thiserror::Error;
9use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
10use tokio::time::{Duration, sleep, timeout};
11
12#[derive(Debug, Error)]
13pub enum ConnectionError {
14 #[error(transparent)]
15 Daemon(#[from] DaemonError),
16 #[error("connection timeout")]
17 Timeout,
18 #[error("connection error: {0}")]
19 Io(#[from] std::io::Error),
20 #[error("serialization error: {0}")]
21 Serde(#[from] serde_json::Error),
22 #[error("response missing")]
23 MissingResponse,
24 #[error("unexpected response: {0}")]
25 UnexpectedResponse(&'static str),
26 #[error(transparent)]
27 PathNormalization(#[from] crate::protocol::PathNormalizationError),
28}
29
30pub async fn execute_request(request: Request) -> Result<Response, ConnectionError> {
31 let request = normalize_request(request)?;
32 match &request.action {
33 RequestAction::AdaptersList => Ok(Response::ok(
34 request.id,
35 ResponseData::AdaptersList {
36 adapters: crate::can::available_adapters(),
37 },
38 )),
39 RequestAction::Connect(connect) => connect_request(request.id, connect.clone()).await,
40 RequestAction::Disconnect => {
41 let response = send_to_running_daemon(&request).await?;
42 if response.success {
43 lifecycle::wait_for_daemon_shutdown(Duration::from_secs(5)).await?;
44 }
45 Ok(response)
46 }
47 _ => send_to_running_daemon(&request).await,
48 }
49}
50
51fn normalize_request(mut request: Request) -> Result<Request, ConnectionError> {
52 request.action = match request.action {
53 RequestAction::Connect(connect) => {
54 RequestAction::Connect(normalize_connect_request_paths(connect)?)
55 }
56 RequestAction::TraceStart(trace) => {
57 RequestAction::TraceStart(normalize_trace_start_request_path(trace)?)
58 }
59 other => other,
60 };
61 Ok(request)
62}
63
64async fn connect_request(
65 request_id: uuid::Uuid,
66 connect: crate::protocol::ConnectRequest,
67) -> Result<Response, ConnectionError> {
68 let config = DaemonConfig { connect };
69 if lifecycle::ensure_daemon_running().await.is_err() {
70 match lifecycle::bootstrap_daemon(&config).await {
71 Ok(()) => {
72 let status = read_status().await?;
73 return Ok(Response::ok(
74 request_id,
75 ResponseData::Connected(ConnectResult {
76 created: true,
77 already_connected: false,
78 status,
79 }),
80 ));
81 }
82 Err(DaemonError::AlreadyRunning) => {}
83 Err(err) => return Err(err.into()),
84 }
85 }
86 send_to_running_daemon(&Request {
87 id: request_id,
88 action: RequestAction::Connect(config.connect),
89 })
90 .await
91}
92
93async fn read_status() -> Result<SessionStatus, ConnectionError> {
94 let response = send_to_running_daemon(&Request {
95 id: uuid::Uuid::new_v4(),
96 action: RequestAction::Status,
97 })
98 .await?;
99 match response.data {
100 Some(ResponseData::Status(status)) => Ok(status),
101 _ => Err(ConnectionError::UnexpectedResponse(
102 "expected status after bootstrap",
103 )),
104 }
105}
106
107async fn send_to_running_daemon(request: &Request) -> Result<Response, ConnectionError> {
108 lifecycle::ensure_daemon_running().await?;
109 RequestTransport::default()
110 .send_to_endpoint(&lifecycle::socket_path(), request)
111 .await
112}
113
114#[derive(Debug, Clone, Copy)]
115struct RequestTransport {
116 max_attempts: u32,
117 retry_delay_base: Duration,
118 connect_timeout: Duration,
119 write_timeout: Duration,
120 read_timeout: Duration,
121}
122
123impl Default for RequestTransport {
124 fn default() -> Self {
125 Self {
126 max_attempts: 3,
127 retry_delay_base: Duration::from_millis(100),
128 connect_timeout: Duration::from_secs(2),
129 write_timeout: Duration::from_secs(5),
130 read_timeout: Duration::from_secs(30),
131 }
132 }
133}
134
135impl RequestTransport {
136 async fn send_to_endpoint(
137 self,
138 endpoint: &std::path::Path,
139 request: &Request,
140 ) -> Result<Response, ConnectionError> {
141 let payload = {
142 let mut line = serde_json::to_string(request)?;
143 line.push('\n');
144 line
145 };
146
147 let mut attempt = 0_u32;
148 loop {
149 match self.send_once(endpoint, &payload).await {
150 Ok(response) => return Ok(response),
151 Err(err) => {
152 attempt += 1;
153 if attempt >= self.max_attempts || !self.should_retry(&err) {
154 return Err(err);
155 }
156 sleep(self.retry_delay_for_attempt(attempt)).await;
157 }
158 }
159 }
160 }
161
162 fn should_retry(self, err: &ConnectionError) -> bool {
163 match err {
164 ConnectionError::Timeout => true,
165 ConnectionError::Io(io_err) => matches!(
166 io_err.kind(),
167 std::io::ErrorKind::NotFound
168 | std::io::ErrorKind::ConnectionRefused
169 | std::io::ErrorKind::ConnectionAborted
170 | std::io::ErrorKind::ConnectionReset
171 | std::io::ErrorKind::TimedOut
172 | std::io::ErrorKind::Interrupted
173 ),
174 ConnectionError::Daemon(_)
175 | ConnectionError::Serde(_)
176 | ConnectionError::MissingResponse
177 | ConnectionError::UnexpectedResponse(_)
178 | ConnectionError::PathNormalization(_) => false,
179 }
180 }
181
182 fn retry_delay_for_attempt(self, attempt: u32) -> Duration {
183 let shift = attempt.saturating_sub(1).min(3);
184 self.retry_delay_base
185 .checked_mul(1_u32 << shift)
186 .unwrap_or(Duration::MAX)
187 }
188
189 async fn send_once(
190 self,
191 endpoint: &std::path::Path,
192 payload: &str,
193 ) -> Result<Response, ConnectionError> {
194 let mut stream = timeout(self.connect_timeout, ipc::connect(endpoint))
195 .await
196 .map_err(|_| ConnectionError::Timeout)??;
197 timeout(self.write_timeout, stream.write_all(payload.as_bytes()))
198 .await
199 .map_err(|_| ConnectionError::Timeout)??;
200
201 let mut reader = BufReader::new(stream);
202 let mut line = String::new();
203 timeout(self.read_timeout, reader.read_line(&mut line))
204 .await
205 .map_err(|_| ConnectionError::Timeout)??;
206 if line.is_empty() {
207 return Err(ConnectionError::MissingResponse);
208 }
209 Ok(serde_json::from_str::<Response>(line.trim_end())?)
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::normalize_request;
216 use crate::protocol::{DbcSpec, Request, RequestAction, TraceStartRequest};
217
218 #[test]
219 fn normalize_request_canonicalizes_connect_and_trace_paths() {
220 let temp = tempfile::tempdir().expect("tempdir");
221 let real_dir = temp.path().join("real");
222 std::fs::create_dir_all(&real_dir).expect("real dir");
223 let dbc_path = real_dir.join("bus.dbc");
224 std::fs::write(&dbc_path, "VERSION \"\"\n").expect("write dbc");
225 let expected_dbc_path = std::fs::canonicalize(&dbc_path)
226 .expect("canonical dbc path")
227 .display()
228 .to_string();
229
230 #[cfg(unix)]
231 let (dbc_input, trace_input, expected_trace) = {
232 let link_dir = temp.path().join("link");
233 std::os::unix::fs::symlink(&real_dir, &link_dir).expect("symlink dir");
234 let canonical_real_dir = std::fs::canonicalize(&real_dir).expect("canonical real dir");
235 (
236 link_dir.join("bus.dbc").display().to_string(),
237 link_dir
238 .join("captures")
239 .join("run.asc")
240 .display()
241 .to_string(),
242 canonical_real_dir
243 .join("captures")
244 .join("run.asc")
245 .display()
246 .to_string(),
247 )
248 };
249
250 #[cfg(not(unix))]
251 let (dbc_input, trace_input, expected_trace) = (
252 dbc_path.display().to_string(),
253 real_dir
254 .join("captures")
255 .join("run.asc")
256 .display()
257 .to_string(),
258 real_dir
259 .join("captures")
260 .join("run.asc")
261 .display()
262 .to_string(),
263 );
264
265 let request = Request {
266 id: uuid::Uuid::new_v4(),
267 action: RequestAction::Connect(crate::protocol::ConnectRequest {
268 adapter: "pcan".to_string(),
269 bitrate: 500_000,
270 bitrate_data: None,
271 fd: false,
272 dbcs: vec![DbcSpec {
273 alias: "main".to_string(),
274 path: dbc_input,
275 }],
276 }),
277 };
278 let normalized = normalize_request(request).expect("normalize connect");
279 let RequestAction::Connect(connect) = normalized.action else {
280 panic!("expected connect request");
281 };
282 assert_eq!(connect.dbcs[0].path, expected_dbc_path);
283
284 let request = Request {
285 id: uuid::Uuid::new_v4(),
286 action: RequestAction::TraceStart(TraceStartRequest { path: trace_input }),
287 };
288 let normalized = normalize_request(request).expect("normalize trace");
289 let RequestAction::TraceStart(trace) = normalized.action else {
290 panic!("expected trace request");
291 };
292 assert_eq!(trace.path, expected_trace);
293 }
294
295 #[test]
296 fn normalize_request_rejects_relative_connect_dbc_paths() {
297 let request = Request {
298 id: uuid::Uuid::new_v4(),
299 action: RequestAction::Connect(crate::protocol::ConnectRequest {
300 adapter: "pcan".to_string(),
301 bitrate: 500_000,
302 bitrate_data: None,
303 fd: false,
304 dbcs: vec![DbcSpec {
305 alias: "main".to_string(),
306 path: "relative.dbc".to_string(),
307 }],
308 }),
309 };
310
311 let err = normalize_request(request).expect_err("relative DBC path must fail");
312 assert!(err.to_string().contains("must be absolute"));
313 }
314
315 #[test]
316 fn normalize_request_rejects_relative_trace_paths() {
317 let request = Request {
318 id: uuid::Uuid::new_v4(),
319 action: RequestAction::TraceStart(TraceStartRequest {
320 path: "captures/run.asc".to_string(),
321 }),
322 };
323
324 let err = normalize_request(request).expect_err("relative trace path must fail");
325 assert!(err.to_string().contains("must be absolute"));
326 }
327}