mcp_tools/common/
server_base.rs

1//! Base server implementation for MCP Tools
2
3use super::*;
4use crate::{McpToolsError, Result};
5// use coderlib::{CoderLib, CoderLibConfig, PermissionService}; // TODO: Re-enable when coderlib is available
6use std::sync::Arc;
7use std::time::{Instant, SystemTime, UNIX_EPOCH};
8use tokio::sync::RwLock;
9use tracing::{debug, error, info, warn};
10
11/// Base MCP server trait
12#[async_trait::async_trait]
13pub trait McpServerBase: Send + Sync {
14    /// Get server capabilities
15    async fn get_capabilities(&self) -> Result<ServerCapabilities>;
16
17    /// Handle tool execution request
18    async fn handle_tool_request(&self, request: McpToolRequest) -> Result<McpToolResponse>;
19
20    /// Get server statistics
21    async fn get_stats(&self) -> Result<ServerStats>;
22
23    /// Initialize server
24    async fn initialize(&mut self) -> Result<()>;
25
26    /// Shutdown server
27    async fn shutdown(&mut self) -> Result<()>;
28}
29
30/// Base server implementation (CoderLib integration disabled temporarily)
31pub struct BaseServer {
32    /// Server configuration
33    config: ServerConfig,
34    /// Server statistics
35    stats: Arc<RwLock<ServerStatsInternal>>,
36    /// Server start time
37    start_time: Instant,
38    /// Active sessions
39    sessions: Arc<RwLock<std::collections::HashMap<String, SessionInfo>>>,
40}
41
42/// Internal server statistics
43#[derive(Debug, Default)]
44struct ServerStatsInternal {
45    active_connections: usize,
46    total_requests: u64,
47    total_errors: u64,
48    total_request_duration_ms: u64,
49}
50
51/// Session information
52#[derive(Debug, Clone)]
53struct SessionInfo {
54    id: String,
55    created_at: SystemTime,
56    last_activity: SystemTime,
57    request_count: u64,
58}
59
60impl BaseServer {
61    /// Create new base server
62    pub async fn new(config: ServerConfig) -> Result<Self> {
63        info!("Initializing MCP server: {}", config.name);
64
65        // TODO: Initialize CoderLib when available
66        // let coderlib_config = CoderLibConfig::default();
67        // let coderlib = Arc::new(CoderLib::new(coderlib_config).await?);
68        // let permission_service = coderlib.permission_service();
69
70        Ok(Self {
71            config,
72            stats: Arc::new(RwLock::new(ServerStatsInternal::default())),
73            start_time: Instant::now(),
74            sessions: Arc::new(RwLock::new(std::collections::HashMap::new())),
75        })
76    }
77
78    // TODO: Re-enable when coderlib is available
79    // /// Get CoderLib instance
80    // pub fn coderlib(&self) -> Arc<CoderLib> {
81    //     self.coderlib.clone()
82    // }
83    //
84    // /// Get permission service
85    // pub fn permission_service(&self) -> Arc<dyn PermissionService> {
86    //     self.permission_service.clone()
87    // }
88
89    /// Get server configuration
90    pub fn config(&self) -> &ServerConfig {
91        &self.config
92    }
93
94    /// Record request start
95    pub async fn record_request_start(&self, session_id: &str) -> RequestTracker {
96        let mut stats = self.stats.write().await;
97        stats.total_requests += 1;
98
99        // Update session info
100        let mut sessions = self.sessions.write().await;
101        let now = SystemTime::now();
102
103        if let Some(session) = sessions.get_mut(session_id) {
104            session.last_activity = now;
105            session.request_count += 1;
106        } else {
107            sessions.insert(
108                session_id.to_string(),
109                SessionInfo {
110                    id: session_id.to_string(),
111                    created_at: now,
112                    last_activity: now,
113                    request_count: 1,
114                },
115            );
116        }
117
118        RequestTracker {
119            start_time: Instant::now(),
120            stats: self.stats.clone(),
121        }
122    }
123
124    /// Add connection
125    pub async fn add_connection(&self) {
126        let mut stats = self.stats.write().await;
127        stats.active_connections += 1;
128        info!(
129            "New connection. Active connections: {}",
130            stats.active_connections
131        );
132    }
133
134    /// Remove connection
135    pub async fn remove_connection(&self) {
136        let mut stats = self.stats.write().await;
137        if stats.active_connections > 0 {
138            stats.active_connections -= 1;
139        }
140        info!(
141            "Connection closed. Active connections: {}",
142            stats.active_connections
143        );
144    }
145
146    /// Get server information
147    pub fn get_server_info(&self) -> ServerInfo {
148        ServerInfo {
149            name: self.config.name.clone(),
150            version: self.config.version.clone(),
151            description: self.config.description.clone(),
152            coderlib_version: "0.1.0".to_string(), // TODO: Get from coderlib when available
153            protocol_version: "1.0".to_string(),
154        }
155    }
156
157    /// Validate tool request
158    pub async fn validate_request(&self, request: &McpToolRequest) -> Result<()> {
159        debug!("Validating tool request: {}", request.tool);
160
161        // Check if tool exists (this would be implemented by specific servers)
162        // For now, just log the request
163        debug!(
164            "Tool request validated: {} with args: {}",
165            request.tool, request.arguments
166        );
167
168        Ok(())
169    }
170
171    // TODO: Re-enable when coderlib is available
172    // /// Handle permission check for tool
173    // pub async fn check_tool_permission(
174    //     &self,
175    //     session_id: &str,
176    //     tool_name: &str,
177    //     permission: coderlib::Permission,
178    // ) -> Result<()> {
179    //     let has_permission = self.permission_service
180    //         .check_permission(session_id, tool_name, permission, None)
181    //         .await
182    //         .map_err(|e| McpToolsError::CoderLib(e.into()))?;
183    //
184    //     if !has_permission {
185    //         warn!("Permission denied for tool {} in session {}", tool_name, session_id);
186    //         return Err(McpToolsError::Server(
187    //             format!("Permission denied for tool: {}", tool_name)
188    //         ));
189    //     }
190    //
191    //     debug!("Permission granted for tool {} in session {}", tool_name, session_id);
192    //     Ok(())
193    // }
194
195    /// Create error response
196    pub fn create_error_response(
197        &self,
198        request_id: Uuid,
199        error: impl Into<String>,
200    ) -> McpToolResponse {
201        let error_msg = error.into();
202        error!("Tool request error: {}", error_msg);
203        McpToolResponse::error(request_id, error_msg)
204    }
205
206    /// Create success response
207    pub fn create_success_response(
208        &self,
209        request_id: Uuid,
210        content: Vec<McpContent>,
211    ) -> McpToolResponse {
212        debug!("Tool request successful: {} content items", content.len());
213        McpToolResponse::success(request_id, content)
214    }
215}
216
217#[async_trait::async_trait]
218impl McpServerBase for BaseServer {
219    async fn get_capabilities(&self) -> Result<ServerCapabilities> {
220        Ok(ServerCapabilities {
221            tools: vec![], // To be implemented by specific servers
222            features: vec![
223                "tool_execution".to_string(),
224                "permission_system".to_string(),
225                "session_management".to_string(),
226            ],
227            info: self.get_server_info(),
228        })
229    }
230
231    async fn handle_tool_request(&self, request: McpToolRequest) -> Result<McpToolResponse> {
232        let _tracker = self.record_request_start(&request.session_id).await;
233
234        // Validate request
235        self.validate_request(&request).await?;
236
237        // This is a base implementation - specific servers will override this
238        Ok(self.create_error_response(request.id, "Tool not implemented in base server"))
239    }
240
241    async fn get_stats(&self) -> Result<ServerStats> {
242        let stats = self.stats.read().await;
243        let uptime_secs = self.start_time.elapsed().as_secs();
244
245        let avg_request_duration_ms = if stats.total_requests > 0 {
246            stats.total_request_duration_ms as f64 / stats.total_requests as f64
247        } else {
248            0.0
249        };
250
251        Ok(ServerStats {
252            active_connections: stats.active_connections,
253            total_requests: stats.total_requests,
254            total_errors: stats.total_errors,
255            uptime_secs,
256            avg_request_duration_ms,
257        })
258    }
259
260    async fn initialize(&mut self) -> Result<()> {
261        info!("Initializing server: {}", self.config.name);
262
263        // Initialize CoderLib components
264        // This could include setting up tools, loading configuration, etc.
265
266        info!("Server initialized successfully");
267        Ok(())
268    }
269
270    async fn shutdown(&mut self) -> Result<()> {
271        info!("Shutting down server: {}", self.config.name);
272
273        // Cleanup resources
274        let mut sessions = self.sessions.write().await;
275        sessions.clear();
276
277        info!("Server shutdown complete");
278        Ok(())
279    }
280}
281
282/// Request tracker for timing and statistics
283pub struct RequestTracker {
284    start_time: Instant,
285    stats: Arc<RwLock<ServerStatsInternal>>,
286}
287
288impl Drop for RequestTracker {
289    fn drop(&mut self) {
290        let duration_ms = self.start_time.elapsed().as_millis() as u64;
291
292        // Update stats asynchronously
293        let stats = self.stats.clone();
294        tokio::spawn(async move {
295            let mut stats = stats.write().await;
296            stats.total_request_duration_ms += duration_ms;
297        });
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[tokio::test]
306    async fn test_base_server_creation() {
307        let config = ServerConfig::default();
308        let server = BaseServer::new(config).await.unwrap();
309
310        assert_eq!(server.config().name, "MCP Server");
311        assert_eq!(server.config().port, 3000);
312    }
313
314    #[tokio::test]
315    async fn test_server_capabilities() {
316        let config = ServerConfig::default();
317        let server = BaseServer::new(config).await.unwrap();
318
319        let capabilities = server.get_capabilities().await.unwrap();
320        assert!(!capabilities.features.is_empty());
321        assert!(capabilities
322            .features
323            .contains(&"tool_execution".to_string()));
324    }
325
326    #[tokio::test]
327    async fn test_server_stats() {
328        let config = ServerConfig::default();
329        let server = BaseServer::new(config).await.unwrap();
330
331        let stats = server.get_stats().await.unwrap();
332        assert_eq!(stats.active_connections, 0);
333        assert_eq!(stats.total_requests, 0);
334    }
335}