use super::*;
use crate::{McpToolsError, Result};
use std::sync::Arc;
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
#[async_trait::async_trait]
pub trait McpServerBase: Send + Sync {
async fn get_capabilities(&self) -> Result<ServerCapabilities>;
async fn handle_tool_request(&self, request: McpToolRequest) -> Result<McpToolResponse>;
async fn get_stats(&self) -> Result<ServerStats>;
async fn initialize(&mut self) -> Result<()>;
async fn shutdown(&mut self) -> Result<()>;
}
pub struct BaseServer {
config: ServerConfig,
stats: Arc<RwLock<ServerStatsInternal>>,
start_time: Instant,
sessions: Arc<RwLock<std::collections::HashMap<String, SessionInfo>>>,
}
#[derive(Debug, Default)]
struct ServerStatsInternal {
active_connections: usize,
total_requests: u64,
total_errors: u64,
total_request_duration_ms: u64,
}
#[derive(Debug, Clone)]
struct SessionInfo {
id: String,
created_at: SystemTime,
last_activity: SystemTime,
request_count: u64,
}
impl BaseServer {
pub async fn new(config: ServerConfig) -> Result<Self> {
info!("Initializing MCP server: {}", config.name);
Ok(Self {
config,
stats: Arc::new(RwLock::new(ServerStatsInternal::default())),
start_time: Instant::now(),
sessions: Arc::new(RwLock::new(std::collections::HashMap::new())),
})
}
pub fn config(&self) -> &ServerConfig {
&self.config
}
pub async fn record_request_start(&self, session_id: &str) -> RequestTracker {
let mut stats = self.stats.write().await;
stats.total_requests += 1;
let mut sessions = self.sessions.write().await;
let now = SystemTime::now();
if let Some(session) = sessions.get_mut(session_id) {
session.last_activity = now;
session.request_count += 1;
} else {
sessions.insert(
session_id.to_string(),
SessionInfo {
id: session_id.to_string(),
created_at: now,
last_activity: now,
request_count: 1,
},
);
}
RequestTracker {
start_time: Instant::now(),
stats: self.stats.clone(),
}
}
pub async fn add_connection(&self) {
let mut stats = self.stats.write().await;
stats.active_connections += 1;
info!(
"New connection. Active connections: {}",
stats.active_connections
);
}
pub async fn remove_connection(&self) {
let mut stats = self.stats.write().await;
if stats.active_connections > 0 {
stats.active_connections -= 1;
}
info!(
"Connection closed. Active connections: {}",
stats.active_connections
);
}
pub fn get_server_info(&self) -> ServerInfo {
ServerInfo {
name: self.config.name.clone(),
version: self.config.version.clone(),
description: self.config.description.clone(),
coderlib_version: "0.1.0".to_string(), protocol_version: "1.0".to_string(),
}
}
pub async fn validate_request(&self, request: &McpToolRequest) -> Result<()> {
debug!("Validating tool request: {}", request.tool);
debug!(
"Tool request validated: {} with args: {}",
request.tool, request.arguments
);
Ok(())
}
pub fn create_error_response(
&self,
request_id: Uuid,
error: impl Into<String>,
) -> McpToolResponse {
let error_msg = error.into();
error!("Tool request error: {}", error_msg);
McpToolResponse::error(request_id, error_msg)
}
pub fn create_success_response(
&self,
request_id: Uuid,
content: Vec<McpContent>,
) -> McpToolResponse {
debug!("Tool request successful: {} content items", content.len());
McpToolResponse::success(request_id, content)
}
}
#[async_trait::async_trait]
impl McpServerBase for BaseServer {
async fn get_capabilities(&self) -> Result<ServerCapabilities> {
Ok(ServerCapabilities {
tools: vec![], features: vec![
"tool_execution".to_string(),
"permission_system".to_string(),
"session_management".to_string(),
],
info: self.get_server_info(),
})
}
async fn handle_tool_request(&self, request: McpToolRequest) -> Result<McpToolResponse> {
let _tracker = self.record_request_start(&request.session_id).await;
self.validate_request(&request).await?;
Ok(self.create_error_response(request.id, "Tool not implemented in base server"))
}
async fn get_stats(&self) -> Result<ServerStats> {
let stats = self.stats.read().await;
let uptime_secs = self.start_time.elapsed().as_secs();
let avg_request_duration_ms = if stats.total_requests > 0 {
stats.total_request_duration_ms as f64 / stats.total_requests as f64
} else {
0.0
};
Ok(ServerStats {
active_connections: stats.active_connections,
total_requests: stats.total_requests,
total_errors: stats.total_errors,
uptime_secs,
avg_request_duration_ms,
})
}
async fn initialize(&mut self) -> Result<()> {
info!("Initializing server: {}", self.config.name);
info!("Server initialized successfully");
Ok(())
}
async fn shutdown(&mut self) -> Result<()> {
info!("Shutting down server: {}", self.config.name);
let mut sessions = self.sessions.write().await;
sessions.clear();
info!("Server shutdown complete");
Ok(())
}
}
pub struct RequestTracker {
start_time: Instant,
stats: Arc<RwLock<ServerStatsInternal>>,
}
impl Drop for RequestTracker {
fn drop(&mut self) {
let duration_ms = self.start_time.elapsed().as_millis() as u64;
let stats = self.stats.clone();
tokio::spawn(async move {
let mut stats = stats.write().await;
stats.total_request_duration_ms += duration_ms;
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_base_server_creation() {
let config = ServerConfig::default();
let server = BaseServer::new(config).await.unwrap();
assert_eq!(server.config().name, "MCP Server");
assert_eq!(server.config().port, 3000);
}
#[tokio::test]
async fn test_server_capabilities() {
let config = ServerConfig::default();
let server = BaseServer::new(config).await.unwrap();
let capabilities = server.get_capabilities().await.unwrap();
assert!(!capabilities.features.is_empty());
assert!(capabilities
.features
.contains(&"tool_execution".to_string()));
}
#[tokio::test]
async fn test_server_stats() {
let config = ServerConfig::default();
let server = BaseServer::new(config).await.unwrap();
let stats = server.get_stats().await.unwrap();
assert_eq!(stats.active_connections, 0);
assert_eq!(stats.total_requests, 0);
}
}