Skip to main content

libgrite_ipc/
client.rs

1//! IPC client for connecting to the daemon
2
3use std::time::Duration;
4
5use nng::{options::Options, Message, Protocol, Socket};
6
7use crate::error::IpcError;
8use crate::messages::{ArchivedIpcResponse, IpcRequest, IpcResponse};
9use crate::DEFAULT_TIMEOUT_MS;
10
11/// IPC client for daemon communication
12pub struct IpcClient {
13    socket: Socket,
14    endpoint: String,
15    timeout_ms: u64,
16}
17
18impl IpcClient {
19    /// Connect to a daemon at the given endpoint
20    pub fn connect(endpoint: &str) -> Result<Self, IpcError> {
21        Self::connect_with_timeout(endpoint, DEFAULT_TIMEOUT_MS)
22    }
23
24    /// Connect with a custom timeout
25    pub fn connect_with_timeout(endpoint: &str, timeout_ms: u64) -> Result<Self, IpcError> {
26        let socket = Socket::new(Protocol::Req0)?;
27
28        // Set timeouts
29        let timeout = Duration::from_millis(timeout_ms);
30        socket
31            .set_opt::<nng::options::SendTimeout>(Some(timeout))
32            .map_err(|e| IpcError::ConnectionFailed(e.to_string()))?;
33        socket
34            .set_opt::<nng::options::RecvTimeout>(Some(timeout))
35            .map_err(|e| IpcError::ConnectionFailed(e.to_string()))?;
36
37        // Connect to the endpoint
38        socket.dial(endpoint).map_err(|e| {
39            if e == nng::Error::ConnectionRefused {
40                IpcError::DaemonNotRunning
41            } else {
42                IpcError::ConnectionFailed(e.to_string())
43            }
44        })?;
45
46        Ok(Self {
47            socket,
48            endpoint: endpoint.to_string(),
49            timeout_ms,
50        })
51    }
52
53    /// Get the endpoint this client is connected to
54    pub fn endpoint(&self) -> &str {
55        &self.endpoint
56    }
57
58    /// Get the configured timeout in milliseconds
59    pub fn timeout_ms(&self) -> u64 {
60        self.timeout_ms
61    }
62
63    /// Send a request and wait for a response
64    pub fn send(&self, request: &IpcRequest) -> Result<IpcResponse, IpcError> {
65        // Serialize the request
66        let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(request)
67            .map_err(|e| IpcError::Serialization(e.to_string()))?;
68
69        // Send the request
70        let msg = Message::from(bytes.as_slice());
71        self.socket.send(msg).map_err(|e| {
72            if e.1 == nng::Error::TimedOut {
73                IpcError::Timeout(self.timeout_ms)
74            } else {
75                IpcError::Nng(e.1.to_string())
76            }
77        })?;
78
79        // Receive the response
80        let response_msg = self.socket.recv().map_err(|e| {
81            if e == nng::Error::TimedOut {
82                IpcError::Timeout(self.timeout_ms)
83            } else {
84                IpcError::Nng(e.to_string())
85            }
86        })?;
87
88        // Deserialize the response
89        let archived = rkyv::access::<ArchivedIpcResponse, rkyv::rancor::Error>(&response_msg)
90            .map_err(|e| IpcError::Deserialization(e.to_string()))?;
91
92        // Check version
93        let actual_version: u32 = archived.ipc_schema_version.into();
94        if actual_version != request.ipc_schema_version {
95            return Err(IpcError::VersionMismatch {
96                expected: request.ipc_schema_version,
97                actual: actual_version,
98            });
99        }
100
101        // Deserialize to owned type
102        let response: IpcResponse =
103            rkyv::deserialize::<IpcResponse, rkyv::rancor::Error>(archived)
104                .map_err(|e| IpcError::Deserialization(e.to_string()))?;
105
106        // Check for daemon error
107        if !response.ok {
108            if let Some(ref error) = response.error {
109                return Err(IpcError::DaemonError {
110                    code: error.code.clone(),
111                    message: error.message.clone(),
112                });
113            }
114        }
115
116        Ok(response)
117    }
118
119    /// Send a request with retries using exponential backoff
120    pub fn send_with_retry(
121        &self,
122        request: &IpcRequest,
123        max_retries: u32,
124    ) -> Result<IpcResponse, IpcError> {
125        let mut last_error = None;
126        let mut delay_ms = 100;
127
128        for attempt in 0..=max_retries {
129            match self.send(request) {
130                Ok(response) => return Ok(response),
131                Err(e) => {
132                    // Only retry on timeout or transient errors
133                    match &e {
134                        IpcError::Timeout(_) | IpcError::Nng(_) => {
135                            last_error = Some(e);
136                            if attempt < max_retries {
137                                std::thread::sleep(Duration::from_millis(delay_ms));
138                                delay_ms *= 2; // Exponential backoff
139                            }
140                        }
141                        _ => return Err(e),
142                    }
143                }
144            }
145        }
146
147        Err(last_error.unwrap())
148    }
149}
150
151/// Try to connect to a daemon, returning None if not running
152pub fn try_connect(endpoint: &str) -> Option<IpcClient> {
153    IpcClient::connect(endpoint).ok()
154}
155
156#[cfg(test)]
157mod tests {
158    // Client tests require a running daemon or mock server
159    // These are integration tests that would be in the grit-daemon crate
160
161    #[test]
162    fn test_timeout_config() {
163        // Just verify the constants are reasonable
164        assert!(super::DEFAULT_TIMEOUT_MS > 0);
165        assert!(super::DEFAULT_TIMEOUT_MS <= 60_000);
166    }
167}