reasonkit-core 0.1.8

The Reasoning Engine — Auditable Reasoning for Production AI | Rust-Native | Turn Prompts into Protocols
//! IPC Server Implementation
//!
//! Handles incoming IPC connections and routes requests to MCP servers
//! via the daemon's connection pool and server registry.

use super::config::ServerRegistry;
use super::pool::ConnectionPool;
use crate::error::{Error, Result};
use crate::mcp::tools::ToolResult;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::{broadcast, RwLock};
use tracing::{debug, error, info, warn};
use uuid::Uuid;

#[cfg(unix)]
use tokio::net::{UnixListener, UnixStream};

/// Tool info returned by list_tools
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolInfo {
    /// Tool name
    pub name: String,
    /// Server providing the tool
    pub server: String,
}

/// IPC message types
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum IpcMessage {
    // Client -> Daemon requests
    CallTool {
        id: String,
        tool: String,
        args: serde_json::Value,
    },
    ListTools {
        id: String,
    },
    Ping {
        id: String,
    },
    Shutdown {
        id: String,
    },
    GetStats {
        id: String,
    },

    // Daemon -> Client responses
    ToolResult {
        id: String,
        result: ToolResult,
    },
    ToolsList {
        id: String,
        tools: Vec<ToolInfo>,
    },
    Pong {
        id: String,
    },
    Stats {
        id: String,
        active_connections: usize,
        total_calls: u64,
        cache_hits: u64,
        cache_misses: u64,
    },
    Error {
        id: String,
        error: String,
    },
    Ok {
        id: String,
    },
}

/// Daemon IPC server
pub struct DaemonServer {
    registry: Arc<RwLock<ServerRegistry>>,
    pool: Arc<ConnectionPool>,
    shutdown_tx: broadcast::Sender<()>,
    #[cfg(unix)]
    listener: UnixListener,
}

impl DaemonServer {
    /// Create new daemon server
    #[cfg(unix)]
    pub async fn new(shutdown_tx: broadcast::Sender<()>) -> Result<Self> {
        let socket_path = super::manager::DaemonManager::get_socket_path()?;

        // Remove existing socket
        if socket_path.exists() {
            std::fs::remove_file(&socket_path)?;
        }

        // Create Unix socket listener
        let listener = UnixListener::bind(&socket_path)
            .map_err(|e| Error::daemon(format!("Failed to bind socket: {}", e)))?;

        // Set socket permissions (user-only)
        #[cfg(unix)]
        {
            use std::os::unix::fs::PermissionsExt;
            let perms = std::fs::Permissions::from_mode(0o600);
            std::fs::set_permissions(&socket_path, perms)?;
        }

        // Load server registry
        let registry = ServerRegistry::load().unwrap_or_else(|e| {
            warn!("Failed to load server registry: {}, using defaults", e);
            ServerRegistry::new()
        });

        info!("IPC server listening on {:?}", socket_path);

        Ok(Self {
            registry: Arc::new(RwLock::new(registry)),
            pool: Arc::new(ConnectionPool::new()),
            shutdown_tx,
            listener,
        })
    }

    /// Get the connection pool
    pub fn pool(&self) -> Arc<ConnectionPool> {
        self.pool.clone()
    }

    /// Run the server (blocks until shutdown)
    pub async fn run(&self) -> Result<()> {
        let mut shutdown_rx = self.shutdown_tx.subscribe();

        loop {
            tokio::select! {
                // Accept new connections
                result = self.listener.accept() => {
                    match result {
                        Ok((stream, _addr)) => {
                            let registry = self.registry.clone();
                            let pool = self.pool.clone();
                            let shutdown_tx = self.shutdown_tx.clone();

                            tokio::spawn(async move {
                                if let Err(e) = Self::handle_client(stream, registry, pool, shutdown_tx).await {
                                    error!("Client handler error: {}", e);
                                }
                            });
                        }
                        Err(e) => {
                            error!("Failed to accept connection: {}", e);
                        }
                    }
                }

                // Shutdown signal received
                _ = shutdown_rx.recv() => {
                    info!("Shutdown signal received, stopping server...");
                    self.shutdown().await?;
                    break;
                }
            }
        }

        Ok(())
    }

    /// Handle client connection
    #[cfg(unix)]
    async fn handle_client(
        mut stream: UnixStream,
        registry: Arc<RwLock<ServerRegistry>>,
        pool: Arc<ConnectionPool>,
        shutdown_tx: broadcast::Sender<()>,
    ) -> Result<()> {
        let client_id = Uuid::new_v4();
        debug!("Client connected: {}", client_id);

        loop {
            // Read message length
            let mut len_buf = [0u8; 4];
            match stream.read_exact(&mut len_buf).await {
                Ok(_) => {}
                Err(_) => {
                    // Client disconnected
                    debug!("Client disconnected: {}", client_id);
                    break;
                }
            }

            let len = u32::from_le_bytes(len_buf) as usize;

            // Enforce max message size (1 MB)
            if len > 1_000_000 {
                warn!("Client {} sent oversized message: {} bytes", client_id, len);
                break;
            }

            // Read message data
            let mut buf = vec![0u8; len];
            stream.read_exact(&mut buf).await?;

            // Deserialize message
            let message: IpcMessage = match serde_json::from_slice(&buf) {
                Ok(msg) => msg,
                Err(e) => {
                    warn!("Failed to deserialize message: {}", e);
                    continue;
                }
            };

            // Handle message
            let response = Self::process_message(message, &registry, &pool, &shutdown_tx).await;

            // Send response
            Self::send_message(&mut stream, &response).await?;
        }

        Ok(())
    }

    /// Process IPC message
    async fn process_message(
        msg: IpcMessage,
        registry: &Arc<RwLock<ServerRegistry>>,
        pool: &Arc<ConnectionPool>,
        shutdown_tx: &broadcast::Sender<()>,
    ) -> IpcMessage {
        match msg {
            IpcMessage::CallTool { id, tool, args } => {
                debug!("Calling tool: {}", tool);

                let reg = registry.read().await;

                // Get config for the tool
                let config = match reg.get_client_config(&tool) {
                    Some(c) => c,
                    None => {
                        return IpcMessage::Error {
                            id,
                            error: format!("Unknown tool: {}", tool),
                        }
                    }
                };

                // Call via pool
                match pool.call_tool(&config, &tool, args).await {
                    Ok(result) => IpcMessage::ToolResult { id, result },
                    Err(e) => IpcMessage::Error {
                        id,
                        error: format!("Tool execution failed: {}", e),
                    },
                }
            }

            IpcMessage::ListTools { id } => {
                let reg = registry.read().await;
                let tools: Vec<ToolInfo> = reg
                    .list_tools()
                    .iter()
                    .filter_map(|name| {
                        reg.get_server_for_tool(name).map(|server| ToolInfo {
                            name: name.to_string(),
                            server: server.name.clone(),
                        })
                    })
                    .collect();

                IpcMessage::ToolsList { id, tools }
            }

            IpcMessage::Ping { id } => IpcMessage::Pong { id },

            IpcMessage::GetStats { id } => {
                let stats = pool.stats();
                IpcMessage::Stats {
                    id,
                    active_connections: stats.active_connections,
                    total_calls: stats.total_calls,
                    cache_hits: stats.cache_hits,
                    cache_misses: stats.cache_misses,
                }
            }

            IpcMessage::Shutdown { id } => {
                info!("Shutdown requested via IPC");
                shutdown_tx.send(()).ok();
                IpcMessage::Ok { id }
            }

            _ => IpcMessage::Error {
                id: Uuid::new_v4().to_string(),
                error: "Invalid message type from client".to_string(),
            },
        }
    }

    /// Send IPC message
    #[cfg(unix)]
    async fn send_message(stream: &mut UnixStream, msg: &IpcMessage) -> Result<()> {
        let json = serde_json::to_vec(msg)?;
        let len = (json.len() as u32).to_le_bytes();

        stream.write_all(&len).await?;
        stream.write_all(&json).await?;
        stream.flush().await?;

        Ok(())
    }

    /// Graceful shutdown
    async fn shutdown(&self) -> Result<()> {
        info!("Shutting down daemon...");

        // Clear connection pool
        self.pool.clear();

        // Socket cleanup handled by Drop
        info!("Shutdown complete");
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_ipc_message_serialization() {
        let msg = IpcMessage::Ping {
            id: "test-123".to_string(),
        };

        let json = serde_json::to_string(&msg).unwrap();
        assert!(json.contains("Ping"));
        assert!(json.contains("test-123"));

        let deserialized: IpcMessage = serde_json::from_str(&json).unwrap();
        match deserialized {
            IpcMessage::Ping { id } => assert_eq!(id, "test-123"),
            _ => panic!("Wrong message type"),
        }
    }

    #[test]
    fn test_tool_info_serialization() {
        let info = ToolInfo {
            name: "gigathink".to_string(),
            server: "reasonkit-thinktools".to_string(),
        };

        let json = serde_json::to_string(&info).unwrap();
        assert!(json.contains("gigathink"));
        assert!(json.contains("reasonkit-thinktools"));
    }

    #[test]
    fn test_stats_message() {
        let msg = IpcMessage::Stats {
            id: "stats-1".to_string(),
            active_connections: 3,
            total_calls: 100,
            cache_hits: 80,
            cache_misses: 20,
        };

        let json = serde_json::to_string(&msg).unwrap();
        assert!(json.contains("Stats"));
        assert!(json.contains("100"));
    }

    #[test]
    fn test_ipc_message_size_limit() {
        let max_size = 1_000_000_u32;
        assert_eq!(max_size, 1_000_000);
    }
}