use super::config::ServerRegistry;
use super::pool::ConnectionPool;
use crate::error::{Error, Result};
use crate::mcp::tools::ToolResult;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::{broadcast, RwLock};
use tracing::{debug, error, info, warn};
use uuid::Uuid;
#[cfg(unix)]
use tokio::net::{UnixListener, UnixStream};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolInfo {
pub name: String,
pub server: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum IpcMessage {
CallTool {
id: String,
tool: String,
args: serde_json::Value,
},
ListTools {
id: String,
},
Ping {
id: String,
},
Shutdown {
id: String,
},
GetStats {
id: String,
},
ToolResult {
id: String,
result: ToolResult,
},
ToolsList {
id: String,
tools: Vec<ToolInfo>,
},
Pong {
id: String,
},
Stats {
id: String,
active_connections: usize,
total_calls: u64,
cache_hits: u64,
cache_misses: u64,
},
Error {
id: String,
error: String,
},
Ok {
id: String,
},
}
pub struct DaemonServer {
registry: Arc<RwLock<ServerRegistry>>,
pool: Arc<ConnectionPool>,
shutdown_tx: broadcast::Sender<()>,
#[cfg(unix)]
listener: UnixListener,
}
impl DaemonServer {
#[cfg(unix)]
pub async fn new(shutdown_tx: broadcast::Sender<()>) -> Result<Self> {
let socket_path = super::manager::DaemonManager::get_socket_path()?;
if socket_path.exists() {
std::fs::remove_file(&socket_path)?;
}
let listener = UnixListener::bind(&socket_path)
.map_err(|e| Error::daemon(format!("Failed to bind socket: {}", e)))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let perms = std::fs::Permissions::from_mode(0o600);
std::fs::set_permissions(&socket_path, perms)?;
}
let registry = ServerRegistry::load().unwrap_or_else(|e| {
warn!("Failed to load server registry: {}, using defaults", e);
ServerRegistry::new()
});
info!("IPC server listening on {:?}", socket_path);
Ok(Self {
registry: Arc::new(RwLock::new(registry)),
pool: Arc::new(ConnectionPool::new()),
shutdown_tx,
listener,
})
}
pub fn pool(&self) -> Arc<ConnectionPool> {
self.pool.clone()
}
pub async fn run(&self) -> Result<()> {
let mut shutdown_rx = self.shutdown_tx.subscribe();
loop {
tokio::select! {
result = self.listener.accept() => {
match result {
Ok((stream, _addr)) => {
let registry = self.registry.clone();
let pool = self.pool.clone();
let shutdown_tx = self.shutdown_tx.clone();
tokio::spawn(async move {
if let Err(e) = Self::handle_client(stream, registry, pool, shutdown_tx).await {
error!("Client handler error: {}", e);
}
});
}
Err(e) => {
error!("Failed to accept connection: {}", e);
}
}
}
_ = shutdown_rx.recv() => {
info!("Shutdown signal received, stopping server...");
self.shutdown().await?;
break;
}
}
}
Ok(())
}
#[cfg(unix)]
async fn handle_client(
mut stream: UnixStream,
registry: Arc<RwLock<ServerRegistry>>,
pool: Arc<ConnectionPool>,
shutdown_tx: broadcast::Sender<()>,
) -> Result<()> {
let client_id = Uuid::new_v4();
debug!("Client connected: {}", client_id);
loop {
let mut len_buf = [0u8; 4];
match stream.read_exact(&mut len_buf).await {
Ok(_) => {}
Err(_) => {
debug!("Client disconnected: {}", client_id);
break;
}
}
let len = u32::from_le_bytes(len_buf) as usize;
if len > 1_000_000 {
warn!("Client {} sent oversized message: {} bytes", client_id, len);
break;
}
let mut buf = vec![0u8; len];
stream.read_exact(&mut buf).await?;
let message: IpcMessage = match serde_json::from_slice(&buf) {
Ok(msg) => msg,
Err(e) => {
warn!("Failed to deserialize message: {}", e);
continue;
}
};
let response = Self::process_message(message, ®istry, &pool, &shutdown_tx).await;
Self::send_message(&mut stream, &response).await?;
}
Ok(())
}
async fn process_message(
msg: IpcMessage,
registry: &Arc<RwLock<ServerRegistry>>,
pool: &Arc<ConnectionPool>,
shutdown_tx: &broadcast::Sender<()>,
) -> IpcMessage {
match msg {
IpcMessage::CallTool { id, tool, args } => {
debug!("Calling tool: {}", tool);
let reg = registry.read().await;
let config = match reg.get_client_config(&tool) {
Some(c) => c,
None => {
return IpcMessage::Error {
id,
error: format!("Unknown tool: {}", tool),
}
}
};
match pool.call_tool(&config, &tool, args).await {
Ok(result) => IpcMessage::ToolResult { id, result },
Err(e) => IpcMessage::Error {
id,
error: format!("Tool execution failed: {}", e),
},
}
}
IpcMessage::ListTools { id } => {
let reg = registry.read().await;
let tools: Vec<ToolInfo> = reg
.list_tools()
.iter()
.filter_map(|name| {
reg.get_server_for_tool(name).map(|server| ToolInfo {
name: name.to_string(),
server: server.name.clone(),
})
})
.collect();
IpcMessage::ToolsList { id, tools }
}
IpcMessage::Ping { id } => IpcMessage::Pong { id },
IpcMessage::GetStats { id } => {
let stats = pool.stats();
IpcMessage::Stats {
id,
active_connections: stats.active_connections,
total_calls: stats.total_calls,
cache_hits: stats.cache_hits,
cache_misses: stats.cache_misses,
}
}
IpcMessage::Shutdown { id } => {
info!("Shutdown requested via IPC");
shutdown_tx.send(()).ok();
IpcMessage::Ok { id }
}
_ => IpcMessage::Error {
id: Uuid::new_v4().to_string(),
error: "Invalid message type from client".to_string(),
},
}
}
#[cfg(unix)]
async fn send_message(stream: &mut UnixStream, msg: &IpcMessage) -> Result<()> {
let json = serde_json::to_vec(msg)?;
let len = (json.len() as u32).to_le_bytes();
stream.write_all(&len).await?;
stream.write_all(&json).await?;
stream.flush().await?;
Ok(())
}
async fn shutdown(&self) -> Result<()> {
info!("Shutting down daemon...");
self.pool.clear();
info!("Shutdown complete");
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ipc_message_serialization() {
let msg = IpcMessage::Ping {
id: "test-123".to_string(),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("Ping"));
assert!(json.contains("test-123"));
let deserialized: IpcMessage = serde_json::from_str(&json).unwrap();
match deserialized {
IpcMessage::Ping { id } => assert_eq!(id, "test-123"),
_ => panic!("Wrong message type"),
}
}
#[test]
fn test_tool_info_serialization() {
let info = ToolInfo {
name: "gigathink".to_string(),
server: "reasonkit-thinktools".to_string(),
};
let json = serde_json::to_string(&info).unwrap();
assert!(json.contains("gigathink"));
assert!(json.contains("reasonkit-thinktools"));
}
#[test]
fn test_stats_message() {
let msg = IpcMessage::Stats {
id: "stats-1".to_string(),
active_connections: 3,
total_calls: 100,
cache_hits: 80,
cache_misses: 20,
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("Stats"));
assert!(json.contains("100"));
}
#[test]
fn test_ipc_message_size_limit() {
let max_size = 1_000_000_u32;
assert_eq!(max_size, 1_000_000);
}
}