Skip to main content

rab/extensions/mcp/
server.rs

1//! Server lifecycle manager — lazy connection, idle timeout, keep-alive.
2//! Mirrors pi-mcp-adapter's McpLifecycleManager + McpServerManager pattern.
3
4use crate::extensions::mcp::types::ServerEntry;
5use async_trait::async_trait;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::Mutex as StdMutex;
9use std::time::Instant;
10use tokio::sync::Mutex;
11use yoagent::mcp::McpClient;
12use yoagent::mcp::McpTransport;
13use yoagent::mcp::types::*;
14
15// ---------------------------------------------------------------------------
16// SSE-aware HTTP transport — handles servers that return SSE events (e.g. exa)
17// instead of plain JSON-RPC responses. Falls back to direct JSON parsing
18// for servers that return plain JSON-RPC.
19// ---------------------------------------------------------------------------
20
21/// HTTP transport that handles both SSE (Server-Sent Events) and direct JSON-RPC responses.
22///
23/// Modern MCP servers (exa, etc.) return SSE events like:
24/// ```text
25/// event: message
26/// data: {"jsonrpc":"2.0","result":{...},"id":1}
27///
28/// ```
29/// This transport parses those events and extracts the JSON-RPC response.
30struct SseHttpTransport {
31    client: reqwest::Client,
32    base_url: String,
33    headers: Vec<(String, String)>,
34    /// Session ID returned by the server (Streamable HTTP).
35    session_id: StdMutex<Option<String>>,
36}
37
38impl SseHttpTransport {
39    fn new(url: &str) -> Self {
40        Self {
41            client: reqwest::Client::new(),
42            base_url: url.trim_end_matches('/').to_string(),
43            headers: Vec::new(),
44            session_id: StdMutex::new(None),
45        }
46    }
47
48    fn with_headers(mut self, headers: Option<&std::collections::HashMap<String, String>>) -> Self {
49        if let Some(h) = headers {
50            for (k, v) in h {
51                self.headers.push((k.clone(), v.clone()));
52            }
53        }
54        self
55    }
56
57    /// Parse an SSE response body to extract JSON-RPC responses.
58    fn parse_sse_response(body: &str) -> Result<JsonRpcResponse, McpError> {
59        // Try direct JSON parse first (for old-style HTTP transport)
60        if let Ok(r) = serde_json::from_str::<JsonRpcResponse>(body) {
61            return Ok(r);
62        }
63
64        // SSE format: split by double newlines, look for `data:` lines
65        for event in body.split("\n\n") {
66            let event = event.trim();
67            if event.is_empty() {
68                continue;
69            }
70            // Find the data line
71            for line in event.lines() {
72                if let Some(data) = line
73                    .strip_prefix("data: ")
74                    .or_else(|| line.strip_prefix("data:"))
75                {
76                    let data = data.trim();
77                    if data.starts_with('{')
78                        && let Ok(r) = serde_json::from_str::<JsonRpcResponse>(data)
79                    {
80                        return Ok(r);
81                    }
82                }
83            }
84        }
85
86        Err(McpError::Transport(format!(
87            "Cannot parse SSE response: {}",
88            body.chars().take(200).collect::<String>()
89        )))
90    }
91}
92
93#[async_trait]
94impl McpTransport for SseHttpTransport {
95    async fn send(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse, McpError> {
96        let mut req = self
97            .client
98            .post(&self.base_url)
99            // Streamable HTTP requires the client to accept both formats
100            .header("Accept", "application/json, text/event-stream")
101            .json(&request);
102
103        for (k, v) in &self.headers {
104            req = req.header(k.as_str(), v.as_str());
105        }
106
107        // Include session ID if we have one (Streamable HTTP)
108        if let Ok(guard) = self.session_id.lock()
109            && let Some(ref sid) = *guard
110        {
111            req = req.header("Mcp-Session-Id", sid.as_str());
112        }
113
114        let resp = req
115            .send()
116            .await
117            .map_err(|e| McpError::Transport(format!("HTTP error: {}", e)))?;
118
119        let status = resp.status();
120
121        // Capture session ID from response headers (Streamable HTTP)
122        // reqwest normalizes header names to lowercase
123        if let Some(sid) = resp
124            .headers()
125            .get("mcp-session-id")
126            .and_then(|v| v.to_str().ok())
127            .filter(|s| !s.is_empty())
128            && let Ok(mut guard) = self.session_id.lock()
129            && guard.is_none()
130        {
131            *guard = Some(sid.to_string());
132        }
133
134        let body = resp
135            .text()
136            .await
137            .map_err(|e| McpError::Transport(format!("Failed to read response: {}", e)))?;
138
139        if status.is_success() || status == 202 {
140            Self::parse_sse_response(&body)
141        } else {
142            Err(McpError::Transport(format!(
143                "HTTP {} from server: {}",
144                status,
145                body.chars().take(200).collect::<String>()
146            )))
147        }
148    }
149
150    async fn close(&self) -> Result<(), McpError> {
151        Ok(())
152    }
153}
154
155/// Connection status for a server.
156#[derive(Debug, Clone, PartialEq, Eq)]
157pub enum ConnectionStatus {
158    /// Successfully connected and ready.
159    Connected,
160    /// Disconnected after idle timeout.
161    Idle,
162    /// Connection failed or server unreachable.
163    Failed,
164}
165
166/// A managed server connection.
167struct ServerConnection {
168    entry: ServerEntry,
169    client: Option<Arc<Mutex<McpClient>>>,
170    status: ConnectionStatus,
171    last_used: Instant,
172    last_failure: Option<Instant>,
173    config_hash: u64,
174}
175
176/// Manages all MCP server connections with lazy connection, idle timeout, and health checks.
177pub struct ServerManager {
178    servers: HashMap<String, ServerConnection>,
179    global_idle_timeout: std::time::Duration,
180}
181
182impl ServerManager {
183    pub fn new(global_idle_timeout_minutes: u64) -> Self {
184        Self {
185            servers: HashMap::new(),
186            global_idle_timeout: std::time::Duration::from_secs(global_idle_timeout_minutes * 60),
187        }
188    }
189
190    /// Register a server definition (from config). Does not connect.
191    pub fn register(&mut self, name: &str, entry: ServerEntry, config_hash: u64) {
192        self.servers
193            .entry(name.to_string())
194            .or_insert_with(|| ServerConnection {
195                entry,
196                client: None,
197                status: ConnectionStatus::Idle,
198                last_used: Instant::now(),
199                last_failure: None,
200                config_hash,
201            });
202    }
203
204    /// Ensure a server is connected (lazy connect). Returns true if connected/available.
205    pub async fn ensure_connected(&mut self, name: &str) -> bool {
206        // Check if we have a cached connection that's still alive
207        if let Some(conn) = self.servers.get(name)
208            && conn.status == ConnectionStatus::Connected
209            && conn.client.is_some()
210        {
211            // Touch last_used so idle timer resets
212            if let Some(c) = self.servers.get_mut(name) {
213                c.last_used = Instant::now();
214            }
215            return true;
216        }
217
218        // Need to connect
219        let entry = match self.servers.get(name) {
220            Some(e) => e.entry.clone(),
221            None => return false,
222        };
223
224        let client = match &entry.url {
225            Some(url) => {
226                // Use SSE-aware HTTP transport instead of the plain yoagent one
227                let transport =
228                    Box::new(SseHttpTransport::new(url).with_headers(entry.headers.as_ref()));
229                let mut c = McpClient::from_transport(transport);
230                c.initialize().await.map(|_| c)
231            }
232            None => {
233                let env = entry.env.as_ref().cloned();
234                let cmd = entry.command.as_deref().unwrap_or("npx");
235                McpClient::connect_stdio(cmd, &to_str_slice(&entry.args), env).await
236            }
237        };
238
239        match client {
240            Ok(c) => {
241                let c = Arc::new(Mutex::new(c));
242                if let Some(conn) = self.servers.get_mut(name) {
243                    conn.client = Some(c);
244                    conn.status = ConnectionStatus::Connected;
245                    conn.last_used = Instant::now();
246                    conn.last_failure = None;
247                }
248                true
249            }
250            Err(e) => {
251                eprintln!("MCP: failed to connect to '{}': {}", name, e);
252                if let Some(conn) = self.servers.get_mut(name) {
253                    conn.status = ConnectionStatus::Failed;
254                    conn.last_failure = Some(Instant::now());
255                    conn.client = None;
256                }
257                false
258            }
259        }
260    }
261
262    /// Get a connected client for a server (must call ensure_connected first).
263    pub fn get_client(&self, name: &str) -> Option<Arc<Mutex<McpClient>>> {
264        self.servers.get(name).and_then(|c| c.client.clone())
265    }
266
267    /// Get the connection status for a server.
268    pub fn status(&self, name: &str) -> Option<ConnectionStatus> {
269        self.servers.get(name).map(|c| c.status.clone())
270    }
271
272    /// Mark a connection as failed after a tool call error.
273    pub fn mark_failed(&mut self, name: &str) {
274        if let Some(conn) = self.servers.get_mut(name) {
275            conn.status = ConnectionStatus::Failed;
276            conn.last_failure = Some(Instant::now());
277            conn.client = None;
278        }
279    }
280
281    /// Touch a server (update last_used timestamp, e.g. after successful tool call).
282    pub fn touch(&mut self, name: &str) {
283        if let Some(conn) = self.servers.get_mut(name) {
284            conn.last_used = Instant::now();
285            if conn.status == ConnectionStatus::Failed && conn.last_failure.is_some() {
286                let backoff = std::time::Duration::from_secs(60);
287                if conn.last_failure.unwrap().elapsed() > backoff {
288                    conn.status = ConnectionStatus::Idle;
289                    conn.last_failure = None;
290                }
291            }
292        }
293    }
294
295    /// Disconnect a server (idle shutdown).
296    pub async fn disconnect(&mut self, name: &str) {
297        if let Some(conn) = self.servers.get_mut(name) {
298            if let Some(ref client) = conn.client {
299                let _ = client.lock().await.close().await;
300            }
301            conn.client = None;
302            conn.status = ConnectionStatus::Idle;
303        }
304    }
305
306    /// Close all connections (on session shutdown).
307    pub async fn close_all(&mut self) {
308        let names: Vec<String> = self.servers.keys().cloned().collect();
309        for name in &names {
310            self.disconnect(name).await;
311        }
312    }
313
314    /// Get the idle timeout for a server (per-server override or global default).
315    pub fn idle_timeout(&self, name: &str) -> std::time::Duration {
316        if let Some(conn) = self.servers.get(name) {
317            idle_timeout_for(conn, self.global_idle_timeout)
318        } else {
319            self.global_idle_timeout
320        }
321    }
322
323    /// Check for idle servers and disconnect them.
324    pub async fn sweep_idle(&mut self) {
325        let now = Instant::now();
326        let idle_names: Vec<String> = self
327            .servers
328            .iter()
329            .filter(|(_name, conn)| {
330                if conn.status != ConnectionStatus::Connected {
331                    return false;
332                }
333                let timeout = idle_timeout_for(conn, self.global_idle_timeout);
334                now.duration_since(conn.last_used) > timeout
335            })
336            .map(|(name, _)| name.clone())
337            .collect();
338
339        for name in &idle_names {
340            self.disconnect(name).await;
341        }
342    }
343
344    /// Get a list of all registered server names.
345    pub fn server_names(&self) -> Vec<String> {
346        self.servers.keys().cloned().collect()
347    }
348
349    /// Check if a server should be connected eagerly at startup.
350    pub fn should_connect_eagerly(&self, name: &str) -> bool {
351        self.servers
352            .get(name)
353            .is_some_and(|c| matches!(c.entry.lifecycle.as_deref(), Some("eager" | "keep-alive")))
354    }
355
356    /// Get the config hash for a server.
357    pub fn config_hash(&self, name: &str) -> Option<u64> {
358        self.servers.get(name).map(|c| c.config_hash)
359    }
360}
361
362fn to_str_slice(args: &[String]) -> Vec<&str> {
363    args.iter().map(|s| s.as_str()).collect()
364}
365
366/// Compute idle timeout for a server connection.
367fn idle_timeout_for(conn: &ServerConnection, global: std::time::Duration) -> std::time::Duration {
368    if let Some(t) = conn.entry.idle_timeout {
369        return std::time::Duration::from_secs(t * 60);
370    }
371    // keep-alive servers have no idle timeout
372    if conn.entry.lifecycle.as_deref() == Some("keep-alive") {
373        return std::time::Duration::MAX;
374    }
375    global
376}