rab/extensions/mcp/
server.rs1use crate::extensions::mcp::types::ServerEntry;
5use async_trait::async_trait;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::Mutex as StdMutex;
9use std::time::Instant;
10use tokio::sync::Mutex;
11use yoagent::mcp::McpClient;
12use yoagent::mcp::McpTransport;
13use yoagent::mcp::types::*;
14
15struct SseHttpTransport {
31 client: reqwest::Client,
32 base_url: String,
33 headers: Vec<(String, String)>,
34 session_id: StdMutex<Option<String>>,
36}
37
38impl SseHttpTransport {
39 fn new(url: &str) -> Self {
40 Self {
41 client: reqwest::Client::new(),
42 base_url: url.trim_end_matches('/').to_string(),
43 headers: Vec::new(),
44 session_id: StdMutex::new(None),
45 }
46 }
47
48 fn with_headers(mut self, headers: Option<&std::collections::HashMap<String, String>>) -> Self {
49 if let Some(h) = headers {
50 for (k, v) in h {
51 self.headers.push((k.clone(), v.clone()));
52 }
53 }
54 self
55 }
56
57 fn parse_sse_response(body: &str) -> Result<JsonRpcResponse, McpError> {
59 if let Ok(r) = serde_json::from_str::<JsonRpcResponse>(body) {
61 return Ok(r);
62 }
63
64 for event in body.split("\n\n") {
66 let event = event.trim();
67 if event.is_empty() {
68 continue;
69 }
70 for line in event.lines() {
72 if let Some(data) = line
73 .strip_prefix("data: ")
74 .or_else(|| line.strip_prefix("data:"))
75 {
76 let data = data.trim();
77 if data.starts_with('{')
78 && let Ok(r) = serde_json::from_str::<JsonRpcResponse>(data)
79 {
80 return Ok(r);
81 }
82 }
83 }
84 }
85
86 Err(McpError::Transport(format!(
87 "Cannot parse SSE response: {}",
88 body.chars().take(200).collect::<String>()
89 )))
90 }
91}
92
93#[async_trait]
94impl McpTransport for SseHttpTransport {
95 async fn send(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse, McpError> {
96 let mut req = self
97 .client
98 .post(&self.base_url)
99 .header("Accept", "application/json, text/event-stream")
101 .json(&request);
102
103 for (k, v) in &self.headers {
104 req = req.header(k.as_str(), v.as_str());
105 }
106
107 if let Ok(guard) = self.session_id.lock()
109 && let Some(ref sid) = *guard
110 {
111 req = req.header("Mcp-Session-Id", sid.as_str());
112 }
113
114 let resp = req
115 .send()
116 .await
117 .map_err(|e| McpError::Transport(format!("HTTP error: {}", e)))?;
118
119 let status = resp.status();
120
121 if let Some(sid) = resp
124 .headers()
125 .get("mcp-session-id")
126 .and_then(|v| v.to_str().ok())
127 .filter(|s| !s.is_empty())
128 && let Ok(mut guard) = self.session_id.lock()
129 && guard.is_none()
130 {
131 *guard = Some(sid.to_string());
132 }
133
134 let body = resp
135 .text()
136 .await
137 .map_err(|e| McpError::Transport(format!("Failed to read response: {}", e)))?;
138
139 if status.is_success() || status == 202 {
140 Self::parse_sse_response(&body)
141 } else {
142 Err(McpError::Transport(format!(
143 "HTTP {} from server: {}",
144 status,
145 body.chars().take(200).collect::<String>()
146 )))
147 }
148 }
149
150 async fn close(&self) -> Result<(), McpError> {
151 Ok(())
152 }
153}
154
155#[derive(Debug, Clone, PartialEq, Eq)]
157pub enum ConnectionStatus {
158 Connected,
160 Idle,
162 Failed,
164}
165
166struct ServerConnection {
168 entry: ServerEntry,
169 client: Option<Arc<Mutex<McpClient>>>,
170 status: ConnectionStatus,
171 last_used: Instant,
172 last_failure: Option<Instant>,
173 config_hash: u64,
174}
175
176pub struct ServerManager {
178 servers: HashMap<String, ServerConnection>,
179 global_idle_timeout: std::time::Duration,
180}
181
182impl ServerManager {
183 pub fn new(global_idle_timeout_minutes: u64) -> Self {
184 Self {
185 servers: HashMap::new(),
186 global_idle_timeout: std::time::Duration::from_secs(global_idle_timeout_minutes * 60),
187 }
188 }
189
190 pub fn register(&mut self, name: &str, entry: ServerEntry, config_hash: u64) {
192 self.servers
193 .entry(name.to_string())
194 .or_insert_with(|| ServerConnection {
195 entry,
196 client: None,
197 status: ConnectionStatus::Idle,
198 last_used: Instant::now(),
199 last_failure: None,
200 config_hash,
201 });
202 }
203
204 pub async fn ensure_connected(&mut self, name: &str) -> bool {
206 if let Some(conn) = self.servers.get(name)
208 && conn.status == ConnectionStatus::Connected
209 && conn.client.is_some()
210 {
211 if let Some(c) = self.servers.get_mut(name) {
213 c.last_used = Instant::now();
214 }
215 return true;
216 }
217
218 let entry = match self.servers.get(name) {
220 Some(e) => e.entry.clone(),
221 None => return false,
222 };
223
224 let client = match &entry.url {
225 Some(url) => {
226 let transport =
228 Box::new(SseHttpTransport::new(url).with_headers(entry.headers.as_ref()));
229 let mut c = McpClient::from_transport(transport);
230 c.initialize().await.map(|_| c)
231 }
232 None => {
233 let env = entry.env.as_ref().cloned();
234 let cmd = entry.command.as_deref().unwrap_or("npx");
235 McpClient::connect_stdio(cmd, &to_str_slice(&entry.args), env).await
236 }
237 };
238
239 match client {
240 Ok(c) => {
241 let c = Arc::new(Mutex::new(c));
242 if let Some(conn) = self.servers.get_mut(name) {
243 conn.client = Some(c);
244 conn.status = ConnectionStatus::Connected;
245 conn.last_used = Instant::now();
246 conn.last_failure = None;
247 }
248 true
249 }
250 Err(e) => {
251 eprintln!("MCP: failed to connect to '{}': {}", name, e);
252 if let Some(conn) = self.servers.get_mut(name) {
253 conn.status = ConnectionStatus::Failed;
254 conn.last_failure = Some(Instant::now());
255 conn.client = None;
256 }
257 false
258 }
259 }
260 }
261
262 pub fn get_client(&self, name: &str) -> Option<Arc<Mutex<McpClient>>> {
264 self.servers.get(name).and_then(|c| c.client.clone())
265 }
266
267 pub fn status(&self, name: &str) -> Option<ConnectionStatus> {
269 self.servers.get(name).map(|c| c.status.clone())
270 }
271
272 pub fn mark_failed(&mut self, name: &str) {
274 if let Some(conn) = self.servers.get_mut(name) {
275 conn.status = ConnectionStatus::Failed;
276 conn.last_failure = Some(Instant::now());
277 conn.client = None;
278 }
279 }
280
281 pub fn touch(&mut self, name: &str) {
283 if let Some(conn) = self.servers.get_mut(name) {
284 conn.last_used = Instant::now();
285 if conn.status == ConnectionStatus::Failed && conn.last_failure.is_some() {
286 let backoff = std::time::Duration::from_secs(60);
287 if conn.last_failure.unwrap().elapsed() > backoff {
288 conn.status = ConnectionStatus::Idle;
289 conn.last_failure = None;
290 }
291 }
292 }
293 }
294
295 pub async fn disconnect(&mut self, name: &str) {
297 if let Some(conn) = self.servers.get_mut(name) {
298 if let Some(ref client) = conn.client {
299 let _ = client.lock().await.close().await;
300 }
301 conn.client = None;
302 conn.status = ConnectionStatus::Idle;
303 }
304 }
305
306 pub async fn close_all(&mut self) {
308 let names: Vec<String> = self.servers.keys().cloned().collect();
309 for name in &names {
310 self.disconnect(name).await;
311 }
312 }
313
314 pub fn idle_timeout(&self, name: &str) -> std::time::Duration {
316 if let Some(conn) = self.servers.get(name) {
317 idle_timeout_for(conn, self.global_idle_timeout)
318 } else {
319 self.global_idle_timeout
320 }
321 }
322
323 pub async fn sweep_idle(&mut self) {
325 let now = Instant::now();
326 let idle_names: Vec<String> = self
327 .servers
328 .iter()
329 .filter(|(_name, conn)| {
330 if conn.status != ConnectionStatus::Connected {
331 return false;
332 }
333 let timeout = idle_timeout_for(conn, self.global_idle_timeout);
334 now.duration_since(conn.last_used) > timeout
335 })
336 .map(|(name, _)| name.clone())
337 .collect();
338
339 for name in &idle_names {
340 self.disconnect(name).await;
341 }
342 }
343
344 pub fn server_names(&self) -> Vec<String> {
346 self.servers.keys().cloned().collect()
347 }
348
349 pub fn should_connect_eagerly(&self, name: &str) -> bool {
351 self.servers
352 .get(name)
353 .is_some_and(|c| matches!(c.entry.lifecycle.as_deref(), Some("eager" | "keep-alive")))
354 }
355
356 pub fn config_hash(&self, name: &str) -> Option<u64> {
358 self.servers.get(name).map(|c| c.config_hash)
359 }
360}
361
362fn to_str_slice(args: &[String]) -> Vec<&str> {
363 args.iter().map(|s| s.as_str()).collect()
364}
365
366fn idle_timeout_for(conn: &ServerConnection, global: std::time::Duration) -> std::time::Duration {
368 if let Some(t) = conn.entry.idle_timeout {
369 return std::time::Duration::from_secs(t * 60);
370 }
371 if conn.entry.lifecycle.as_deref() == Some("keep-alive") {
373 return std::time::Duration::MAX;
374 }
375 global
376}