use super::*;
use crate::{McpToolsError, Result};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, error, info};
#[async_trait::async_trait]
pub trait McpClientBase: Send + Sync {
async fn connect(&mut self) -> Result<()>;
async fn disconnect(&mut self) -> Result<()>;
async fn get_server_capabilities(&self) -> Result<ServerCapabilities>;
async fn execute_tool(&self, request: McpToolRequest) -> Result<McpToolResponse>;
async fn get_status(&self) -> Result<ConnectionStatus>;
}
pub struct BaseClient {
config: ClientConfig,
status: Arc<RwLock<ConnectionStatus>>,
capabilities: Arc<RwLock<Option<ServerCapabilities>>>,
}
impl BaseClient {
pub fn new(config: ClientConfig) -> Self {
Self {
config,
status: Arc::new(RwLock::new(ConnectionStatus::Disconnected)),
capabilities: Arc::new(RwLock::new(None)),
}
}
pub fn config(&self) -> &ClientConfig {
&self.config
}
async fn set_status(&self, status: ConnectionStatus) {
let mut current_status = self.status.write().await;
*current_status = status;
}
async fn cache_capabilities(&self, capabilities: ServerCapabilities) {
let mut cached = self.capabilities.write().await;
*cached = Some(capabilities);
}
}
#[async_trait::async_trait]
impl McpClientBase for BaseClient {
async fn connect(&mut self) -> Result<()> {
info!("Connecting to MCP server: {}", self.config.server_url);
self.set_status(ConnectionStatus::Connecting).await;
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
self.set_status(ConnectionStatus::Connected).await;
info!("Connected to MCP server successfully");
Ok(())
}
async fn disconnect(&mut self) -> Result<()> {
info!("Disconnecting from MCP server");
self.set_status(ConnectionStatus::Disconnected).await;
let mut cached = self.capabilities.write().await;
*cached = None;
info!("Disconnected from MCP server");
Ok(())
}
async fn get_server_capabilities(&self) -> Result<ServerCapabilities> {
{
let cached = self.capabilities.read().await;
if let Some(capabilities) = cached.as_ref() {
return Ok(capabilities.clone());
}
}
let capabilities = ServerCapabilities {
tools: vec![],
features: vec!["mock".to_string()],
info: ServerInfo {
name: "Mock Server".to_string(),
version: "1.0.0".to_string(),
description: "Mock MCP Server".to_string(),
coderlib_version: "0.1.0".to_string(), protocol_version: "1.0".to_string(),
},
};
self.cache_capabilities(capabilities.clone()).await;
Ok(capabilities)
}
async fn execute_tool(&self, request: McpToolRequest) -> Result<McpToolResponse> {
debug!("Executing tool: {}", request.tool);
let status = self.status.read().await;
if *status != ConnectionStatus::Connected {
return Err(McpToolsError::Client("Not connected to server".to_string()));
}
Ok(McpToolResponse::error(
request.id,
"Tool execution not implemented in base client",
))
}
async fn get_status(&self) -> Result<ConnectionStatus> {
let status = self.status.read().await;
Ok(status.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_base_client_creation() {
let config = ClientConfig::default();
let client = BaseClient::new(config);
let status = client.get_status().await.unwrap();
assert_eq!(status, ConnectionStatus::Disconnected);
}
#[tokio::test]
async fn test_client_connection() {
let config = ClientConfig::default();
let mut client = BaseClient::new(config);
client.connect().await.unwrap();
let status = client.get_status().await.unwrap();
assert_eq!(status, ConnectionStatus::Connected);
client.disconnect().await.unwrap();
let status = client.get_status().await.unwrap();
assert_eq!(status, ConnectionStatus::Disconnected);
}
}