1use std::time::Duration;
4
5use nng::{options::Options, Message, Protocol, Socket};
6
7use crate::error::IpcError;
8use crate::messages::{ArchivedIpcResponse, IpcRequest, IpcResponse};
9use crate::DEFAULT_TIMEOUT_MS;
10
11pub struct IpcClient {
13 socket: Socket,
14 endpoint: String,
15 timeout_ms: u64,
16}
17
18impl IpcClient {
19 pub fn connect(endpoint: &str) -> Result<Self, IpcError> {
21 Self::connect_with_timeout(endpoint, DEFAULT_TIMEOUT_MS)
22 }
23
24 pub fn connect_with_timeout(endpoint: &str, timeout_ms: u64) -> Result<Self, IpcError> {
26 let socket = Socket::new(Protocol::Req0)?;
27
28 let timeout = Duration::from_millis(timeout_ms);
30 socket
31 .set_opt::<nng::options::SendTimeout>(Some(timeout))
32 .map_err(|e| IpcError::ConnectionFailed(e.to_string()))?;
33 socket
34 .set_opt::<nng::options::RecvTimeout>(Some(timeout))
35 .map_err(|e| IpcError::ConnectionFailed(e.to_string()))?;
36
37 socket.dial(endpoint).map_err(|e| {
39 if e == nng::Error::ConnectionRefused {
40 IpcError::DaemonNotRunning
41 } else {
42 IpcError::ConnectionFailed(e.to_string())
43 }
44 })?;
45
46 Ok(Self {
47 socket,
48 endpoint: endpoint.to_string(),
49 timeout_ms,
50 })
51 }
52
53 pub fn endpoint(&self) -> &str {
55 &self.endpoint
56 }
57
58 pub fn timeout_ms(&self) -> u64 {
60 self.timeout_ms
61 }
62
63 pub fn send(&self, request: &IpcRequest) -> Result<IpcResponse, IpcError> {
65 let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(request)
67 .map_err(|e| IpcError::Serialization(e.to_string()))?;
68
69 let msg = Message::from(bytes.as_slice());
71 self.socket.send(msg).map_err(|e| {
72 if e.1 == nng::Error::TimedOut {
73 IpcError::Timeout(self.timeout_ms)
74 } else {
75 IpcError::Nng(e.1.to_string())
76 }
77 })?;
78
79 let response_msg = self.socket.recv().map_err(|e| {
81 if e == nng::Error::TimedOut {
82 IpcError::Timeout(self.timeout_ms)
83 } else {
84 IpcError::Nng(e.to_string())
85 }
86 })?;
87
88 let archived = rkyv::access::<ArchivedIpcResponse, rkyv::rancor::Error>(&response_msg)
90 .map_err(|e| IpcError::Deserialization(e.to_string()))?;
91
92 let actual_version: u32 = archived.ipc_schema_version.into();
94 if actual_version != request.ipc_schema_version {
95 return Err(IpcError::VersionMismatch {
96 expected: request.ipc_schema_version,
97 actual: actual_version,
98 });
99 }
100
101 let response: IpcResponse =
103 rkyv::deserialize::<IpcResponse, rkyv::rancor::Error>(archived)
104 .map_err(|e| IpcError::Deserialization(e.to_string()))?;
105
106 if !response.ok {
108 if let Some(ref error) = response.error {
109 return Err(IpcError::DaemonError {
110 code: error.code.clone(),
111 message: error.message.clone(),
112 });
113 }
114 }
115
116 Ok(response)
117 }
118
119 pub fn send_with_retry(
121 &self,
122 request: &IpcRequest,
123 max_retries: u32,
124 ) -> Result<IpcResponse, IpcError> {
125 let mut last_error = None;
126 let mut delay_ms = 100;
127
128 for attempt in 0..=max_retries {
129 match self.send(request) {
130 Ok(response) => return Ok(response),
131 Err(e) => {
132 match &e {
134 IpcError::Timeout(_) | IpcError::Nng(_) => {
135 last_error = Some(e);
136 if attempt < max_retries {
137 std::thread::sleep(Duration::from_millis(delay_ms));
138 delay_ms *= 2; }
140 }
141 _ => return Err(e),
142 }
143 }
144 }
145 }
146
147 Err(last_error.unwrap())
148 }
149}
150
151pub fn try_connect(endpoint: &str) -> Option<IpcClient> {
153 IpcClient::connect(endpoint).ok()
154}
155
156#[cfg(test)]
157mod tests {
158 #[test]
162 fn test_timeout_config() {
163 assert!(super::DEFAULT_TIMEOUT_MS > 0);
165 assert!(super::DEFAULT_TIMEOUT_MS <= 60_000);
166 }
167}