use parking_lot::RwLock;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc;
use tracing::{debug, info, warn};
use super::handler::{DefaultHandler, McpHandler, ToolExecutor};
use super::protocol::{
McpMessage, McpNotification, McpRequest, McpResponse, RequestId, ToolDefinition, ToolResult,
};
use super::transport::McpTransport;
use anyhow::{anyhow, Result};
#[derive(Debug, Clone)]
pub struct McpBridgeConfig {
pub server_name: String,
pub server_version: String,
pub request_timeout_ms: u64,
pub max_concurrent_requests: usize,
}
impl Default for McpBridgeConfig {
fn default() -> Self {
Self {
server_name: "Continuum".to_string(),
server_version: "0.1.0".to_string(),
request_timeout_ms: 30000,
max_concurrent_requests: 100,
}
}
}
pub struct McpBridge {
transport: RwLock<Option<Arc<dyn McpTransport>>>,
handler: Arc<DefaultHandler>,
config: McpBridgeConfig,
request_id_counter: AtomicU64,
pending_responses: Arc<RwLock<HashMap<RequestId, mpsc::Sender<McpResponse>>>>,
running: Arc<AtomicBool>,
}
impl McpBridge {
pub fn new(config: McpBridgeConfig) -> Self {
let handler = DefaultHandler::new(&config.server_name, &config.server_version);
Self {
transport: RwLock::new(None),
handler: Arc::new(handler),
config,
request_id_counter: AtomicU64::new(0),
pending_responses: Arc::new(RwLock::new(HashMap::new())),
running: Arc::new(AtomicBool::new(false)),
}
}
pub fn with_transport(self, transport: Box<dyn McpTransport>) -> Self {
*self.transport.write() = Some(Arc::from(transport));
self
}
pub fn register_tool(&self, tool: ToolDefinition, executor: Arc<dyn ToolExecutor>) {
self.handler.register_tool(tool, executor);
}
pub fn register_simple_tool<F>(&self, name: &str, description: &str, executor: F)
where
F: Fn(&str, Value) -> Result<ToolResult> + Send + Sync + 'static,
{
let tool = ToolDefinition {
name: name.to_string(),
description: Some(description.to_string()),
input_schema: None,
};
self.register_tool(tool, Arc::new(super::handler::SimpleToolExecutor(executor)));
}
fn next_request_id(&self) -> RequestId {
RequestId::Number(self.request_id_counter.fetch_add(1, Ordering::SeqCst) as i64)
}
pub async fn start(&self) -> Result<()> {
self.running.store(true, Ordering::SeqCst);
let transport_opt = {
let transport_guard = self.transport.read();
transport_guard.clone()
};
let handler = self.handler.clone();
let pending = self.pending_responses.clone();
let running = self.running.clone();
tokio::spawn(async move {
info!("MCP message loop started");
if transport_opt.is_none() {
info!("No transport configured, message loop will idle");
}
loop {
if !running.load(Ordering::SeqCst) {
info!("MCP message loop stopping");
break;
}
if let Some(ref t) = transport_opt {
match t.receive().await {
Ok(Some(message)) => {
match message {
McpMessage::Request(request) => {
match handler.handle(&request).await {
Ok(response) => {
if let Err(e) =
t.send(&McpMessage::Response(response)).await
{
warn!("Failed to send response: {}", e);
}
}
Err(e) => {
warn!(
"Handler error for request {:?}: {}",
request.id, e
);
}
}
}
McpMessage::Notification(notification) => {
if let Err(e) = handler.handle_notification(¬ification).await
{
warn!("Notification handler error: {}", e);
}
}
McpMessage::Response(response) => {
let sender_opt = pending.write().remove(&response.id);
if let Some(sender) = sender_opt {
if let Err(e) = sender.send(response).await {
warn!(
"Failed to forward response to pending request: {}",
e
);
}
} else {
debug!(
"Received response for unknown request {:?}",
response.id
);
}
}
McpMessage::Error(error) => {
warn!("Received error: {:?}", error);
}
}
}
Ok(None) => {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
Err(e) => {
warn!("Transport receive error: {}", e);
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}
} else {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}
pending.write().clear();
info!("MCP message loop stopped");
});
Ok(())
}
pub async fn stop(&self) -> Result<()> {
self.running.store(false, Ordering::SeqCst);
let transport = self.transport.write().take();
if let Some(transport) = transport {
transport.close().await?;
}
Ok(())
}
pub async fn request(&self, method: &str, params: Option<Value>) -> Result<McpResponse> {
let id = self.next_request_id();
let timeout_duration = std::time::Duration::from_millis(self.config.request_timeout_ms);
let (tx, mut rx) = mpsc::channel::<McpResponse>(1);
{
self.pending_responses.write().insert(id.clone(), tx);
}
let request = McpRequest {
id: id.clone(),
method: method.to_string(),
params,
};
let message = McpMessage::Request(request);
let transport = {
let transport_guard = self.transport.read();
transport_guard
.as_ref()
.ok_or_else(|| anyhow!("Transport not initialized"))?
.clone()
};
transport.send(&message).await?;
let result = tokio::time::timeout(timeout_duration, rx.recv()).await;
self.pending_responses.write().remove(&id);
match result {
Ok(Some(response)) => Ok(response),
Ok(None) => Err(anyhow!("Response channel closed")),
Err(_) => Err(anyhow!(
"Request timeout after {}ms",
self.config.request_timeout_ms
)),
}
}
#[allow(clippy::await_holding_lock)]
pub async fn notify(&self, method: &str, params: Option<Value>) -> Result<()> {
let notification = McpNotification {
method: method.to_string(),
params,
};
let message = McpMessage::Notification(notification);
let transport_guard = self.transport.read();
let transport = transport_guard
.as_ref()
.ok_or_else(|| anyhow!("Transport not initialized"))?;
transport.send(&message).await?;
Ok(())
}
pub async fn list_tools(&self) -> Result<Vec<ToolDefinition>> {
let response = self.request("tools/list", None).await?;
if let Some(result) = response.result {
let tools: Vec<ToolDefinition> = serde_json::from_value(
result.get("tools").cloned().unwrap_or(Value::Array(vec![])),
)?;
Ok(tools)
} else {
Ok(vec![])
}
}
pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<ToolResult> {
let params = serde_json::json!({
"name": name,
"arguments": arguments
});
let response = self.request("tools/call", Some(params)).await?;
if let Some(result) = response.result {
let tool_result: ToolResult = serde_json::from_value(result)?;
Ok(tool_result)
} else if let Some(error) = response.error {
Err(anyhow!("Tool call error: {}", error.message))
} else {
Err(anyhow!("Unknown error"))
}
}
pub async fn initialize(&self, client_info: &str, version: &str) -> Result<()> {
let params = serde_json::json!({
"protocol_version": "2024-11-05",
"capabilities": {},
"client_info": {
"name": client_info,
"version": version
}
});
let response = self.request("initialize", Some(params)).await?;
if response.error.is_some() {
return Err(anyhow!("Initialize failed"));
}
self.notify("notifications/initialized", None).await?;
Ok(())
}
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mcp_bridge::protocol::ContentBlock;
#[tokio::test]
async fn test_bridge_creation() {
let config = McpBridgeConfig::default();
let bridge = McpBridge::new(config);
assert!(!bridge.is_running());
}
#[tokio::test]
async fn test_register_tool() {
let bridge = McpBridge::new(McpBridgeConfig::default());
bridge.register_simple_tool("test_tool", "A test tool", |_name, _args| {
Ok(ToolResult {
is_error: false,
content: vec![ContentBlock::Text {
text: "OK".to_string(),
}],
})
});
}
#[tokio::test]
async fn test_next_request_id() {
let bridge = McpBridge::new(McpBridgeConfig::default());
let id1 = bridge.next_request_id();
let id2 = bridge.next_request_id();
assert_ne!(id1, id2);
}
#[test]
fn test_config_default() {
let config = McpBridgeConfig::default();
assert_eq!(config.server_name, "Continuum");
assert_eq!(config.request_timeout_ms, 30000);
}
}