1use crate::daemon::error::DaemonError;
2use crate::daemon::{config::DaemonConfig, lifecycle};
3use crate::ipc;
4use crate::protocol::{Request, Response};
5use thiserror::Error;
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::time::{Duration, sleep, timeout};
8
9#[derive(Debug, Error)]
10pub enum ConnectionError {
11 #[error(transparent)]
12 Daemon(#[from] DaemonError),
13 #[error("connection timeout")]
14 Timeout,
15 #[error("connection error: {0}")]
16 Io(#[from] std::io::Error),
17 #[error("serialization error: {0}")]
18 Serde(#[from] serde_json::Error),
19 #[error("response missing")]
20 MissingResponse,
21}
22
23pub async fn send_request(bus: &str, request: &Request) -> Result<Response, ConnectionError> {
24 lifecycle::ensure_daemon_running(bus).await?;
25 RequestTransport::default()
26 .send_to_endpoint(&lifecycle::socket_path(bus), request)
27 .await
28}
29
30pub async fn bootstrap_and_send_request(
31 bus: &str,
32 config: &DaemonConfig,
33 request: &Request,
34) -> Result<Response, ConnectionError> {
35 lifecycle::bootstrap_daemon(bus, config).await?;
36 RequestTransport::default()
37 .send_to_endpoint(&lifecycle::socket_path(bus), request)
38 .await
39}
40
41#[derive(Debug, Clone, Copy)]
42struct RequestTransport {
43 max_attempts: u32,
44 retry_delay_base: Duration,
45 connect_timeout: Duration,
46 write_timeout: Duration,
47 read_timeout: Duration,
48}
49
50impl Default for RequestTransport {
51 fn default() -> Self {
52 Self {
53 max_attempts: 3,
54 retry_delay_base: Duration::from_millis(100),
55 connect_timeout: Duration::from_secs(2),
56 write_timeout: Duration::from_secs(5),
57 read_timeout: Duration::from_secs(30),
58 }
59 }
60}
61
62impl RequestTransport {
63 async fn send_to_endpoint(
64 self,
65 endpoint: &std::path::Path,
66 request: &Request,
67 ) -> Result<Response, ConnectionError> {
68 let payload = {
69 let mut line = serde_json::to_string(request)?;
70 line.push('\n');
71 line
72 };
73
74 let mut attempt = 0_u32;
75 loop {
76 match self.send_once(endpoint, &payload).await {
77 Ok(response) => return Ok(response),
78 Err(err) => {
79 attempt += 1;
80 if attempt >= self.max_attempts || !self.should_retry(&err) {
81 return Err(err);
82 }
83 sleep(self.retry_delay_for_attempt(attempt)).await;
84 }
85 }
86 }
87 }
88
89 fn should_retry(self, err: &ConnectionError) -> bool {
90 match err {
91 ConnectionError::Timeout => true,
92 ConnectionError::Io(io_err) => matches!(
93 io_err.kind(),
94 std::io::ErrorKind::NotFound
95 | std::io::ErrorKind::ConnectionRefused
96 | std::io::ErrorKind::ConnectionAborted
97 | std::io::ErrorKind::ConnectionReset
98 | std::io::ErrorKind::TimedOut
99 | std::io::ErrorKind::Interrupted
100 ),
101 ConnectionError::Daemon(_)
102 | ConnectionError::Serde(_)
103 | ConnectionError::MissingResponse => false,
104 }
105 }
106
107 fn retry_delay_for_attempt(self, attempt: u32) -> Duration {
108 let shift = attempt.saturating_sub(1).min(3);
109 self.retry_delay_base
110 .checked_mul(1_u32 << shift)
111 .unwrap_or(Duration::MAX)
112 }
113
114 async fn send_once(
115 self,
116 endpoint: &std::path::Path,
117 payload: &str,
118 ) -> Result<Response, ConnectionError> {
119 let mut stream = timeout(self.connect_timeout, ipc::connect(endpoint))
120 .await
121 .map_err(|_| ConnectionError::Timeout)??;
122 timeout(self.write_timeout, stream.write_all(payload.as_bytes()))
123 .await
124 .map_err(|_| ConnectionError::Timeout)??;
125
126 let mut reader = BufReader::new(stream);
127 let mut line = String::new();
128 timeout(self.read_timeout, reader.read_line(&mut line))
129 .await
130 .map_err(|_| ConnectionError::Timeout)??;
131 if line.is_empty() {
132 return Err(ConnectionError::MissingResponse);
133 }
134 Ok(serde_json::from_str::<Response>(line.trim_end())?)
135 }
136}