1#[cfg(not(unix))]
6compile_error!("libgrite-ipc client requires a Unix platform");
7
8use std::os::unix::net::UnixStream;
9use std::time::Duration;
10
11use crate::error::IpcError;
12use crate::framing::{read_framed, write_framed};
13use crate::messages::{ArchivedIpcResponse, IpcRequest, IpcResponse};
14use crate::DEFAULT_TIMEOUT_MS;
15
16pub struct IpcClient {
23 stream: UnixStream,
24 endpoint: String,
25 timeout_ms: u64,
26 poisoned: bool,
27}
28
29impl IpcClient {
30 pub fn connect(endpoint: &str) -> Result<Self, IpcError> {
32 Self::connect_with_timeout(endpoint, DEFAULT_TIMEOUT_MS)
33 }
34
35 pub fn connect_with_timeout(endpoint: &str, timeout_ms: u64) -> Result<Self, IpcError> {
37 let stream = UnixStream::connect(endpoint).map_err(|e| {
38 if e.kind() == std::io::ErrorKind::ConnectionRefused
39 || e.kind() == std::io::ErrorKind::NotFound
40 {
41 IpcError::DaemonNotRunning
42 } else {
43 IpcError::ConnectionFailed(e.to_string())
44 }
45 })?;
46
47 let timeout = Duration::from_millis(timeout_ms);
48 stream
49 .set_read_timeout(Some(timeout))
50 .map_err(|e| IpcError::ConnectionFailed(e.to_string()))?;
51 stream
52 .set_write_timeout(Some(timeout))
53 .map_err(|e| IpcError::ConnectionFailed(e.to_string()))?;
54
55 Ok(Self {
56 stream,
57 endpoint: endpoint.to_string(),
58 timeout_ms,
59 poisoned: false,
60 })
61 }
62
63 pub fn endpoint(&self) -> &str {
65 &self.endpoint
66 }
67
68 pub fn timeout_ms(&self) -> u64 {
70 self.timeout_ms
71 }
72
73 pub fn send(&mut self, request: &IpcRequest) -> Result<IpcResponse, IpcError> {
78 if self.poisoned {
79 return Err(IpcError::ClientPoisoned);
80 }
81
82 let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(request)
84 .map_err(|e| IpcError::Serialization(e.to_string()))?;
85
86 write_framed(&mut self.stream, &bytes).map_err(|e| {
88 if e.kind() == std::io::ErrorKind::TimedOut
89 || e.kind() == std::io::ErrorKind::WouldBlock
90 {
91 self.poisoned = true;
92 IpcError::Timeout(self.timeout_ms)
93 } else {
94 self.poisoned = true;
95 IpcError::Io(e)
96 }
97 })?;
98
99 let response_bytes = read_framed(&mut self.stream).map_err(|e| {
101 if e.kind() == std::io::ErrorKind::TimedOut
102 || e.kind() == std::io::ErrorKind::WouldBlock
103 {
104 self.poisoned = true;
105 IpcError::Timeout(self.timeout_ms)
106 } else {
107 self.poisoned = true;
108 IpcError::Io(e)
109 }
110 })?;
111
112 let archived = rkyv::access::<ArchivedIpcResponse, rkyv::rancor::Error>(&response_bytes)
114 .map_err(|e| IpcError::Deserialization(e.to_string()))?;
115
116 let actual_version: u32 = archived.ipc_schema_version.into();
118 if actual_version != request.ipc_schema_version {
119 return Err(IpcError::VersionMismatch {
120 expected: request.ipc_schema_version,
121 actual: actual_version,
122 });
123 }
124
125 let response: IpcResponse = rkyv::deserialize::<IpcResponse, rkyv::rancor::Error>(archived)
127 .map_err(|e| IpcError::Deserialization(e.to_string()))?;
128
129 if !response.ok {
131 if let Some(ref error) = response.error {
132 return Err(IpcError::DaemonError {
133 code: error.code.clone(),
134 message: error.message.clone(),
135 });
136 }
137 }
138
139 Ok(response)
140 }
141
142 pub fn send_with_retry(
148 &mut self,
149 request: &IpcRequest,
150 max_retries: u32,
151 ) -> Result<IpcResponse, IpcError> {
152 let mut last_error = None;
153 let mut delay_ms = 100u64;
154
155 for attempt in 0..=max_retries {
156 if attempt > 0 {
158 std::thread::sleep(Duration::from_millis(delay_ms));
159 delay_ms *= 2;
160 match IpcClient::connect_with_timeout(&self.endpoint, self.timeout_ms) {
161 Ok(new_client) => {
162 self.stream = new_client.stream;
163 self.poisoned = false;
164 }
165 Err(e) => {
166 last_error = Some(e);
167 continue;
168 }
169 }
170 }
171
172 match self.send(request) {
173 Ok(response) => return Ok(response),
174 Err(e) => match &e {
175 IpcError::Timeout(_) | IpcError::Io(_) | IpcError::ClientPoisoned => {
176 last_error = Some(e);
177 }
178 _ => return Err(e),
179 },
180 }
181 }
182
183 Err(last_error
184 .unwrap_or_else(|| IpcError::ConnectionFailed("all retries exhausted".to_string())))
185 }
186}
187
188pub fn try_connect(endpoint: &str) -> Option<IpcClient> {
190 IpcClient::connect(endpoint).ok()
191}
192
193#[cfg(test)]
194mod tests {
195 #[test]
196 #[allow(clippy::assertions_on_constants)]
197 fn test_timeout_config() {
198 assert!(super::DEFAULT_TIMEOUT_MS > 0);
199 assert!(super::DEFAULT_TIMEOUT_MS <= 60_000);
200 }
201}