Skip to main content

rab/extensions/mcp/
server.rs

1//! Server lifecycle manager — lazy connection, idle timeout, keep-alive.
2//! Mirrors pi-mcp-adapter's McpLifecycleManager + McpServerManager pattern.
3
4use crate::extensions::mcp::types::ServerEntry;
5use async_trait::async_trait;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::Mutex as StdMutex;
9use std::time::Instant;
10use tokio::sync::Mutex;
11use yoagent::mcp::McpClient;
12use yoagent::mcp::McpTransport;
13use yoagent::mcp::types::*;
14
15// ---------------------------------------------------------------------------
16// SSE-aware HTTP transport — handles servers that return SSE events (e.g. exa)
17// instead of plain JSON-RPC responses. Falls back to direct JSON parsing
18// for servers that return plain JSON-RPC.
19// ---------------------------------------------------------------------------
20
21/// HTTP transport that handles both SSE (Server-Sent Events) and direct JSON-RPC responses.
22///
23/// Modern MCP servers (exa, etc.) return SSE events like:
24/// ```text
25/// event: message
26/// data: {"jsonrpc":"2.0","result":{...},"id":1}
27///
28/// ```
29/// This transport parses those events and extracts the JSON-RPC response.
30struct SseHttpTransport {
31    client: reqwest::Client,
32    base_url: String,
33    headers: Vec<(String, String)>,
34    /// Session ID returned by the server (Streamable HTTP).
35    session_id: StdMutex<Option<String>>,
36}
37
38impl SseHttpTransport {
39    fn new(url: &str) -> Self {
40        Self {
41            client: reqwest::Client::new(),
42            base_url: url.trim_end_matches('/').to_string(),
43            headers: Vec::new(),
44            session_id: StdMutex::new(None),
45        }
46    }
47
48    fn with_headers(mut self, headers: Option<&std::collections::HashMap<String, String>>) -> Self {
49        if let Some(h) = headers {
50            for (k, v) in h {
51                self.headers.push((k.clone(), v.clone()));
52            }
53        }
54        self
55    }
56
57    /// Parse an SSE response body to extract JSON-RPC responses.
58    fn parse_sse_response(body: &str) -> Result<JsonRpcResponse, McpError> {
59        // Try direct JSON parse first (for old-style HTTP transport)
60        if let Ok(r) = serde_json::from_str::<JsonRpcResponse>(body) {
61            return Ok(r);
62        }
63
64        // SSE format: split by double newlines, look for `data:` lines
65        for event in body.split("\n\n") {
66            let event = event.trim();
67            if event.is_empty() {
68                continue;
69            }
70            // Find the data line
71            for line in event.lines() {
72                if let Some(data) = line
73                    .strip_prefix("data: ")
74                    .or_else(|| line.strip_prefix("data:"))
75                {
76                    let data = data.trim();
77                    if data.starts_with('{')
78                        && let Ok(r) = serde_json::from_str::<JsonRpcResponse>(data)
79                    {
80                        return Ok(r);
81                    }
82                }
83            }
84        }
85
86        Err(McpError::Transport(format!(
87            "Cannot parse SSE response: {}",
88            body.chars().take(200).collect::<String>()
89        )))
90    }
91}
92
93#[async_trait]
94impl McpTransport for SseHttpTransport {
95    async fn send(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse, McpError> {
96        let mut req = self
97            .client
98            .post(&self.base_url)
99            // Streamable HTTP requires the client to accept both formats
100            .header("Accept", "application/json, text/event-stream")
101            .json(&request);
102
103        for (k, v) in &self.headers {
104            req = req.header(k.as_str(), v.as_str());
105        }
106
107        // Include session ID if we have one (Streamable HTTP)
108        if let Ok(guard) = self.session_id.lock()
109            && let Some(ref sid) = *guard
110        {
111            req = req.header("Mcp-Session-Id", sid.as_str());
112        }
113
114        let resp = req
115            .send()
116            .await
117            .map_err(|e| McpError::Transport(format!("HTTP error: {}", e)))?;
118
119        let status = resp.status();
120
121        // Capture session ID from response headers (Streamable HTTP)
122        // reqwest normalizes header names to lowercase
123        if let Some(sid) = resp
124            .headers()
125            .get("mcp-session-id")
126            .and_then(|v| v.to_str().ok())
127            .filter(|s| !s.is_empty())
128            && let Ok(mut guard) = self.session_id.lock()
129            && guard.is_none()
130        {
131            *guard = Some(sid.to_string());
132        }
133
134        let body = resp
135            .text()
136            .await
137            .map_err(|e| McpError::Transport(format!("Failed to read response: {}", e)))?;
138
139        if status.is_success() || status == 202 {
140            Self::parse_sse_response(&body)
141        } else {
142            Err(McpError::Transport(format!(
143                "HTTP {} from server: {}",
144                status,
145                body.chars().take(200).collect::<String>()
146            )))
147        }
148    }
149
150    async fn close(&self) -> Result<(), McpError> {
151        Ok(())
152    }
153}
154
155/// Connection status for a server.
156#[derive(Debug, Clone, PartialEq, Eq)]
157pub enum ConnectionStatus {
158    /// Successfully connected and ready.
159    Connected,
160    /// Disconnected after idle timeout.
161    Idle,
162    /// Connection failed or server unreachable.
163    Failed,
164}
165
166/// A managed server connection.
167struct ServerConnection {
168    entry: ServerEntry,
169    client: Option<Arc<Mutex<McpClient>>>,
170    status: ConnectionStatus,
171    last_used: Instant,
172    last_failure: Option<Instant>,
173    config_hash: u64,
174}
175
176/// Manages all MCP server connections with lazy connection, idle timeout, and health checks.
177pub struct ServerManager {
178    servers: HashMap<String, ServerConnection>,
179    global_idle_timeout: std::time::Duration,
180}
181
182impl ServerManager {
183    pub fn new(global_idle_timeout_minutes: u64) -> Self {
184        Self {
185            servers: HashMap::new(),
186            global_idle_timeout: std::time::Duration::from_secs(global_idle_timeout_minutes * 60),
187        }
188    }
189
190    /// Register or update a server definition (from config). Does not connect.
191    /// If the server already exists, its entry is replaced and the old connection
192    /// is dropped so that next use reconnects with the new config.
193    pub fn register(&mut self, name: &str, entry: ServerEntry, config_hash: u64) {
194        if let Some(conn) = self.servers.get_mut(name) {
195            // Update existing server: replace entry and drop old client so it
196            // reconnects lazily with the new config on next use.
197            conn.entry = entry;
198            conn.config_hash = config_hash;
199            conn.client = None;
200            conn.status = ConnectionStatus::Idle;
201            conn.last_failure = None;
202        } else {
203            self.servers.insert(
204                name.to_string(),
205                ServerConnection {
206                    entry,
207                    client: None,
208                    status: ConnectionStatus::Idle,
209                    last_used: Instant::now(),
210                    last_failure: None,
211                    config_hash,
212                },
213            );
214        }
215    }
216
217    /// Ensure a server is connected (lazy connect). Returns true if connected/available.
218    pub async fn ensure_connected(&mut self, name: &str) -> bool {
219        // Check if we have a cached connection that's still alive
220        if let Some(conn) = self.servers.get(name)
221            && conn.status == ConnectionStatus::Connected
222            && conn.client.is_some()
223        {
224            // Touch last_used so idle timer resets
225            if let Some(c) = self.servers.get_mut(name) {
226                c.last_used = Instant::now();
227            }
228            return true;
229        }
230
231        // Need to connect
232        let entry = match self.servers.get(name) {
233            Some(e) => e.entry.clone(),
234            None => return false,
235        };
236
237        let client = match &entry.url {
238            Some(url) => {
239                // Use SSE-aware HTTP transport instead of the plain yoagent one
240                let transport =
241                    Box::new(SseHttpTransport::new(url).with_headers(entry.headers.as_ref()));
242                let mut c = McpClient::from_transport(transport);
243                c.initialize().await.map(|_| c)
244            }
245            None => {
246                let env = entry.env.as_ref().cloned();
247                let cmd = entry.command.as_deref().unwrap_or("npx");
248                McpClient::connect_stdio(cmd, &to_str_slice(&entry.args), env).await
249            }
250        };
251
252        match client {
253            Ok(c) => {
254                let c = Arc::new(Mutex::new(c));
255                if let Some(conn) = self.servers.get_mut(name) {
256                    conn.client = Some(c);
257                    conn.status = ConnectionStatus::Connected;
258                    conn.last_used = Instant::now();
259                    conn.last_failure = None;
260                }
261                true
262            }
263            Err(e) => {
264                eprintln!("MCP: failed to connect to '{}': {}", name, e);
265                if let Some(conn) = self.servers.get_mut(name) {
266                    conn.status = ConnectionStatus::Failed;
267                    conn.last_failure = Some(Instant::now());
268                    conn.client = None;
269                }
270                false
271            }
272        }
273    }
274
275    /// Get a connected client for a server (must call ensure_connected first).
276    pub fn get_client(&self, name: &str) -> Option<Arc<Mutex<McpClient>>> {
277        self.servers.get(name).and_then(|c| c.client.clone())
278    }
279
280    /// Get the connection status for a server.
281    pub fn status(&self, name: &str) -> Option<ConnectionStatus> {
282        self.servers.get(name).map(|c| c.status.clone())
283    }
284
285    /// Mark a connection as failed after a tool call error.
286    pub fn mark_failed(&mut self, name: &str) {
287        if let Some(conn) = self.servers.get_mut(name) {
288            conn.status = ConnectionStatus::Failed;
289            conn.last_failure = Some(Instant::now());
290            conn.client = None;
291        }
292    }
293
294    /// Touch a server (update last_used timestamp, e.g. after successful tool call).
295    pub fn touch(&mut self, name: &str) {
296        if let Some(conn) = self.servers.get_mut(name) {
297            conn.last_used = Instant::now();
298            if conn.status == ConnectionStatus::Failed && conn.last_failure.is_some() {
299                let backoff = std::time::Duration::from_secs(60);
300                if conn.last_failure.unwrap().elapsed() > backoff {
301                    conn.status = ConnectionStatus::Idle;
302                    conn.last_failure = None;
303                }
304            }
305        }
306    }
307
308    /// Disconnect a server (idle shutdown).
309    pub async fn disconnect(&mut self, name: &str) {
310        if let Some(conn) = self.servers.get_mut(name) {
311            if let Some(ref client) = conn.client {
312                let _ = client.lock().await.close().await;
313            }
314            conn.client = None;
315            conn.status = ConnectionStatus::Idle;
316        }
317    }
318
319    /// Close all connections (on session shutdown).
320    pub async fn close_all(&mut self) {
321        let names: Vec<String> = self.servers.keys().cloned().collect();
322        for name in &names {
323            self.disconnect(name).await;
324        }
325    }
326
327    /// Get the idle timeout for a server (per-server override or global default).
328    pub fn idle_timeout(&self, name: &str) -> std::time::Duration {
329        if let Some(conn) = self.servers.get(name) {
330            idle_timeout_for(conn, self.global_idle_timeout)
331        } else {
332            self.global_idle_timeout
333        }
334    }
335
336    /// Check for idle servers and disconnect them.
337    pub async fn sweep_idle(&mut self) {
338        let now = Instant::now();
339        let idle_names: Vec<String> = self
340            .servers
341            .iter()
342            .filter(|(_name, conn)| {
343                if conn.status != ConnectionStatus::Connected {
344                    return false;
345                }
346                let timeout = idle_timeout_for(conn, self.global_idle_timeout);
347                now.duration_since(conn.last_used) > timeout
348            })
349            .map(|(name, _)| name.clone())
350            .collect();
351
352        for name in &idle_names {
353            self.disconnect(name).await;
354        }
355    }
356
357    /// Get a list of all registered server names.
358    pub fn server_names(&self) -> Vec<String> {
359        self.servers.keys().cloned().collect()
360    }
361
362    /// Synchronously remove a server entry, dropping any existing connection.
363    /// The client Arc is dropped; in-flight calls holding a clone of the Arc
364    /// can still complete.
365    pub fn remove(&mut self, name: &str) {
366        self.servers.remove(name);
367    }
368
369    /// Check if a server should be connected eagerly at startup.
370    pub fn should_connect_eagerly(&self, name: &str) -> bool {
371        self.servers
372            .get(name)
373            .is_some_and(|c| matches!(c.entry.lifecycle.as_deref(), Some("eager" | "keep-alive")))
374    }
375
376    /// Get the config hash for a server.
377    pub fn config_hash(&self, name: &str) -> Option<u64> {
378        self.servers.get(name).map(|c| c.config_hash)
379    }
380}
381
382fn to_str_slice(args: &[String]) -> Vec<&str> {
383    args.iter().map(|s| s.as_str()).collect()
384}
385
386/// Compute idle timeout for a server connection.
387fn idle_timeout_for(conn: &ServerConnection, global: std::time::Duration) -> std::time::Duration {
388    if let Some(t) = conn.entry.idle_timeout {
389        return std::time::Duration::from_secs(t * 60);
390    }
391    // keep-alive servers have no idle timeout
392    if conn.entry.lifecycle.as_deref() == Some("keep-alive") {
393        return std::time::Duration::MAX;
394    }
395    global
396}