use std::time::Duration;
use nng::{options::Options, Message, Protocol, Socket};
use crate::error::IpcError;
use crate::messages::{ArchivedIpcResponse, IpcRequest, IpcResponse};
use crate::DEFAULT_TIMEOUT_MS;
pub struct IpcClient {
socket: Socket,
endpoint: String,
timeout_ms: u64,
}
impl IpcClient {
pub fn connect(endpoint: &str) -> Result<Self, IpcError> {
Self::connect_with_timeout(endpoint, DEFAULT_TIMEOUT_MS)
}
pub fn connect_with_timeout(endpoint: &str, timeout_ms: u64) -> Result<Self, IpcError> {
let socket = Socket::new(Protocol::Req0)?;
let timeout = Duration::from_millis(timeout_ms);
socket
.set_opt::<nng::options::SendTimeout>(Some(timeout))
.map_err(|e| IpcError::ConnectionFailed(e.to_string()))?;
socket
.set_opt::<nng::options::RecvTimeout>(Some(timeout))
.map_err(|e| IpcError::ConnectionFailed(e.to_string()))?;
socket.dial(endpoint).map_err(|e| {
if e == nng::Error::ConnectionRefused {
IpcError::DaemonNotRunning
} else {
IpcError::ConnectionFailed(e.to_string())
}
})?;
Ok(Self {
socket,
endpoint: endpoint.to_string(),
timeout_ms,
})
}
pub fn endpoint(&self) -> &str {
&self.endpoint
}
pub fn timeout_ms(&self) -> u64 {
self.timeout_ms
}
pub fn send(&self, request: &IpcRequest) -> Result<IpcResponse, IpcError> {
let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(request)
.map_err(|e| IpcError::Serialization(e.to_string()))?;
let msg = Message::from(bytes.as_slice());
self.socket.send(msg).map_err(|e| {
if e.1 == nng::Error::TimedOut {
IpcError::Timeout(self.timeout_ms)
} else {
IpcError::Nng(e.1.to_string())
}
})?;
let response_msg = self.socket.recv().map_err(|e| {
if e == nng::Error::TimedOut {
IpcError::Timeout(self.timeout_ms)
} else {
IpcError::Nng(e.to_string())
}
})?;
let archived = rkyv::access::<ArchivedIpcResponse, rkyv::rancor::Error>(&response_msg)
.map_err(|e| IpcError::Deserialization(e.to_string()))?;
let actual_version: u32 = archived.ipc_schema_version.into();
if actual_version != request.ipc_schema_version {
return Err(IpcError::VersionMismatch {
expected: request.ipc_schema_version,
actual: actual_version,
});
}
let response: IpcResponse =
rkyv::deserialize::<IpcResponse, rkyv::rancor::Error>(archived)
.map_err(|e| IpcError::Deserialization(e.to_string()))?;
if !response.ok {
if let Some(ref error) = response.error {
return Err(IpcError::DaemonError {
code: error.code.clone(),
message: error.message.clone(),
});
}
}
Ok(response)
}
pub fn send_with_retry(
&self,
request: &IpcRequest,
max_retries: u32,
) -> Result<IpcResponse, IpcError> {
let mut last_error = None;
let mut delay_ms = 100;
for attempt in 0..=max_retries {
match self.send(request) {
Ok(response) => return Ok(response),
Err(e) => {
match &e {
IpcError::Timeout(_) | IpcError::Nng(_) => {
last_error = Some(e);
if attempt < max_retries {
std::thread::sleep(Duration::from_millis(delay_ms));
delay_ms *= 2; }
}
_ => return Err(e),
}
}
}
}
Err(last_error.unwrap())
}
}
pub fn try_connect(endpoint: &str) -> Option<IpcClient> {
IpcClient::connect(endpoint).ok()
}
#[cfg(test)]
mod tests {
#[test]
fn test_timeout_config() {
assert!(super::DEFAULT_TIMEOUT_MS > 0);
assert!(super::DEFAULT_TIMEOUT_MS <= 60_000);
}
}