#![allow(dead_code)]
use std::io;
use serde::{Serialize, de::DeserializeOwned};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum IpcError {
#[error("IPC not enabled")]
NotEnabled,
#[error("unknown method: {0}")]
UnknownMethod(String),
#[error("serialization error: {0}")]
Serialization(#[from] rmp_serde::encode::Error),
#[error("deserialization error: {0}")]
Deserialization(#[from] rmp_serde::decode::Error),
#[error("IO error: {0}")]
Io(#[from] io::Error),
#[error("invalid protocol: {0}")]
InvalidProtocol(String),
#[error("handler error: {0}")]
Handler(String),
}
#[derive(Debug)]
pub struct IpcRequest {
pub method: String,
pub params: Vec<u8>,
}
impl IpcRequest {
pub fn from_bytes(data: &[u8]) -> Result<Self, IpcError> {
if data.is_empty() {
return Err(IpcError::InvalidProtocol("empty request".to_string()));
}
let method_len = data[0] as usize;
if data.len() < 1 + method_len {
return Err(IpcError::InvalidProtocol("truncated method".to_string()));
}
let method = String::from_utf8(data[1..1 + method_len].to_vec())
.map_err(|e| IpcError::InvalidProtocol(format!("invalid method UTF-8: {e}")))?;
let params = data[1 + method_len..].to_vec();
Ok(Self { method, params })
}
pub fn to_bytes<T: Serialize>(method: &str, params: &T) -> Result<Vec<u8>, IpcError> {
let method_bytes = method.as_bytes();
if method_bytes.len() > 255 {
return Err(IpcError::InvalidProtocol(
"method name too long".to_string(),
));
}
let params_bytes = rmp_serde::to_vec(params)?;
let total_len = 1 + method_bytes.len() + params_bytes.len();
let mut buf = Vec::with_capacity(4 + total_len);
buf.extend_from_slice(&(total_len as u32).to_be_bytes());
buf.push(method_bytes.len() as u8);
buf.extend_from_slice(method_bytes);
buf.extend_from_slice(¶ms_bytes);
Ok(buf)
}
pub fn deserialize_params<T: DeserializeOwned>(&self) -> Result<T, IpcError> {
rmp_serde::from_slice(&self.params).map_err(IpcError::from)
}
}
#[derive(Debug)]
pub struct IpcResponse {
pub success: bool,
pub payload: Vec<u8>,
}
impl IpcResponse {
pub fn success<T: Serialize>(result: &T) -> Result<Self, IpcError> {
Ok(Self {
success: true,
payload: rmp_serde::to_vec(result)?,
})
}
pub fn error(message: &str) -> Result<Self, IpcError> {
Ok(Self {
success: false,
payload: rmp_serde::to_vec(&message)?,
})
}
pub fn to_bytes(&self) -> Vec<u8> {
let total_len = 1 + self.payload.len();
let mut buf = Vec::with_capacity(4 + total_len);
buf.extend_from_slice(&(total_len as u32).to_be_bytes());
buf.push(if self.success { 1 } else { 0 });
buf.extend_from_slice(&self.payload);
buf
}
pub fn from_bytes(data: &[u8]) -> Result<Self, IpcError> {
if data.is_empty() {
return Err(IpcError::InvalidProtocol("empty response".to_string()));
}
let success = data[0] == 1;
let payload = data[1..].to_vec();
Ok(Self { success, payload })
}
pub fn deserialize_payload<T: DeserializeOwned>(&self) -> Result<T, IpcError> {
rmp_serde::from_slice(&self.payload).map_err(IpcError::from)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct TestParams {
query: String,
}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
struct TestResult {
items: Vec<String>,
}
#[test]
fn test_request_roundtrip() {
let params = TestParams {
query: "hello".to_string(),
};
let bytes = IpcRequest::to_bytes("search", ¶ms).unwrap();
let request = IpcRequest::from_bytes(&bytes[4..]).unwrap();
assert_eq!(request.method, "search");
let decoded: TestParams = request.deserialize_params().unwrap();
assert_eq!(decoded, params);
}
#[test]
fn test_response_success_roundtrip() {
let result = TestResult {
items: vec!["a".to_string(), "b".to_string()],
};
let response = IpcResponse::success(&result).unwrap();
let bytes = response.to_bytes();
let parsed = IpcResponse::from_bytes(&bytes[4..]).unwrap();
assert!(parsed.success);
let decoded: TestResult = parsed.deserialize_payload().unwrap();
assert_eq!(decoded, result);
}
#[test]
fn test_response_error_roundtrip() {
let response = IpcResponse::error("something went wrong").unwrap();
let bytes = response.to_bytes();
let parsed = IpcResponse::from_bytes(&bytes[4..]).unwrap();
assert!(!parsed.success);
let message: String = parsed.deserialize_payload().unwrap();
assert_eq!(message, "something went wrong");
}
}