Skip to main content

libgrite_ipc/
client.rs

1//! IPC client for connecting to the daemon
2//!
3//! This module requires Unix (uses Unix domain sockets).
4
5#[cfg(not(unix))]
6compile_error!("libgrite-ipc client requires a Unix platform");
7
8use std::os::unix::net::UnixStream;
9use std::time::Duration;
10
11use crate::error::IpcError;
12use crate::framing::{read_framed, write_framed};
13use crate::messages::{ArchivedIpcResponse, IpcRequest, IpcResponse};
14use crate::DEFAULT_TIMEOUT_MS;
15
16/// IPC client for daemon communication
17///
18/// A client becomes *poisoned* after a timeout or IO error, because the
19/// underlying stream may contain partial data from the failed exchange.
20/// Poisoned clients reject further `send()` calls with [`IpcError::ClientPoisoned`].
21/// Use [`send_with_retry`](Self::send_with_retry) for automatic reconnection.
22pub struct IpcClient {
23    stream: UnixStream,
24    endpoint: String,
25    timeout_ms: u64,
26    poisoned: bool,
27}
28
29impl IpcClient {
30    /// Connect to a daemon at the given endpoint (Unix socket path)
31    pub fn connect(endpoint: &str) -> Result<Self, IpcError> {
32        Self::connect_with_timeout(endpoint, DEFAULT_TIMEOUT_MS)
33    }
34
35    /// Connect with a custom timeout
36    pub fn connect_with_timeout(endpoint: &str, timeout_ms: u64) -> Result<Self, IpcError> {
37        let stream = UnixStream::connect(endpoint).map_err(|e| {
38            if e.kind() == std::io::ErrorKind::ConnectionRefused
39                || e.kind() == std::io::ErrorKind::NotFound
40            {
41                IpcError::DaemonNotRunning
42            } else {
43                IpcError::ConnectionFailed(e.to_string())
44            }
45        })?;
46
47        let timeout = Duration::from_millis(timeout_ms);
48        stream
49            .set_read_timeout(Some(timeout))
50            .map_err(|e| IpcError::ConnectionFailed(e.to_string()))?;
51        stream
52            .set_write_timeout(Some(timeout))
53            .map_err(|e| IpcError::ConnectionFailed(e.to_string()))?;
54
55        Ok(Self {
56            stream,
57            endpoint: endpoint.to_string(),
58            timeout_ms,
59            poisoned: false,
60        })
61    }
62
63    /// Get the endpoint this client is connected to
64    pub fn endpoint(&self) -> &str {
65        &self.endpoint
66    }
67
68    /// Get the configured timeout in milliseconds
69    pub fn timeout_ms(&self) -> u64 {
70        self.timeout_ms
71    }
72
73    /// Send a request and wait for a response
74    ///
75    /// Returns [`IpcError::ClientPoisoned`] if this client was poisoned by a
76    /// previous timeout or IO error.
77    pub fn send(&mut self, request: &IpcRequest) -> Result<IpcResponse, IpcError> {
78        if self.poisoned {
79            return Err(IpcError::ClientPoisoned);
80        }
81
82        // Serialize the request
83        let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(request)
84            .map_err(|e| IpcError::Serialization(e.to_string()))?;
85
86        // Send length-prefixed request
87        write_framed(&mut self.stream, &bytes).map_err(|e| {
88            if e.kind() == std::io::ErrorKind::TimedOut
89                || e.kind() == std::io::ErrorKind::WouldBlock
90            {
91                self.poisoned = true;
92                IpcError::Timeout(self.timeout_ms)
93            } else {
94                self.poisoned = true;
95                IpcError::Io(e)
96            }
97        })?;
98
99        // Read length-prefixed response
100        let response_bytes = read_framed(&mut self.stream).map_err(|e| {
101            if e.kind() == std::io::ErrorKind::TimedOut
102                || e.kind() == std::io::ErrorKind::WouldBlock
103            {
104                self.poisoned = true;
105                IpcError::Timeout(self.timeout_ms)
106            } else {
107                self.poisoned = true;
108                IpcError::Io(e)
109            }
110        })?;
111
112        // Deserialize the response
113        let archived = rkyv::access::<ArchivedIpcResponse, rkyv::rancor::Error>(&response_bytes)
114            .map_err(|e| IpcError::Deserialization(e.to_string()))?;
115
116        // Check version
117        let actual_version: u32 = archived.ipc_schema_version.into();
118        if actual_version != request.ipc_schema_version {
119            return Err(IpcError::VersionMismatch {
120                expected: request.ipc_schema_version,
121                actual: actual_version,
122            });
123        }
124
125        // Deserialize to owned type
126        let response: IpcResponse = rkyv::deserialize::<IpcResponse, rkyv::rancor::Error>(archived)
127            .map_err(|e| IpcError::Deserialization(e.to_string()))?;
128
129        // Check for daemon error
130        if !response.ok {
131            if let Some(ref error) = response.error {
132                return Err(IpcError::DaemonError {
133                    code: error.code.clone(),
134                    message: error.message.clone(),
135                });
136            }
137        }
138
139        Ok(response)
140    }
141
142    /// Send a request with retries using exponential backoff
143    ///
144    /// Each retry creates a fresh connection to avoid stale stream state.
145    /// If reconnection fails, that attempt is consumed but the retry loop
146    /// continues (with backoff) rather than silently burning all retries.
147    pub fn send_with_retry(
148        &mut self,
149        request: &IpcRequest,
150        max_retries: u32,
151    ) -> Result<IpcResponse, IpcError> {
152        let mut last_error = None;
153        let mut delay_ms = 100u64;
154
155        for attempt in 0..=max_retries {
156            // Reconnect before each retry (not on the first attempt)
157            if attempt > 0 {
158                std::thread::sleep(Duration::from_millis(delay_ms));
159                delay_ms *= 2;
160                match IpcClient::connect_with_timeout(&self.endpoint, self.timeout_ms) {
161                    Ok(new_client) => {
162                        self.stream = new_client.stream;
163                        self.poisoned = false;
164                    }
165                    Err(e) => {
166                        last_error = Some(e);
167                        continue;
168                    }
169                }
170            }
171
172            match self.send(request) {
173                Ok(response) => return Ok(response),
174                Err(e) => match &e {
175                    IpcError::Timeout(_) | IpcError::Io(_) | IpcError::ClientPoisoned => {
176                        last_error = Some(e);
177                    }
178                    _ => return Err(e),
179                },
180            }
181        }
182
183        Err(last_error
184            .unwrap_or_else(|| IpcError::ConnectionFailed("all retries exhausted".to_string())))
185    }
186}
187
188/// Try to connect to a daemon, returning None if not running
189pub fn try_connect(endpoint: &str) -> Option<IpcClient> {
190    IpcClient::connect(endpoint).ok()
191}
192
193#[cfg(test)]
194mod tests {
195    #[test]
196    #[allow(clippy::assertions_on_constants)]
197    fn test_timeout_config() {
198        assert!(super::DEFAULT_TIMEOUT_MS > 0);
199        assert!(super::DEFAULT_TIMEOUT_MS <= 60_000);
200    }
201}