use chrono::Utc;
use dashmap::DashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration as StdDuration, Instant};
use tokio::sync::{Mutex, OwnedSemaphorePermit, RwLock, Semaphore};
use tracing::{error, info, warn};
use crate::mcp::config::{McpConfig, McpServerConfig, TransportConfig};
use crate::mcp::error::{McpError, Result};
use crate::mcp::protocol::{McpProtocolClient, McpTransport};
use crate::mcp::tool_index::ToolIndex;
use crate::mcp::transports::{SseTransport, StdioTransport, StreamableHttpTransport};
use crate::mcp::types::{McpEvent, McpTool, RuntimeInfo, ServerStatus};
use bamboo_infrastructure::Config;
mod config_sync;
mod fingerprint;
mod lifecycle;
mod reconnect;
#[cfg(test)]
mod tests;
const DEFAULT_MAX_CONCURRENT_CALLS_PER_SERVER: usize = 4;
const DEFAULT_CIRCUIT_FAILURE_THRESHOLD: u32 = 3;
const DEFAULT_CIRCUIT_OPEN_MS: u64 = 5_000;
#[derive(Debug, Clone, Copy)]
struct McpQosConfig {
max_concurrent_calls: usize,
circuit_failure_threshold: u32,
circuit_open_ms: u64,
}
impl Default for McpQosConfig {
fn default() -> Self {
Self {
max_concurrent_calls: DEFAULT_MAX_CONCURRENT_CALLS_PER_SERVER,
circuit_failure_threshold: DEFAULT_CIRCUIT_FAILURE_THRESHOLD,
circuit_open_ms: DEFAULT_CIRCUIT_OPEN_MS,
}
}
}
#[derive(Debug, Default)]
struct McpQosState {
consecutive_failures: u32,
circuit_open_until: Option<Instant>,
}
#[derive(Debug)]
struct McpServerQos {
config: McpQosConfig,
permits: Arc<Semaphore>,
state: Mutex<McpQosState>,
}
impl McpServerQos {
fn new(config: McpQosConfig) -> Self {
let max_permits = config.max_concurrent_calls.max(1);
Self {
config,
permits: Arc::new(Semaphore::new(max_permits)),
state: Mutex::new(McpQosState::default()),
}
}
async fn check_circuit(&self, server_id: &str, tool_name: &str) -> Result<()> {
let mut state = self.state.lock().await;
if let Some(open_until) = state.circuit_open_until {
let now = Instant::now();
if now < open_until {
let remaining = open_until.saturating_duration_since(now).as_millis();
return Err(McpError::ToolExecution(format!(
"MCP QoS circuit open for server '{}' (tool '{}'), retry in ~{}ms",
server_id, tool_name, remaining
)));
}
state.circuit_open_until = None;
state.consecutive_failures = 0;
}
Ok(())
}
async fn acquire_permit(&self) -> Result<OwnedSemaphorePermit> {
self.permits.clone().acquire_owned().await.map_err(|error| {
McpError::ToolExecution(format!("MCP QoS permit unavailable: {error}"))
})
}
async fn record_success(&self) {
let mut state = self.state.lock().await;
state.consecutive_failures = 0;
state.circuit_open_until = None;
}
async fn record_failure(&self, server_id: &str, tool_name: &str, error: &McpError) {
let mut state = self.state.lock().await;
state.consecutive_failures = state.consecutive_failures.saturating_add(1);
if state.consecutive_failures >= self.config.circuit_failure_threshold {
state.circuit_open_until =
Some(Instant::now() + StdDuration::from_millis(self.config.circuit_open_ms));
warn!(
"MCP QoS opening circuit for server '{}' after {} consecutive failures (tool '{}', last_error={})",
server_id, state.consecutive_failures, tool_name, error
);
}
}
}
struct ServerRuntime {
config: McpServerConfig,
client: RwLock<McpProtocolClient>,
info: RwLock<RuntimeInfo>,
tools: RwLock<Vec<McpTool>>,
shutdown: AtomicBool,
reconnecting: AtomicBool,
qos: McpServerQos,
proxy_fingerprint: Option<String>,
}
pub struct McpServerManager {
runtimes: DashMap<String, Arc<ServerRuntime>>,
index: Arc<ToolIndex>,
event_tx: Option<tokio::sync::mpsc::Sender<McpEvent>>,
config: Option<Arc<tokio::sync::RwLock<Config>>>,
}
impl Clone for McpServerManager {
fn clone(&self) -> Self {
Self {
runtimes: self.runtimes.clone(),
index: self.index.clone(),
event_tx: self.event_tx.clone(),
config: self.config.clone(),
}
}
}
impl McpServerManager {
pub fn new() -> Self {
Self {
runtimes: DashMap::new(),
index: Arc::new(ToolIndex::new()),
event_tx: None,
config: None,
}
}
pub fn new_with_config(config: Arc<tokio::sync::RwLock<Config>>) -> Self {
Self {
runtimes: DashMap::new(),
index: Arc::new(ToolIndex::new()),
event_tx: None,
config: Some(config),
}
}
pub fn with_event_channel(mut self, tx: tokio::sync::mpsc::Sender<McpEvent>) -> Self {
self.event_tx = Some(tx);
self
}
pub fn tool_index(&self) -> Arc<ToolIndex> {
self.index.clone()
}
pub fn list_servers(&self) -> Vec<String> {
self.runtimes
.iter()
.map(|entry| entry.key().clone())
.collect()
}
pub fn get_server_info(&self, server_id: &str) -> Option<RuntimeInfo> {
self.runtimes
.get(server_id)
.and_then(|runtime| runtime.info.try_read().ok().map(|info| info.clone()))
}
pub fn is_server_running(&self, server_id: &str) -> bool {
self.runtimes.contains_key(server_id)
}
pub async fn shutdown_all(&self) {
let server_ids: Vec<String> = self.list_servers();
for server_id in server_ids {
if let Err(e) = self.stop_server(&server_id).await {
error!("Error stopping server '{}': {}", server_id, e);
}
}
}
}
impl Default for McpServerManager {
fn default() -> Self {
Self::new()
}
}