rab/extensions/mcp/
server.rs1use crate::extensions::mcp::types::ServerEntry;
5use async_trait::async_trait;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::Mutex as StdMutex;
9use std::time::Instant;
10use tokio::sync::Mutex;
11use yoagent::mcp::McpClient;
12use yoagent::mcp::McpTransport;
13use yoagent::mcp::types::*;
14
15struct SseHttpTransport {
31 client: reqwest::Client,
32 base_url: String,
33 headers: Vec<(String, String)>,
34 session_id: StdMutex<Option<String>>,
36}
37
38impl SseHttpTransport {
39 fn new(url: &str) -> Self {
40 Self {
41 client: reqwest::Client::new(),
42 base_url: url.trim_end_matches('/').to_string(),
43 headers: Vec::new(),
44 session_id: StdMutex::new(None),
45 }
46 }
47
48 fn with_headers(mut self, headers: Option<&std::collections::HashMap<String, String>>) -> Self {
49 if let Some(h) = headers {
50 for (k, v) in h {
51 self.headers.push((k.clone(), v.clone()));
52 }
53 }
54 self
55 }
56
57 fn parse_sse_response(body: &str) -> Result<JsonRpcResponse, McpError> {
59 if let Ok(r) = serde_json::from_str::<JsonRpcResponse>(body) {
61 return Ok(r);
62 }
63
64 for event in body.split("\n\n") {
66 let event = event.trim();
67 if event.is_empty() {
68 continue;
69 }
70 for line in event.lines() {
72 if let Some(data) = line
73 .strip_prefix("data: ")
74 .or_else(|| line.strip_prefix("data:"))
75 {
76 let data = data.trim();
77 if data.starts_with('{')
78 && let Ok(r) = serde_json::from_str::<JsonRpcResponse>(data)
79 {
80 return Ok(r);
81 }
82 }
83 }
84 }
85
86 Err(McpError::Transport(format!(
87 "Cannot parse SSE response: {}",
88 body.chars().take(200).collect::<String>()
89 )))
90 }
91}
92
93#[async_trait]
94impl McpTransport for SseHttpTransport {
95 async fn send(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse, McpError> {
96 let mut req = self
97 .client
98 .post(&self.base_url)
99 .header("Accept", "application/json, text/event-stream")
101 .json(&request);
102
103 for (k, v) in &self.headers {
104 req = req.header(k.as_str(), v.as_str());
105 }
106
107 if let Ok(guard) = self.session_id.lock()
109 && let Some(ref sid) = *guard
110 {
111 req = req.header("Mcp-Session-Id", sid.as_str());
112 }
113
114 let resp = req
115 .send()
116 .await
117 .map_err(|e| McpError::Transport(format!("HTTP error: {}", e)))?;
118
119 let status = resp.status();
120
121 if let Some(sid) = resp
124 .headers()
125 .get("mcp-session-id")
126 .and_then(|v| v.to_str().ok())
127 .filter(|s| !s.is_empty())
128 && let Ok(mut guard) = self.session_id.lock()
129 && guard.is_none()
130 {
131 *guard = Some(sid.to_string());
132 }
133
134 let body = resp
135 .text()
136 .await
137 .map_err(|e| McpError::Transport(format!("Failed to read response: {}", e)))?;
138
139 if status.is_success() || status == 202 {
140 Self::parse_sse_response(&body)
141 } else {
142 Err(McpError::Transport(format!(
143 "HTTP {} from server: {}",
144 status,
145 body.chars().take(200).collect::<String>()
146 )))
147 }
148 }
149
150 async fn close(&self) -> Result<(), McpError> {
151 Ok(())
152 }
153}
154
155#[derive(Debug, Clone, PartialEq, Eq)]
157pub enum ConnectionStatus {
158 Connected,
160 Idle,
162 Failed,
164}
165
166struct ServerConnection {
168 entry: ServerEntry,
169 client: Option<Arc<Mutex<McpClient>>>,
170 status: ConnectionStatus,
171 last_used: Instant,
172 last_failure: Option<Instant>,
173 config_hash: u64,
174}
175
176pub struct ServerManager {
178 servers: HashMap<String, ServerConnection>,
179 global_idle_timeout: std::time::Duration,
180}
181
182impl ServerManager {
183 pub fn new(global_idle_timeout_minutes: u64) -> Self {
184 Self {
185 servers: HashMap::new(),
186 global_idle_timeout: std::time::Duration::from_secs(global_idle_timeout_minutes * 60),
187 }
188 }
189
190 pub fn register(&mut self, name: &str, entry: ServerEntry, config_hash: u64) {
194 if let Some(conn) = self.servers.get_mut(name) {
195 conn.entry = entry;
198 conn.config_hash = config_hash;
199 conn.client = None;
200 conn.status = ConnectionStatus::Idle;
201 conn.last_failure = None;
202 } else {
203 self.servers.insert(
204 name.to_string(),
205 ServerConnection {
206 entry,
207 client: None,
208 status: ConnectionStatus::Idle,
209 last_used: Instant::now(),
210 last_failure: None,
211 config_hash,
212 },
213 );
214 }
215 }
216
217 pub async fn ensure_connected(&mut self, name: &str) -> bool {
219 if let Some(conn) = self.servers.get(name)
221 && conn.status == ConnectionStatus::Connected
222 && conn.client.is_some()
223 {
224 if let Some(c) = self.servers.get_mut(name) {
226 c.last_used = Instant::now();
227 }
228 return true;
229 }
230
231 let entry = match self.servers.get(name) {
233 Some(e) => e.entry.clone(),
234 None => return false,
235 };
236
237 let client = match &entry.url {
238 Some(url) => {
239 let transport =
241 Box::new(SseHttpTransport::new(url).with_headers(entry.headers.as_ref()));
242 let mut c = McpClient::from_transport(transport);
243 c.initialize().await.map(|_| c)
244 }
245 None => {
246 let env = entry.env.as_ref().cloned();
247 let cmd = entry.command.as_deref().unwrap_or("npx");
248 McpClient::connect_stdio(cmd, &to_str_slice(&entry.args), env).await
249 }
250 };
251
252 match client {
253 Ok(c) => {
254 let c = Arc::new(Mutex::new(c));
255 if let Some(conn) = self.servers.get_mut(name) {
256 conn.client = Some(c);
257 conn.status = ConnectionStatus::Connected;
258 conn.last_used = Instant::now();
259 conn.last_failure = None;
260 }
261 true
262 }
263 Err(e) => {
264 eprintln!("MCP: failed to connect to '{}': {}", name, e);
265 if let Some(conn) = self.servers.get_mut(name) {
266 conn.status = ConnectionStatus::Failed;
267 conn.last_failure = Some(Instant::now());
268 conn.client = None;
269 }
270 false
271 }
272 }
273 }
274
275 pub fn get_client(&self, name: &str) -> Option<Arc<Mutex<McpClient>>> {
277 self.servers.get(name).and_then(|c| c.client.clone())
278 }
279
280 pub fn status(&self, name: &str) -> Option<ConnectionStatus> {
282 self.servers.get(name).map(|c| c.status.clone())
283 }
284
285 pub fn mark_failed(&mut self, name: &str) {
287 if let Some(conn) = self.servers.get_mut(name) {
288 conn.status = ConnectionStatus::Failed;
289 conn.last_failure = Some(Instant::now());
290 conn.client = None;
291 }
292 }
293
294 pub fn touch(&mut self, name: &str) {
296 if let Some(conn) = self.servers.get_mut(name) {
297 conn.last_used = Instant::now();
298 if conn.status == ConnectionStatus::Failed && conn.last_failure.is_some() {
299 let backoff = std::time::Duration::from_secs(60);
300 if conn.last_failure.unwrap().elapsed() > backoff {
301 conn.status = ConnectionStatus::Idle;
302 conn.last_failure = None;
303 }
304 }
305 }
306 }
307
308 pub async fn disconnect(&mut self, name: &str) {
310 if let Some(conn) = self.servers.get_mut(name) {
311 if let Some(ref client) = conn.client {
312 let _ = client.lock().await.close().await;
313 }
314 conn.client = None;
315 conn.status = ConnectionStatus::Idle;
316 }
317 }
318
319 pub async fn close_all(&mut self) {
321 let names: Vec<String> = self.servers.keys().cloned().collect();
322 for name in &names {
323 self.disconnect(name).await;
324 }
325 }
326
327 pub fn idle_timeout(&self, name: &str) -> std::time::Duration {
329 if let Some(conn) = self.servers.get(name) {
330 idle_timeout_for(conn, self.global_idle_timeout)
331 } else {
332 self.global_idle_timeout
333 }
334 }
335
336 pub async fn sweep_idle(&mut self) {
338 let now = Instant::now();
339 let idle_names: Vec<String> = self
340 .servers
341 .iter()
342 .filter(|(_name, conn)| {
343 if conn.status != ConnectionStatus::Connected {
344 return false;
345 }
346 let timeout = idle_timeout_for(conn, self.global_idle_timeout);
347 now.duration_since(conn.last_used) > timeout
348 })
349 .map(|(name, _)| name.clone())
350 .collect();
351
352 for name in &idle_names {
353 self.disconnect(name).await;
354 }
355 }
356
357 pub fn server_names(&self) -> Vec<String> {
359 self.servers.keys().cloned().collect()
360 }
361
362 pub fn remove(&mut self, name: &str) {
366 self.servers.remove(name);
367 }
368
369 pub fn should_connect_eagerly(&self, name: &str) -> bool {
371 self.servers
372 .get(name)
373 .is_some_and(|c| matches!(c.entry.lifecycle.as_deref(), Some("eager" | "keep-alive")))
374 }
375
376 pub fn config_hash(&self, name: &str) -> Option<u64> {
378 self.servers.get(name).map(|c| c.config_hash)
379 }
380}
381
382fn to_str_slice(args: &[String]) -> Vec<&str> {
383 args.iter().map(|s| s.as_str()).collect()
384}
385
386fn idle_timeout_for(conn: &ServerConnection, global: std::time::Duration) -> std::time::Duration {
388 if let Some(t) = conn.entry.idle_timeout {
389 return std::time::Duration::from_secs(t * 60);
390 }
391 if conn.entry.lifecycle.as_deref() == Some("keep-alive") {
393 return std::time::Duration::MAX;
394 }
395 global
396}