use crate::error::{IpcError, Result};
use crate::graceful::{GracefulChannel, ShutdownState};
use crate::local_socket::{LocalSocketListener, LocalSocketStream};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io::{Read, Write};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::thread::JoinHandle;
use std::time::{Duration, SystemTime};
pub type ConnectionId = u64;
#[derive(Debug, Clone)]
pub struct SocketServerConfig {
pub path: String,
pub max_connections: usize,
pub connection_timeout: Duration,
pub cleanup_on_start: bool,
pub buffer_size: usize,
}
impl Default for SocketServerConfig {
fn default() -> Self {
Self {
path: default_socket_path(),
max_connections: 100,
connection_timeout: Duration::from_secs(30),
cleanup_on_start: true,
buffer_size: 8192,
}
}
}
impl SocketServerConfig {
pub fn with_path(path: &str) -> Self {
Self {
path: path.to_string(),
..Default::default()
}
}
}
pub fn default_socket_path() -> String {
#[cfg(unix)]
{
let runtime_dir = std::env::var("XDG_RUNTIME_DIR").unwrap_or_else(|_| "/tmp".to_string());
format!("{}/ipckit.sock", runtime_dir)
}
#[cfg(windows)]
{
r"\\.\pipe\ipckit".to_string()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConnectionMetadata {
#[serde(with = "system_time_serde")]
pub connected_at: SystemTime,
pub client_pid: Option<u32>,
pub client_info: Option<String>,
}
mod system_time_serde {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
pub fn serialize<S>(time: &SystemTime, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let duration = time.duration_since(UNIX_EPOCH).unwrap_or(Duration::ZERO);
duration.as_secs_f64().serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
where
D: Deserializer<'de>,
{
let secs = f64::deserialize(deserializer)?;
Ok(UNIX_EPOCH + Duration::from_secs_f64(secs))
}
}
impl Default for ConnectionMetadata {
fn default() -> Self {
Self {
connected_at: SystemTime::now(),
client_pid: None,
client_info: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub msg_type: MessageType,
pub payload: serde_json::Value,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageType {
Text,
Binary,
Request,
Response,
Error,
Ping,
Pong,
}
impl Message {
pub fn text(content: &str) -> Self {
Self {
msg_type: MessageType::Text,
payload: serde_json::json!({ "content": content }),
}
}
pub fn request(method: &str, params: serde_json::Value) -> Self {
Self {
msg_type: MessageType::Request,
payload: serde_json::json!({
"method": method,
"params": params
}),
}
}
pub fn response(result: serde_json::Value) -> Self {
Self {
msg_type: MessageType::Response,
payload: serde_json::json!({ "result": result }),
}
}
pub fn error(code: i32, message: &str) -> Self {
Self {
msg_type: MessageType::Error,
payload: serde_json::json!({
"code": code,
"message": message
}),
}
}
pub fn ping() -> Self {
Self {
msg_type: MessageType::Ping,
payload: serde_json::json!({}),
}
}
pub fn pong() -> Self {
Self {
msg_type: MessageType::Pong,
payload: serde_json::json!({}),
}
}
pub fn json(value: serde_json::Value) -> Self {
Self {
msg_type: MessageType::Text,
payload: value,
}
}
pub fn binary(data: Vec<u8>) -> Self {
Self {
msg_type: MessageType::Binary,
payload: serde_json::json!({
"data": base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &data)
}),
}
}
pub fn as_binary(&self) -> Option<Vec<u8>> {
self.payload
.get("data")
.and_then(|v| v.as_str())
.and_then(|s| {
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, s).ok()
})
}
pub fn as_text(&self) -> Option<&str> {
self.payload.get("content").and_then(|v| v.as_str())
}
pub fn method(&self) -> Option<&str> {
self.payload.get("method").and_then(|v| v.as_str())
}
pub fn params(&self) -> Option<&serde_json::Value> {
self.payload.get("params")
}
pub fn result(&self) -> Option<&serde_json::Value> {
self.payload.get("result")
}
}
pub struct Connection {
id: ConnectionId,
stream: LocalSocketStream,
metadata: ConnectionMetadata,
buffer: Vec<u8>,
}
impl Connection {
fn new(id: ConnectionId, stream: LocalSocketStream) -> Self {
Self {
id,
stream,
metadata: ConnectionMetadata::default(),
buffer: Vec::with_capacity(8192),
}
}
pub fn id(&self) -> ConnectionId {
self.id
}
pub fn metadata(&self) -> &ConnectionMetadata {
&self.metadata
}
pub fn set_client_info(&mut self, info: &str) {
self.metadata.client_info = Some(info.to_string());
}
pub fn send(&mut self, msg: &Message) -> Result<()> {
let data = serde_json::to_vec(msg).map_err(|e| IpcError::serialization(e.to_string()))?;
let len = data.len() as u32;
self.stream.write_all(&len.to_le_bytes())?;
self.stream.write_all(&data)?;
self.stream.flush()?;
Ok(())
}
pub fn recv(&mut self) -> Result<Message> {
let mut len_buf = [0u8; 4];
self.stream.read_exact(&mut len_buf)?;
let len = u32::from_le_bytes(len_buf) as usize;
if len > 16 * 1024 * 1024 {
return Err(IpcError::BufferTooSmall {
needed: len,
got: 16 * 1024 * 1024,
});
}
self.buffer.resize(len, 0);
self.stream.read_exact(&mut self.buffer)?;
serde_json::from_slice(&self.buffer).map_err(|e| IpcError::deserialization(e.to_string()))
}
pub fn try_recv(&mut self) -> Result<Option<Message>> {
Err(IpcError::WouldBlock)
}
pub fn request(
&mut self,
method: &str,
params: serde_json::Value,
) -> Result<serde_json::Value> {
self.send(&Message::request(method, params))?;
let response = self.recv()?;
match response.msg_type {
MessageType::Response => response
.result()
.cloned()
.ok_or_else(|| IpcError::deserialization("Missing result in response".to_string())),
MessageType::Error => {
let msg = response
.payload
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("Unknown error");
Err(IpcError::Other(msg.to_string()))
}
_ => Err(IpcError::deserialization(
"Unexpected message type".to_string(),
)),
}
}
}
pub trait ConnectionHandler: Clone + Send + 'static {
fn on_connect(&self, conn: &mut Connection) -> Result<()> {
let _ = conn;
Ok(())
}
fn on_message(&self, conn: &mut Connection, msg: Message) -> Result<Option<Message>>;
fn on_disconnect(&self, conn_id: ConnectionId) {
let _ = conn_id;
}
}
#[derive(Clone)]
pub struct FnHandler<F>
where
F: Fn(&mut Connection, Message) -> Result<Option<Message>> + Clone + Send + 'static,
{
handler: F,
}
impl<F> FnHandler<F>
where
F: Fn(&mut Connection, Message) -> Result<Option<Message>> + Clone + Send + 'static,
{
pub fn new(handler: F) -> Self {
Self { handler }
}
}
impl<F> ConnectionHandler for FnHandler<F>
where
F: Fn(&mut Connection, Message) -> Result<Option<Message>> + Clone + Send + 'static,
{
fn on_message(&self, conn: &mut Connection, msg: Message) -> Result<Option<Message>> {
(self.handler)(conn, msg)
}
}
pub struct SocketServer {
config: SocketServerConfig,
listener: LocalSocketListener,
connections: Arc<RwLock<HashMap<ConnectionId, Arc<RwLock<Connection>>>>>,
shutdown: Arc<ShutdownState>,
next_id: AtomicU64,
}
impl SocketServer {
pub fn new(config: SocketServerConfig) -> Result<Self> {
#[cfg(unix)]
if config.cleanup_on_start && !config.path.starts_with(r"\\.\pipe\") {
let _ = std::fs::remove_file(&config.path);
}
let listener = LocalSocketListener::bind(&config.path)?;
Ok(Self {
config,
listener,
connections: Arc::new(RwLock::new(HashMap::new())),
shutdown: Arc::new(ShutdownState::new()),
next_id: AtomicU64::new(1),
})
}
pub fn with_defaults() -> Result<Self> {
Self::new(SocketServerConfig::default())
}
pub fn at(path: &str) -> Result<Self> {
Self::new(SocketServerConfig::with_path(path))
}
pub fn socket_path(&self) -> &str {
&self.config.path
}
pub fn connection_count(&self) -> usize {
self.connections.read().len()
}
pub fn accept(&self) -> Result<Connection> {
if self.shutdown.is_shutdown() {
return Err(IpcError::Closed);
}
let stream = self.listener.accept()?;
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let conn = Connection::new(id, stream);
self.connections
.write()
.insert(id, Arc::new(RwLock::new(conn)));
let stream = self.listener.accept()?;
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
Ok(Connection::new(id, stream))
}
pub fn incoming(&self) -> impl Iterator<Item = Result<Connection>> + '_ {
std::iter::from_fn(move || {
if self.shutdown.is_shutdown() {
return None;
}
match self.listener.accept() {
Ok(stream) => {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
Some(Ok(Connection::new(id, stream)))
}
Err(e) => Some(Err(e)),
}
})
}
pub fn run<H: ConnectionHandler>(&self, handler: H) -> Result<()> {
for conn_result in self.incoming() {
if self.shutdown.is_shutdown() {
break;
}
match conn_result {
Ok(mut conn) => {
let handler = handler.clone();
let shutdown = Arc::clone(&self.shutdown);
std::thread::spawn(move || {
if let Err(e) = handler.on_connect(&mut conn) {
tracing::error!("Connection error: {}", e);
return;
}
loop {
if shutdown.is_shutdown() {
break;
}
match conn.recv() {
Ok(msg) => match handler.on_message(&mut conn, msg) {
Ok(Some(response)) => {
if let Err(e) = conn.send(&response) {
tracing::error!("Send error: {}", e);
break;
}
}
Ok(None) => {}
Err(e) => {
tracing::error!("Handler error: {}", e);
let _ = conn.send(&Message::error(-1, &e.to_string()));
}
},
Err(IpcError::Io(ref e))
if e.kind() == std::io::ErrorKind::UnexpectedEof =>
{
break;
}
Err(e) => {
tracing::error!("Receive error: {}", e);
break;
}
}
}
handler.on_disconnect(conn.id());
});
}
Err(e) => {
tracing::error!("Accept error: {}", e);
}
}
}
Ok(())
}
pub fn spawn<H: ConnectionHandler>(self, handler: H) -> JoinHandle<Result<()>> {
std::thread::spawn(move || self.run(handler))
}
pub fn shutdown(&self) {
self.shutdown.shutdown();
}
pub fn is_shutdown(&self) -> bool {
self.shutdown.is_shutdown()
}
}
impl GracefulChannel for SocketServer {
fn shutdown(&self) {
self.shutdown.shutdown();
}
fn is_shutdown(&self) -> bool {
self.shutdown.is_shutdown()
}
fn drain(&self) -> Result<()> {
self.shutdown.wait_for_drain(None)
}
fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
self.shutdown();
self.shutdown.wait_for_drain(Some(timeout))
}
}
pub struct SocketClient {
connection: Connection,
}
impl SocketClient {
pub fn connect(path: &str) -> Result<Self> {
let stream = LocalSocketStream::connect(path)?;
let connection = Connection::new(0, stream);
Ok(Self { connection })
}
pub fn connect_timeout(path: &str, timeout: Duration) -> Result<Self> {
use std::sync::mpsc;
use std::thread;
let path_owned = path.to_string();
let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let result = LocalSocketStream::connect(&path_owned);
let _ = tx.send(result);
});
match rx.recv_timeout(timeout) {
Ok(Ok(stream)) => {
let connection = Connection::new(0, stream);
Ok(Self { connection })
}
Ok(Err(e)) => Err(e),
Err(_) => Err(IpcError::Timeout),
}
}
pub fn connect_default() -> Result<Self> {
Self::connect(&default_socket_path())
}
pub fn connect_default_timeout(timeout: Duration) -> Result<Self> {
Self::connect_timeout(&default_socket_path(), timeout)
}
pub fn send(&mut self, msg: &Message) -> Result<()> {
self.connection.send(msg)
}
pub fn recv(&mut self) -> Result<Message> {
self.connection.recv()
}
pub fn request(
&mut self,
method: &str,
params: serde_json::Value,
) -> Result<serde_json::Value> {
self.connection.request(method, params)
}
pub fn connection(&mut self) -> &mut Connection {
&mut self.connection
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_message_creation() {
let text = Message::text("Hello");
assert_eq!(text.msg_type, MessageType::Text);
assert_eq!(text.as_text(), Some("Hello"));
let request = Message::request("ping", serde_json::json!({}));
assert_eq!(request.msg_type, MessageType::Request);
assert_eq!(request.method(), Some("ping"));
let response = Message::response(serde_json::json!({"pong": true}));
assert_eq!(response.msg_type, MessageType::Response);
assert!(response.result().is_some());
let error = Message::error(404, "Not found");
assert_eq!(error.msg_type, MessageType::Error);
}
#[test]
fn test_message_serialization() {
let msg = Message::request("test", serde_json::json!({"key": "value"}));
let json = serde_json::to_string(&msg).unwrap();
let deserialized: Message = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.msg_type, msg.msg_type);
assert_eq!(deserialized.method(), msg.method());
}
#[test]
fn test_socket_server_config() {
let config = SocketServerConfig::default();
assert_eq!(config.max_connections, 100);
assert!(config.cleanup_on_start);
let custom = SocketServerConfig::with_path("/tmp/test.sock");
assert_eq!(custom.path, "/tmp/test.sock");
}
#[test]
fn test_connection_metadata() {
let metadata = ConnectionMetadata::default();
assert!(metadata.client_pid.is_none());
assert!(metadata.client_info.is_none());
}
#[test]
fn test_fn_handler() {
let handler = FnHandler::new(|_conn, msg| {
if msg.method() == Some("ping") {
Ok(Some(Message::response(serde_json::json!({"pong": true}))))
} else {
Ok(None)
}
});
let _handler2 = handler.clone();
}
#[test]
#[ignore] fn test_socket_client_server() {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
let socket_name = format!("test_socket_server_{}", std::process::id());
let server_ready = Arc::new(AtomicBool::new(false));
let server_ready_clone = server_ready.clone();
let socket_name_clone = socket_name.clone();
let server_handle = thread::spawn(move || {
let config = SocketServerConfig::with_path(&socket_name_clone);
let server = match SocketServer::new(config) {
Ok(s) => s,
Err(e) => {
eprintln!("Failed to create server: {}", e);
return;
}
};
server_ready_clone.store(true, Ordering::SeqCst);
if let Ok(mut conn) = server.accept() {
if let Ok(msg) = conn.recv() {
if msg.method() == Some("ping") {
conn.send(&Message::response(serde_json::json!({"pong": true})))
.ok();
}
}
}
});
let start = std::time::Instant::now();
while !server_ready.load(Ordering::SeqCst) {
if start.elapsed() > Duration::from_secs(5) {
panic!("Server failed to start within timeout");
}
thread::sleep(Duration::from_millis(10));
}
thread::sleep(Duration::from_millis(100));
let mut client = None;
for _ in 0..10 {
match SocketClient::connect(&socket_name) {
Ok(c) => {
client = Some(c);
break;
}
Err(_) => {
thread::sleep(Duration::from_millis(50));
}
}
}
let mut client = client.expect("Failed to connect to server");
let result = client.request("ping", serde_json::json!({})).unwrap();
assert_eq!(result["pong"], true);
server_handle.join().unwrap();
}
}