bamboo_engine/mcp/manager/
mod.rs1use chrono::Utc;
2use dashmap::DashMap;
3use std::sync::atomic::{AtomicBool, Ordering};
4use std::sync::Arc;
5use std::time::{Duration as StdDuration, Instant};
6use tokio::sync::{Mutex, OwnedSemaphorePermit, RwLock, Semaphore};
7use tracing::{error, info, warn};
8
9use crate::mcp::config::{McpConfig, McpServerConfig, TransportConfig};
10use crate::mcp::error::{McpError, Result};
11use crate::mcp::protocol::{McpProtocolClient, McpTransport};
12use crate::mcp::tool_index::ToolIndex;
13use crate::mcp::transports::{SseTransport, StdioTransport, StreamableHttpTransport};
14use crate::mcp::types::{McpEvent, McpTool, RuntimeInfo, ServerStatus};
15use bamboo_infrastructure::Config;
16
17mod config_sync;
18mod fingerprint;
19mod lifecycle;
20mod reconnect;
21
22#[cfg(test)]
23mod tests;
24
25const DEFAULT_MAX_CONCURRENT_CALLS_PER_SERVER: usize = 4;
26const DEFAULT_CIRCUIT_FAILURE_THRESHOLD: u32 = 3;
27const DEFAULT_CIRCUIT_OPEN_MS: u64 = 5_000;
28
29#[derive(Debug, Clone, Copy)]
30struct McpQosConfig {
31 max_concurrent_calls: usize,
32 circuit_failure_threshold: u32,
33 circuit_open_ms: u64,
34}
35
36impl Default for McpQosConfig {
37 fn default() -> Self {
38 Self {
39 max_concurrent_calls: DEFAULT_MAX_CONCURRENT_CALLS_PER_SERVER,
40 circuit_failure_threshold: DEFAULT_CIRCUIT_FAILURE_THRESHOLD,
41 circuit_open_ms: DEFAULT_CIRCUIT_OPEN_MS,
42 }
43 }
44}
45
46#[derive(Debug, Default)]
47struct McpQosState {
48 consecutive_failures: u32,
49 circuit_open_until: Option<Instant>,
50}
51
52#[derive(Debug)]
53struct McpServerQos {
54 config: McpQosConfig,
55 permits: Arc<Semaphore>,
56 state: Mutex<McpQosState>,
57}
58
59impl McpServerQos {
60 fn new(config: McpQosConfig) -> Self {
61 let max_permits = config.max_concurrent_calls.max(1);
62 Self {
63 config,
64 permits: Arc::new(Semaphore::new(max_permits)),
65 state: Mutex::new(McpQosState::default()),
66 }
67 }
68
69 async fn check_circuit(&self, server_id: &str, tool_name: &str) -> Result<()> {
70 let mut state = self.state.lock().await;
71
72 if let Some(open_until) = state.circuit_open_until {
73 let now = Instant::now();
74 if now < open_until {
75 let remaining = open_until.saturating_duration_since(now).as_millis();
76 return Err(McpError::ToolExecution(format!(
77 "MCP QoS circuit open for server '{}' (tool '{}'), retry in ~{}ms",
78 server_id, tool_name, remaining
79 )));
80 }
81
82 state.circuit_open_until = None;
83 state.consecutive_failures = 0;
84 }
85
86 Ok(())
87 }
88
89 async fn acquire_permit(&self) -> Result<OwnedSemaphorePermit> {
90 self.permits.clone().acquire_owned().await.map_err(|error| {
91 McpError::ToolExecution(format!("MCP QoS permit unavailable: {error}"))
92 })
93 }
94
95 async fn record_success(&self) {
96 let mut state = self.state.lock().await;
97 state.consecutive_failures = 0;
98 state.circuit_open_until = None;
99 }
100
101 async fn record_failure(&self, server_id: &str, tool_name: &str, error: &McpError) {
102 let mut state = self.state.lock().await;
103 state.consecutive_failures = state.consecutive_failures.saturating_add(1);
104
105 if state.consecutive_failures >= self.config.circuit_failure_threshold {
106 state.circuit_open_until =
107 Some(Instant::now() + StdDuration::from_millis(self.config.circuit_open_ms));
108 warn!(
109 "MCP QoS opening circuit for server '{}' after {} consecutive failures (tool '{}', last_error={})",
110 server_id, state.consecutive_failures, tool_name, error
111 );
112 }
113 }
114}
115
116struct ServerRuntime {
118 config: McpServerConfig,
119 client: RwLock<McpProtocolClient>,
120 info: RwLock<RuntimeInfo>,
121 tools: RwLock<Vec<McpTool>>,
122 shutdown: AtomicBool,
123 reconnecting: AtomicBool,
124 qos: McpServerQos,
125 proxy_fingerprint: Option<String>,
128}
129
130pub struct McpServerManager {
132 runtimes: DashMap<String, Arc<ServerRuntime>>,
133 index: Arc<ToolIndex>,
134 event_tx: Option<tokio::sync::mpsc::Sender<McpEvent>>,
135 config: Option<Arc<tokio::sync::RwLock<Config>>>,
136}
137
138impl Clone for McpServerManager {
139 fn clone(&self) -> Self {
140 Self {
141 runtimes: self.runtimes.clone(),
142 index: self.index.clone(),
143 event_tx: self.event_tx.clone(),
144 config: self.config.clone(),
145 }
146 }
147}
148
149impl McpServerManager {
150 pub fn new() -> Self {
151 Self {
152 runtimes: DashMap::new(),
153 index: Arc::new(ToolIndex::new()),
154 event_tx: None,
155 config: None,
156 }
157 }
158
159 pub fn new_with_config(config: Arc<tokio::sync::RwLock<Config>>) -> Self {
161 Self {
162 runtimes: DashMap::new(),
163 index: Arc::new(ToolIndex::new()),
164 event_tx: None,
165 config: Some(config),
166 }
167 }
168
169 pub fn with_event_channel(mut self, tx: tokio::sync::mpsc::Sender<McpEvent>) -> Self {
170 self.event_tx = Some(tx);
171 self
172 }
173
174 pub fn tool_index(&self) -> Arc<ToolIndex> {
175 self.index.clone()
176 }
177
178 pub fn list_servers(&self) -> Vec<String> {
180 self.runtimes
181 .iter()
182 .map(|entry| entry.key().clone())
183 .collect()
184 }
185
186 pub fn get_server_info(&self, server_id: &str) -> Option<RuntimeInfo> {
188 self.runtimes
189 .get(server_id)
190 .and_then(|runtime| runtime.info.try_read().ok().map(|info| info.clone()))
191 }
192
193 pub fn is_server_running(&self, server_id: &str) -> bool {
195 self.runtimes.contains_key(server_id)
196 }
197
198 pub async fn shutdown_all(&self) {
200 let server_ids: Vec<String> = self.list_servers();
201 for server_id in server_ids {
202 if let Err(e) = self.stop_server(&server_id).await {
203 error!("Error stopping server '{}': {}", server_id, e);
204 }
205 }
206 }
207}
208
209impl Default for McpServerManager {
210 fn default() -> Self {
211 Self::new()
212 }
213}