mcp-tools 0.1.0

Rust MCP tools library
Documentation
//! Base server implementation for MCP Tools

use super::*;
use crate::{McpToolsError, Result};
// use coderlib::{CoderLib, CoderLibConfig, PermissionService}; // TODO: Re-enable when coderlib is available
use std::sync::Arc;
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};

/// Base MCP server trait
#[async_trait::async_trait]
pub trait McpServerBase: Send + Sync {
    /// Get server capabilities
    async fn get_capabilities(&self) -> Result<ServerCapabilities>;

    /// Handle tool execution request
    async fn handle_tool_request(&self, request: McpToolRequest) -> Result<McpToolResponse>;

    /// Get server statistics
    async fn get_stats(&self) -> Result<ServerStats>;

    /// Initialize server
    async fn initialize(&mut self) -> Result<()>;

    /// Shutdown server
    async fn shutdown(&mut self) -> Result<()>;
}

/// Base server implementation (CoderLib integration disabled temporarily)
pub struct BaseServer {
    /// Server configuration
    config: ServerConfig,
    /// Server statistics
    stats: Arc<RwLock<ServerStatsInternal>>,
    /// Server start time
    start_time: Instant,
    /// Active sessions
    sessions: Arc<RwLock<std::collections::HashMap<String, SessionInfo>>>,
}

/// Internal server statistics
#[derive(Debug, Default)]
struct ServerStatsInternal {
    active_connections: usize,
    total_requests: u64,
    total_errors: u64,
    total_request_duration_ms: u64,
}

/// Session information
#[derive(Debug, Clone)]
struct SessionInfo {
    id: String,
    created_at: SystemTime,
    last_activity: SystemTime,
    request_count: u64,
}

impl BaseServer {
    /// Create new base server
    pub async fn new(config: ServerConfig) -> Result<Self> {
        info!("Initializing MCP server: {}", config.name);

        // TODO: Initialize CoderLib when available
        // let coderlib_config = CoderLibConfig::default();
        // let coderlib = Arc::new(CoderLib::new(coderlib_config).await?);
        // let permission_service = coderlib.permission_service();

        Ok(Self {
            config,
            stats: Arc::new(RwLock::new(ServerStatsInternal::default())),
            start_time: Instant::now(),
            sessions: Arc::new(RwLock::new(std::collections::HashMap::new())),
        })
    }

    // TODO: Re-enable when coderlib is available
    // /// Get CoderLib instance
    // pub fn coderlib(&self) -> Arc<CoderLib> {
    //     self.coderlib.clone()
    // }
    //
    // /// Get permission service
    // pub fn permission_service(&self) -> Arc<dyn PermissionService> {
    //     self.permission_service.clone()
    // }

    /// Get server configuration
    pub fn config(&self) -> &ServerConfig {
        &self.config
    }

    /// Record request start
    pub async fn record_request_start(&self, session_id: &str) -> RequestTracker {
        let mut stats = self.stats.write().await;
        stats.total_requests += 1;

        // Update session info
        let mut sessions = self.sessions.write().await;
        let now = SystemTime::now();

        if let Some(session) = sessions.get_mut(session_id) {
            session.last_activity = now;
            session.request_count += 1;
        } else {
            sessions.insert(
                session_id.to_string(),
                SessionInfo {
                    id: session_id.to_string(),
                    created_at: now,
                    last_activity: now,
                    request_count: 1,
                },
            );
        }

        RequestTracker {
            start_time: Instant::now(),
            stats: self.stats.clone(),
        }
    }

    /// Add connection
    pub async fn add_connection(&self) {
        let mut stats = self.stats.write().await;
        stats.active_connections += 1;
        info!(
            "New connection. Active connections: {}",
            stats.active_connections
        );
    }

    /// Remove connection
    pub async fn remove_connection(&self) {
        let mut stats = self.stats.write().await;
        if stats.active_connections > 0 {
            stats.active_connections -= 1;
        }
        info!(
            "Connection closed. Active connections: {}",
            stats.active_connections
        );
    }

    /// Get server information
    pub fn get_server_info(&self) -> ServerInfo {
        ServerInfo {
            name: self.config.name.clone(),
            version: self.config.version.clone(),
            description: self.config.description.clone(),
            coderlib_version: "0.1.0".to_string(), // TODO: Get from coderlib when available
            protocol_version: "1.0".to_string(),
        }
    }

    /// Validate tool request
    pub async fn validate_request(&self, request: &McpToolRequest) -> Result<()> {
        debug!("Validating tool request: {}", request.tool);

        // Check if tool exists (this would be implemented by specific servers)
        // For now, just log the request
        debug!(
            "Tool request validated: {} with args: {}",
            request.tool, request.arguments
        );

        Ok(())
    }

    // TODO: Re-enable when coderlib is available
    // /// Handle permission check for tool
    // pub async fn check_tool_permission(
    //     &self,
    //     session_id: &str,
    //     tool_name: &str,
    //     permission: coderlib::Permission,
    // ) -> Result<()> {
    //     let has_permission = self.permission_service
    //         .check_permission(session_id, tool_name, permission, None)
    //         .await
    //         .map_err(|e| McpToolsError::CoderLib(e.into()))?;
    //
    //     if !has_permission {
    //         warn!("Permission denied for tool {} in session {}", tool_name, session_id);
    //         return Err(McpToolsError::Server(
    //             format!("Permission denied for tool: {}", tool_name)
    //         ));
    //     }
    //
    //     debug!("Permission granted for tool {} in session {}", tool_name, session_id);
    //     Ok(())
    // }

    /// Create error response
    pub fn create_error_response(
        &self,
        request_id: Uuid,
        error: impl Into<String>,
    ) -> McpToolResponse {
        let error_msg = error.into();
        error!("Tool request error: {}", error_msg);
        McpToolResponse::error(request_id, error_msg)
    }

    /// Create success response
    pub fn create_success_response(
        &self,
        request_id: Uuid,
        content: Vec<McpContent>,
    ) -> McpToolResponse {
        debug!("Tool request successful: {} content items", content.len());
        McpToolResponse::success(request_id, content)
    }
}

#[async_trait::async_trait]
impl McpServerBase for BaseServer {
    async fn get_capabilities(&self) -> Result<ServerCapabilities> {
        Ok(ServerCapabilities {
            tools: vec![], // To be implemented by specific servers
            features: vec![
                "tool_execution".to_string(),
                "permission_system".to_string(),
                "session_management".to_string(),
            ],
            info: self.get_server_info(),
        })
    }

    async fn handle_tool_request(&self, request: McpToolRequest) -> Result<McpToolResponse> {
        let _tracker = self.record_request_start(&request.session_id).await;

        // Validate request
        self.validate_request(&request).await?;

        // This is a base implementation - specific servers will override this
        Ok(self.create_error_response(request.id, "Tool not implemented in base server"))
    }

    async fn get_stats(&self) -> Result<ServerStats> {
        let stats = self.stats.read().await;
        let uptime_secs = self.start_time.elapsed().as_secs();

        let avg_request_duration_ms = if stats.total_requests > 0 {
            stats.total_request_duration_ms as f64 / stats.total_requests as f64
        } else {
            0.0
        };

        Ok(ServerStats {
            active_connections: stats.active_connections,
            total_requests: stats.total_requests,
            total_errors: stats.total_errors,
            uptime_secs,
            avg_request_duration_ms,
        })
    }

    async fn initialize(&mut self) -> Result<()> {
        info!("Initializing server: {}", self.config.name);

        // Initialize CoderLib components
        // This could include setting up tools, loading configuration, etc.

        info!("Server initialized successfully");
        Ok(())
    }

    async fn shutdown(&mut self) -> Result<()> {
        info!("Shutting down server: {}", self.config.name);

        // Cleanup resources
        let mut sessions = self.sessions.write().await;
        sessions.clear();

        info!("Server shutdown complete");
        Ok(())
    }
}

/// Request tracker for timing and statistics
pub struct RequestTracker {
    start_time: Instant,
    stats: Arc<RwLock<ServerStatsInternal>>,
}

impl Drop for RequestTracker {
    fn drop(&mut self) {
        let duration_ms = self.start_time.elapsed().as_millis() as u64;

        // Update stats asynchronously
        let stats = self.stats.clone();
        tokio::spawn(async move {
            let mut stats = stats.write().await;
            stats.total_request_duration_ms += duration_ms;
        });
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_base_server_creation() {
        let config = ServerConfig::default();
        let server = BaseServer::new(config).await.unwrap();

        assert_eq!(server.config().name, "MCP Server");
        assert_eq!(server.config().port, 3000);
    }

    #[tokio::test]
    async fn test_server_capabilities() {
        let config = ServerConfig::default();
        let server = BaseServer::new(config).await.unwrap();

        let capabilities = server.get_capabilities().await.unwrap();
        assert!(!capabilities.features.is_empty());
        assert!(capabilities
            .features
            .contains(&"tool_execution".to_string()));
    }

    #[tokio::test]
    async fn test_server_stats() {
        let config = ServerConfig::default();
        let server = BaseServer::new(config).await.unwrap();

        let stats = server.get_stats().await.unwrap();
        assert_eq!(stats.active_connections, 0);
        assert_eq!(stats.total_requests, 0);
    }
}