Skip to main content

construct/tools/
mcp_transport.rs

1//! MCP transport abstraction — supports stdio, SSE, and HTTP transports.
2
3use std::borrow::Cow;
4
5use anyhow::{Context, Result, anyhow, bail};
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::process::{Child, Command};
8use tokio::sync::{Mutex, Notify, oneshot};
9use tokio::time::{Duration, timeout};
10use tokio_stream::StreamExt;
11
12use crate::config::schema::{McpServerConfig, McpTransport};
13use crate::tools::mcp_protocol::{INTERNAL_ERROR, JsonRpcError, JsonRpcRequest, JsonRpcResponse};
14
15/// Maximum bytes for a single JSON-RPC response.
16const MAX_LINE_BYTES: usize = 4 * 1024 * 1024; // 4 MB
17
18/// Timeout for init/list operations during MCP handshake.
19/// Tool calls do NOT use this — they rely on the outer per-tool timeout
20/// in mcp_client::call_tool (default 180s, max 600s).
21const RECV_TIMEOUT_SECS: u64 = 30;
22
23/// Streamable HTTP Accept header required by MCP HTTP transport.
24const MCP_STREAMABLE_ACCEPT: &str = "application/json, text/event-stream";
25
26/// Default media type for MCP JSON-RPC request bodies.
27const MCP_JSON_CONTENT_TYPE: &str = "application/json";
28/// Streamable HTTP session header used to preserve MCP server state.
29const MCP_SESSION_ID_HEADER: &str = "Mcp-Session-Id";
30
31// ── Transport Trait ──────────────────────────────────────────────────────
32
33/// Abstract transport for MCP communication.
34#[async_trait::async_trait]
35pub trait McpTransportConn: Send + Sync {
36    /// Send a JSON-RPC request and receive the response.
37    async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse>;
38
39    /// Close the connection.
40    async fn close(&mut self) -> Result<()>;
41}
42
43// ── Stdio Transport ──────────────────────────────────────────────────────
44
45/// Stdio-based transport (spawn local process).
46pub struct StdioTransport {
47    _child: Child,
48    stdin: tokio::process::ChildStdin,
49    stdout_lines: tokio::io::Lines<BufReader<tokio::process::ChildStdout>>,
50}
51
52impl StdioTransport {
53    pub fn new(config: &McpServerConfig) -> Result<Self> {
54        let mut child = Command::new(&config.command)
55            .args(&config.args)
56            .envs(&config.env)
57            .stdin(std::process::Stdio::piped())
58            .stdout(std::process::Stdio::piped())
59            .stderr(std::process::Stdio::inherit())
60            .kill_on_drop(true)
61            .spawn()
62            .with_context(|| format!("failed to spawn MCP server `{}`", config.name))?;
63
64        let stdin = child
65            .stdin
66            .take()
67            .ok_or_else(|| anyhow!("no stdin on MCP server `{}`", config.name))?;
68        let stdout = child
69            .stdout
70            .take()
71            .ok_or_else(|| anyhow!("no stdout on MCP server `{}`", config.name))?;
72        let stdout_lines = BufReader::new(stdout).lines();
73
74        Ok(Self {
75            _child: child,
76            stdin,
77            stdout_lines,
78        })
79    }
80
81    async fn send_raw(&mut self, line: &str) -> Result<()> {
82        self.stdin
83            .write_all(line.as_bytes())
84            .await
85            .context("failed to write to MCP server stdin")?;
86        self.stdin
87            .write_all(b"\n")
88            .await
89            .context("failed to write newline to MCP server stdin")?;
90        self.stdin.flush().await.context("failed to flush stdin")?;
91        Ok(())
92    }
93
94    async fn recv_raw(&mut self) -> Result<String> {
95        let line = self
96            .stdout_lines
97            .next_line()
98            .await?
99            .ok_or_else(|| anyhow!("MCP server closed stdout"))?;
100        if line.len() > MAX_LINE_BYTES {
101            bail!("MCP response too large: {} bytes", line.len());
102        }
103        Ok(line)
104    }
105}
106
107#[async_trait::async_trait]
108impl McpTransportConn for StdioTransport {
109    async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
110        let line = serde_json::to_string(request)?;
111        self.send_raw(&line).await?;
112        if request.id.is_none() {
113            return Ok(JsonRpcResponse {
114                jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
115                id: None,
116                result: None,
117                error: None,
118            });
119        }
120        // No internal deadline — the caller (call_tool or connect) wraps this
121        // in its own timeout.  This avoids racing with the Python handler's
122        // poll timeout and eliminates orphaned responses that cause cross-wiring.
123        loop {
124            let resp_line = self.recv_raw().await?;
125            let resp: JsonRpcResponse = serde_json::from_str(&resp_line)
126                .with_context(|| format!("invalid JSON-RPC response: {}", resp_line))?;
127            if resp.id.is_none() {
128                tracing::debug!(
129                    "MCP stdio: skipping server notification while waiting for response"
130                );
131                continue;
132            }
133            // JSON-RPC 2.0 requires response id to match request id.
134            // Stale responses from timed-out requests may linger in the pipe —
135            // discard them and keep waiting for ours.
136            if resp.id != request.id {
137                tracing::warn!(
138                    "MCP stdio: discarding response with mismatched id \
139                     (got {:?}, expected {:?}) — likely stale from a timed-out request",
140                    resp.id,
141                    request.id
142                );
143                continue;
144            }
145            return Ok(resp);
146        }
147    }
148
149    async fn close(&mut self) -> Result<()> {
150        let _ = self.stdin.shutdown().await;
151        Ok(())
152    }
153}
154
155// ── HTTP Transport ───────────────────────────────────────────────────────
156
157/// HTTP-based transport (POST requests).
158pub struct HttpTransport {
159    url: String,
160    client: reqwest::Client,
161    headers: std::collections::HashMap<String, String>,
162    session_id: Option<String>,
163}
164
165impl HttpTransport {
166    pub fn new(config: &McpServerConfig) -> Result<Self> {
167        let url = config
168            .url
169            .as_ref()
170            .ok_or_else(|| anyhow!("URL required for HTTP transport"))?
171            .clone();
172
173        let client = reqwest::Client::builder()
174            .timeout(Duration::from_secs(120))
175            .build()
176            .context("failed to build HTTP client")?;
177
178        Ok(Self {
179            url,
180            client,
181            headers: config.headers.clone(),
182            session_id: None,
183        })
184    }
185
186    fn apply_session_header(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
187        if let Some(session_id) = self.session_id.as_deref() {
188            req.header(MCP_SESSION_ID_HEADER, session_id)
189        } else {
190            req
191        }
192    }
193
194    fn update_session_id_from_headers(&mut self, headers: &reqwest::header::HeaderMap) {
195        if let Some(session_id) = headers
196            .get(MCP_SESSION_ID_HEADER)
197            .and_then(|v| v.to_str().ok())
198            .map(str::trim)
199            .filter(|v| !v.is_empty())
200        {
201            self.session_id = Some(session_id.to_string());
202        }
203    }
204}
205
206#[async_trait::async_trait]
207impl McpTransportConn for HttpTransport {
208    async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
209        let body = serde_json::to_string(request)?;
210
211        let has_accept = self
212            .headers
213            .keys()
214            .any(|k| k.eq_ignore_ascii_case("Accept"));
215        let has_content_type = self
216            .headers
217            .keys()
218            .any(|k| k.eq_ignore_ascii_case("Content-Type"));
219
220        let mut req = self.client.post(&self.url).body(body);
221        if !has_content_type {
222            req = req.header("Content-Type", MCP_JSON_CONTENT_TYPE);
223        }
224        for (key, value) in &self.headers {
225            req = req.header(key, value);
226        }
227        req = self.apply_session_header(req);
228        if !has_accept {
229            req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
230        }
231
232        let resp = req
233            .send()
234            .await
235            .context("HTTP request to MCP server failed")?;
236
237        if !resp.status().is_success() {
238            bail!("MCP server returned HTTP {}", resp.status());
239        }
240
241        self.update_session_id_from_headers(resp.headers());
242
243        if request.id.is_none() {
244            return Ok(JsonRpcResponse {
245                jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
246                id: None,
247                result: None,
248                error: None,
249            });
250        }
251
252        let is_sse = resp
253            .headers()
254            .get(reqwest::header::CONTENT_TYPE)
255            .and_then(|v| v.to_str().ok())
256            .is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
257        if is_sse {
258            let maybe_resp = timeout(
259                Duration::from_secs(RECV_TIMEOUT_SECS),
260                read_first_jsonrpc_from_sse_response(resp),
261            )
262            .await
263            .context("timeout waiting for MCP response from streamable HTTP SSE stream")??;
264            return maybe_resp
265                .ok_or_else(|| anyhow!("MCP server returned no response in SSE stream"));
266        }
267
268        let resp_text = resp.text().await.context("failed to read HTTP response")?;
269        parse_jsonrpc_response_text(&resp_text)
270    }
271
272    async fn close(&mut self) -> Result<()> {
273        Ok(())
274    }
275}
276
277// ── SSE Transport ─────────────────────────────────────────────────────────
278
279/// SSE-based transport (HTTP POST for requests, SSE for responses).
280#[derive(Copy, Clone, Debug, Eq, PartialEq)]
281enum SseStreamState {
282    Unknown,
283    Connected,
284    Unsupported,
285}
286
287pub struct SseTransport {
288    sse_url: String,
289    server_name: String,
290    client: reqwest::Client,
291    headers: std::collections::HashMap<String, String>,
292    stream_state: SseStreamState,
293    shared: std::sync::Arc<Mutex<SseSharedState>>,
294    notify: std::sync::Arc<Notify>,
295    shutdown_tx: Option<oneshot::Sender<()>>,
296    reader_task: Option<tokio::task::JoinHandle<()>>,
297}
298
299impl SseTransport {
300    pub fn new(config: &McpServerConfig) -> Result<Self> {
301        let sse_url = config
302            .url
303            .as_ref()
304            .ok_or_else(|| anyhow!("URL required for SSE transport"))?
305            .clone();
306
307        let client = reqwest::Client::builder()
308            .build()
309            .context("failed to build HTTP client")?;
310
311        Ok(Self {
312            sse_url,
313            server_name: config.name.clone(),
314            client,
315            headers: config.headers.clone(),
316            stream_state: SseStreamState::Unknown,
317            shared: std::sync::Arc::new(Mutex::new(SseSharedState::default())),
318            notify: std::sync::Arc::new(Notify::new()),
319            shutdown_tx: None,
320            reader_task: None,
321        })
322    }
323
324    async fn ensure_connected(&mut self) -> Result<()> {
325        if self.stream_state == SseStreamState::Unsupported {
326            return Ok(());
327        }
328        if let Some(task) = &self.reader_task {
329            if !task.is_finished() {
330                self.stream_state = SseStreamState::Connected;
331                return Ok(());
332            }
333        }
334
335        let has_accept = self
336            .headers
337            .keys()
338            .any(|k| k.eq_ignore_ascii_case("Accept"));
339
340        let mut req = self
341            .client
342            .get(&self.sse_url)
343            .header("Cache-Control", "no-cache");
344        for (key, value) in &self.headers {
345            req = req.header(key, value);
346        }
347        if !has_accept {
348            req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
349        }
350
351        let resp = req.send().await.context("SSE GET to MCP server failed")?;
352        if resp.status() == reqwest::StatusCode::NOT_FOUND
353            || resp.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED
354        {
355            self.stream_state = SseStreamState::Unsupported;
356            return Ok(());
357        }
358        if !resp.status().is_success() {
359            return Err(anyhow!("MCP server returned HTTP {}", resp.status()));
360        }
361        let is_event_stream = resp
362            .headers()
363            .get(reqwest::header::CONTENT_TYPE)
364            .and_then(|v| v.to_str().ok())
365            .is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
366        if !is_event_stream {
367            self.stream_state = SseStreamState::Unsupported;
368            return Ok(());
369        }
370
371        let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>();
372        self.shutdown_tx = Some(shutdown_tx);
373
374        let shared = self.shared.clone();
375        let notify = self.notify.clone();
376        let sse_url = self.sse_url.clone();
377        let server_name = self.server_name.clone();
378
379        self.reader_task = Some(tokio::spawn(async move {
380            let stream = resp
381                .bytes_stream()
382                .map(|item| item.map_err(std::io::Error::other));
383            let reader = tokio_util::io::StreamReader::new(stream);
384            let mut lines = BufReader::new(reader).lines();
385
386            let mut cur_event: Option<String> = None;
387            let mut cur_id: Option<String> = None;
388            let mut cur_data: Vec<String> = Vec::new();
389
390            loop {
391                tokio::select! {
392                    _ = &mut shutdown_rx => {
393                        break;
394                    }
395                    line = lines.next_line() => {
396                        let Ok(line_opt) = line else { break; };
397                        let Some(mut line) = line_opt else { break; };
398                        if line.ends_with('\r') {
399                            line.pop();
400                        }
401                        if line.is_empty() {
402                            if cur_event.is_none() && cur_id.is_none() && cur_data.is_empty() {
403                                continue;
404                            }
405                            let event = cur_event.take();
406                            let data = cur_data.join("\n");
407                            cur_data.clear();
408                            let id = cur_id.take();
409                            handle_sse_event(&server_name, &sse_url, &shared, &notify, event.as_deref(), id.as_deref(), data).await;
410                            continue;
411                        }
412
413                        if line.starts_with(':') {
414                            continue;
415                        }
416
417                        if let Some(rest) = line.strip_prefix("event:") {
418                            cur_event = Some(rest.trim().to_string());
419                        }
420                        if let Some(rest) = line.strip_prefix("data:") {
421                            let rest = rest.strip_prefix(' ').unwrap_or(rest);
422                            cur_data.push(rest.to_string());
423                        }
424                        if let Some(rest) = line.strip_prefix("id:") {
425                            cur_id = Some(rest.trim().to_string());
426                        }
427                    }
428                }
429            }
430
431            let pending = {
432                let mut guard = shared.lock().await;
433                std::mem::take(&mut guard.pending)
434            };
435            for (_, tx) in pending {
436                let _ = tx.send(JsonRpcResponse {
437                    jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
438                    id: None,
439                    result: None,
440                    error: Some(JsonRpcError {
441                        code: INTERNAL_ERROR,
442                        message: "SSE connection closed".to_string(),
443                        data: None,
444                    }),
445                });
446            }
447        }));
448        self.stream_state = SseStreamState::Connected;
449
450        Ok(())
451    }
452
453    async fn get_message_url(&self) -> Result<(String, bool)> {
454        let guard = self.shared.lock().await;
455        if let Some(url) = &guard.message_url {
456            return Ok((url.clone(), guard.message_url_from_endpoint));
457        }
458        drop(guard);
459
460        let derived = derive_message_url(&self.sse_url, "messages")
461            .or_else(|| derive_message_url(&self.sse_url, "message"))
462            .ok_or_else(|| anyhow!("invalid SSE URL"))?;
463        let mut guard = self.shared.lock().await;
464        if guard.message_url.is_none() {
465            guard.message_url = Some(derived.clone());
466            guard.message_url_from_endpoint = false;
467        }
468        Ok((derived, false))
469    }
470
471    fn maybe_try_alternate_message_url(
472        &self,
473        current_url: &str,
474        from_endpoint: bool,
475    ) -> Option<String> {
476        if from_endpoint {
477            return None;
478        }
479        let alt = if current_url.ends_with("/messages") {
480            derive_message_url(&self.sse_url, "message")
481        } else {
482            derive_message_url(&self.sse_url, "messages")
483        }?;
484        if alt == current_url {
485            return None;
486        }
487        Some(alt)
488    }
489}
490
491#[derive(Default)]
492struct SseSharedState {
493    message_url: Option<String>,
494    message_url_from_endpoint: bool,
495    pending: std::collections::HashMap<u64, oneshot::Sender<JsonRpcResponse>>,
496}
497
498fn derive_message_url(sse_url: &str, message_path: &str) -> Option<String> {
499    let url = reqwest::Url::parse(sse_url).ok()?;
500    let mut segments: Vec<&str> = url.path_segments()?.collect();
501    if segments.is_empty() {
502        return None;
503    }
504    if segments.last().copied() == Some("sse") {
505        segments.pop();
506        segments.push(message_path);
507        let mut new_url = url.clone();
508        new_url.set_path(&format!("/{}", segments.join("/")));
509        return Some(new_url.to_string());
510    }
511    let mut new_url = url.clone();
512    let mut path = url.path().trim_end_matches('/').to_string();
513    path.push('/');
514    path.push_str(message_path);
515    new_url.set_path(&path);
516    Some(new_url.to_string())
517}
518
519async fn handle_sse_event(
520    server_name: &str,
521    sse_url: &str,
522    shared: &std::sync::Arc<Mutex<SseSharedState>>,
523    notify: &std::sync::Arc<Notify>,
524    event: Option<&str>,
525    _id: Option<&str>,
526    data: String,
527) {
528    let event = event.unwrap_or("message");
529    let trimmed = data.trim();
530    if trimmed.is_empty() {
531        return;
532    }
533
534    if event.eq_ignore_ascii_case("endpoint") || event.eq_ignore_ascii_case("mcp-endpoint") {
535        if let Some(url) = parse_endpoint_from_data(sse_url, trimmed) {
536            let mut guard = shared.lock().await;
537            guard.message_url = Some(url);
538            guard.message_url_from_endpoint = true;
539            drop(guard);
540            notify.notify_waiters();
541        }
542        return;
543    }
544
545    if !event.eq_ignore_ascii_case("message") {
546        return;
547    }
548
549    let Ok(value) = serde_json::from_str::<serde_json::Value>(trimmed) else {
550        return;
551    };
552
553    let Ok(resp) = serde_json::from_value::<JsonRpcResponse>(value.clone()) else {
554        let _ = serde_json::from_value::<JsonRpcRequest>(value);
555        return;
556    };
557
558    let Some(id_val) = resp.id.clone() else {
559        return;
560    };
561    let id = match id_val.as_u64() {
562        Some(v) => v,
563        None => return,
564    };
565
566    let tx = {
567        let mut guard = shared.lock().await;
568        guard.pending.remove(&id)
569    };
570    if let Some(tx) = tx {
571        let _ = tx.send(resp);
572    } else {
573        tracing::debug!(
574            "MCP SSE `{}` received response for unknown id {}",
575            server_name,
576            id
577        );
578    }
579}
580
581fn parse_endpoint_from_data(sse_url: &str, data: &str) -> Option<String> {
582    if data.starts_with('{') {
583        let v: serde_json::Value = serde_json::from_str(data).ok()?;
584        let endpoint = v.get("endpoint")?.as_str()?;
585        return parse_endpoint_from_data(sse_url, endpoint);
586    }
587    if data.starts_with("http://") || data.starts_with("https://") {
588        return Some(data.to_string());
589    }
590    let base = reqwest::Url::parse(sse_url).ok()?;
591    base.join(data).ok().map(|u| u.to_string())
592}
593
594fn extract_json_from_sse_text(resp_text: &str) -> Cow<'_, str> {
595    let text = resp_text.trim_start_matches('\u{feff}');
596    let mut current_data_lines: Vec<&str> = Vec::new();
597    let mut last_event_data_lines: Vec<&str> = Vec::new();
598
599    for raw_line in text.lines() {
600        let line = raw_line.trim_end_matches('\r').trim_start();
601        if line.is_empty() {
602            if !current_data_lines.is_empty() {
603                last_event_data_lines = std::mem::take(&mut current_data_lines);
604            }
605            continue;
606        }
607
608        if line.starts_with(':') {
609            continue;
610        }
611
612        if let Some(rest) = line.strip_prefix("data:") {
613            let rest = rest.strip_prefix(' ').unwrap_or(rest);
614            current_data_lines.push(rest);
615        }
616    }
617
618    if !current_data_lines.is_empty() {
619        last_event_data_lines = current_data_lines;
620    }
621
622    if last_event_data_lines.is_empty() {
623        return Cow::Borrowed(text.trim());
624    }
625
626    if last_event_data_lines.len() == 1 {
627        return Cow::Borrowed(last_event_data_lines[0].trim());
628    }
629
630    let joined = last_event_data_lines.join("\n");
631    Cow::Owned(joined.trim().to_string())
632}
633
634fn parse_jsonrpc_response_text(resp_text: &str) -> Result<JsonRpcResponse> {
635    let trimmed = resp_text.trim();
636    if trimmed.is_empty() {
637        bail!("MCP server returned no response");
638    }
639
640    let json_text = if looks_like_sse_text(trimmed) {
641        extract_json_from_sse_text(trimmed)
642    } else {
643        Cow::Borrowed(trimmed)
644    };
645
646    let mcp_resp: JsonRpcResponse = serde_json::from_str(json_text.as_ref())
647        .with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
648    Ok(mcp_resp)
649}
650
651fn looks_like_sse_text(text: &str) -> bool {
652    text.starts_with("data:")
653        || text.starts_with("event:")
654        || text.contains("\ndata:")
655        || text.contains("\nevent:")
656}
657
658async fn read_first_jsonrpc_from_sse_response(
659    resp: reqwest::Response,
660) -> Result<Option<JsonRpcResponse>> {
661    let stream = resp
662        .bytes_stream()
663        .map(|item| item.map_err(std::io::Error::other));
664    let reader = tokio_util::io::StreamReader::new(stream);
665    let mut lines = BufReader::new(reader).lines();
666
667    let mut cur_event: Option<String> = None;
668    let mut cur_data: Vec<String> = Vec::new();
669
670    while let Ok(line_opt) = lines.next_line().await {
671        let Some(mut line) = line_opt else { break };
672        if line.ends_with('\r') {
673            line.pop();
674        }
675        if line.is_empty() {
676            if cur_event.is_none() && cur_data.is_empty() {
677                continue;
678            }
679            let event = cur_event.take();
680            let data = cur_data.join("\n");
681            cur_data.clear();
682
683            let event = event.unwrap_or_else(|| "message".to_string());
684            if event.eq_ignore_ascii_case("endpoint") || event.eq_ignore_ascii_case("mcp-endpoint")
685            {
686                continue;
687            }
688            if !event.eq_ignore_ascii_case("message") {
689                continue;
690            }
691
692            let trimmed = data.trim();
693            if trimmed.is_empty() {
694                continue;
695            }
696            let json_str = extract_json_from_sse_text(trimmed);
697            if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(json_str.as_ref()) {
698                return Ok(Some(resp));
699            }
700            continue;
701        }
702
703        if line.starts_with(':') {
704            continue;
705        }
706        if let Some(rest) = line.strip_prefix("event:") {
707            cur_event = Some(rest.trim().to_string());
708        }
709        if let Some(rest) = line.strip_prefix("data:") {
710            let rest = rest.strip_prefix(' ').unwrap_or(rest);
711            cur_data.push(rest.to_string());
712        }
713    }
714
715    Ok(None)
716}
717
718#[async_trait::async_trait]
719impl McpTransportConn for SseTransport {
720    async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
721        self.ensure_connected().await?;
722
723        let id = request.id.as_ref().and_then(|v| v.as_u64());
724        let body = serde_json::to_string(request)?;
725
726        let (mut message_url, mut from_endpoint) = self.get_message_url().await?;
727        if self.stream_state == SseStreamState::Connected && !from_endpoint {
728            for _ in 0..3 {
729                {
730                    let guard = self.shared.lock().await;
731                    if guard.message_url_from_endpoint {
732                        if let Some(url) = &guard.message_url {
733                            message_url = url.clone();
734                            from_endpoint = true;
735                            break;
736                        }
737                    }
738                }
739                let _ = timeout(Duration::from_millis(300), self.notify.notified()).await;
740            }
741        }
742        let primary_url = if from_endpoint {
743            message_url.clone()
744        } else {
745            self.sse_url.clone()
746        };
747        let secondary_url = if message_url == self.sse_url {
748            None
749        } else if primary_url == message_url {
750            Some(self.sse_url.clone())
751        } else {
752            Some(message_url.clone())
753        };
754        let has_secondary = secondary_url.is_some();
755
756        let mut rx = None;
757        if let Some(id) = id {
758            if self.stream_state == SseStreamState::Connected {
759                let (tx, ch) = oneshot::channel();
760                {
761                    let mut guard = self.shared.lock().await;
762                    guard.pending.insert(id, tx);
763                }
764                rx = Some((id, ch));
765            }
766        }
767
768        let mut got_direct = None;
769        let mut last_status = None;
770
771        for (i, url) in std::iter::once(primary_url)
772            .chain(secondary_url.into_iter())
773            .enumerate()
774        {
775            let has_accept = self
776                .headers
777                .keys()
778                .any(|k| k.eq_ignore_ascii_case("Accept"));
779            let has_content_type = self
780                .headers
781                .keys()
782                .any(|k| k.eq_ignore_ascii_case("Content-Type"));
783            let mut req = self
784                .client
785                .post(&url)
786                .timeout(Duration::from_secs(120))
787                .body(body.clone());
788            if !has_content_type {
789                req = req.header("Content-Type", MCP_JSON_CONTENT_TYPE);
790            }
791            for (key, value) in &self.headers {
792                req = req.header(key, value);
793            }
794            if !has_accept {
795                req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
796            }
797
798            let resp = req.send().await.context("SSE POST to MCP server failed")?;
799            let status = resp.status();
800            last_status = Some(status);
801
802            if (status == reqwest::StatusCode::NOT_FOUND
803                || status == reqwest::StatusCode::METHOD_NOT_ALLOWED)
804                && i == 0
805            {
806                continue;
807            }
808
809            if !status.is_success() {
810                break;
811            }
812
813            if request.id.is_none() {
814                got_direct = Some(JsonRpcResponse {
815                    jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
816                    id: None,
817                    result: None,
818                    error: None,
819                });
820                break;
821            }
822
823            let is_sse = resp
824                .headers()
825                .get(reqwest::header::CONTENT_TYPE)
826                .and_then(|v| v.to_str().ok())
827                .is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
828
829            if is_sse {
830                if i == 0 && has_secondary {
831                    match timeout(
832                        Duration::from_secs(3),
833                        read_first_jsonrpc_from_sse_response(resp),
834                    )
835                    .await
836                    {
837                        Ok(res) => {
838                            if let Some(resp) = res? {
839                                got_direct = Some(resp);
840                            }
841                            break;
842                        }
843                        Err(_) => continue,
844                    }
845                }
846                if let Some(resp) = read_first_jsonrpc_from_sse_response(resp).await? {
847                    got_direct = Some(resp);
848                }
849                break;
850            }
851
852            let text = if i == 0 && has_secondary {
853                match timeout(Duration::from_secs(3), resp.text()).await {
854                    Ok(Ok(t)) => t,
855                    Ok(Err(_)) => String::new(),
856                    Err(_) => continue,
857                }
858            } else {
859                resp.text().await.unwrap_or_default()
860            };
861            let trimmed = text.trim();
862            if !trimmed.is_empty() {
863                let json_str = if trimmed.contains("\ndata:") || trimmed.starts_with("data:") {
864                    extract_json_from_sse_text(trimmed)
865                } else {
866                    Cow::Borrowed(trimmed)
867                };
868                if let Ok(mcp_resp) = serde_json::from_str::<JsonRpcResponse>(json_str.as_ref()) {
869                    got_direct = Some(mcp_resp);
870                }
871            }
872            break;
873        }
874
875        if let Some((id, _)) = rx.as_ref() {
876            if got_direct.is_some() {
877                let mut guard = self.shared.lock().await;
878                guard.pending.remove(id);
879            } else if let Some(status) = last_status {
880                if !status.is_success() {
881                    let mut guard = self.shared.lock().await;
882                    guard.pending.remove(id);
883                }
884            }
885        }
886
887        if let Some(resp) = got_direct {
888            return Ok(resp);
889        }
890
891        if let Some(status) = last_status {
892            if !status.is_success() {
893                bail!("MCP server returned HTTP {}", status);
894            }
895        } else {
896            bail!("MCP request not sent");
897        }
898
899        let Some((_id, rx)) = rx else {
900            bail!("MCP server returned no response");
901        };
902
903        rx.await.map_err(|_| anyhow!("SSE response channel closed"))
904    }
905
906    async fn close(&mut self) -> Result<()> {
907        if let Some(tx) = self.shutdown_tx.take() {
908            let _ = tx.send(());
909        }
910        if let Some(task) = self.reader_task.take() {
911            task.abort();
912        }
913        Ok(())
914    }
915}
916
917// ── Factory ──────────────────────────────────────────────────────────────
918
919/// Create a transport based on config.
920pub fn create_transport(config: &McpServerConfig) -> Result<Box<dyn McpTransportConn>> {
921    match config.transport {
922        McpTransport::Stdio => Ok(Box::new(StdioTransport::new(config)?)),
923        McpTransport::Http => Ok(Box::new(HttpTransport::new(config)?)),
924        McpTransport::Sse => Ok(Box::new(SseTransport::new(config)?)),
925    }
926}
927
928// ── Tests ─────────────────────────────────────────────────────────────────
929
930#[cfg(test)]
931mod tests {
932    use super::*;
933
934    #[test]
935    fn test_transport_default_is_stdio() {
936        let config = McpServerConfig::default();
937        assert_eq!(config.transport, McpTransport::Stdio);
938    }
939
940    #[test]
941    fn test_http_transport_requires_url() {
942        let config = McpServerConfig {
943            name: "test".into(),
944            transport: McpTransport::Http,
945            ..Default::default()
946        };
947        assert!(HttpTransport::new(&config).is_err());
948    }
949
950    #[test]
951    fn test_sse_transport_requires_url() {
952        let config = McpServerConfig {
953            name: "test".into(),
954            transport: McpTransport::Sse,
955            ..Default::default()
956        };
957        assert!(SseTransport::new(&config).is_err());
958    }
959
960    #[test]
961    fn test_extract_json_from_sse_data_no_space() {
962        let input = "data:{\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
963        let extracted = extract_json_from_sse_text(input);
964        let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
965    }
966
967    #[test]
968    fn test_extract_json_from_sse_with_event_and_id() {
969        let input = "id: 1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
970        let extracted = extract_json_from_sse_text(input);
971        let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
972    }
973
974    #[test]
975    fn test_extract_json_from_sse_multiline_data() {
976        let input = "event: message\ndata: {\ndata:   \"jsonrpc\": \"2.0\",\ndata:   \"result\": {}\ndata: }\n\n";
977        let extracted = extract_json_from_sse_text(input);
978        let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
979    }
980
981    #[test]
982    fn test_extract_json_from_sse_skips_bom_and_leading_whitespace() {
983        let input = "\u{feff}\n\n  data: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
984        let extracted = extract_json_from_sse_text(input);
985        let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
986    }
987
988    #[test]
989    fn test_extract_json_from_sse_uses_last_event_with_data() {
990        let input =
991            ": keep-alive\n\nid: 1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
992        let extracted = extract_json_from_sse_text(input);
993        let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
994    }
995
996    #[test]
997    fn test_parse_jsonrpc_response_text_handles_plain_json() {
998        let parsed = parse_jsonrpc_response_text("{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}")
999            .expect("plain JSON response should parse");
1000        assert_eq!(parsed.id, Some(serde_json::json!(1)));
1001        assert!(parsed.error.is_none());
1002    }
1003
1004    #[test]
1005    fn test_parse_jsonrpc_response_text_handles_sse_framed_json() {
1006        let sse =
1007            "event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":2,\"result\":{\"ok\":true}}\n\n";
1008        let parsed =
1009            parse_jsonrpc_response_text(sse).expect("SSE-framed JSON response should parse");
1010        assert_eq!(parsed.id, Some(serde_json::json!(2)));
1011        assert_eq!(
1012            parsed
1013                .result
1014                .as_ref()
1015                .and_then(|v| v.get("ok"))
1016                .and_then(|v| v.as_bool()),
1017            Some(true)
1018        );
1019    }
1020
1021    #[test]
1022    fn test_parse_jsonrpc_response_text_rejects_empty_payload() {
1023        assert!(parse_jsonrpc_response_text(" \n\t ").is_err());
1024    }
1025
1026    #[test]
1027    fn http_transport_updates_session_id_from_response_headers() {
1028        let config = McpServerConfig {
1029            name: "test-http".into(),
1030            transport: McpTransport::Http,
1031            url: Some("http://localhost/mcp".into()),
1032            ..Default::default()
1033        };
1034        let mut transport = HttpTransport::new(&config).expect("build transport");
1035
1036        let mut headers = reqwest::header::HeaderMap::new();
1037        headers.insert(
1038            reqwest::header::HeaderName::from_static("mcp-session-id"),
1039            reqwest::header::HeaderValue::from_static("session-abc"),
1040        );
1041        transport.update_session_id_from_headers(&headers);
1042        assert_eq!(transport.session_id.as_deref(), Some("session-abc"));
1043    }
1044
1045    #[test]
1046    fn http_transport_injects_session_id_header_when_available() {
1047        let config = McpServerConfig {
1048            name: "test-http".into(),
1049            transport: McpTransport::Http,
1050            url: Some("http://localhost/mcp".into()),
1051            ..Default::default()
1052        };
1053        let mut transport = HttpTransport::new(&config).expect("build transport");
1054        transport.session_id = Some("session-xyz".to_string());
1055
1056        let req = transport
1057            .apply_session_header(reqwest::Client::new().post("http://localhost/mcp"))
1058            .build()
1059            .expect("build request");
1060        assert_eq!(
1061            req.headers()
1062                .get(MCP_SESSION_ID_HEADER)
1063                .and_then(|v| v.to_str().ok()),
1064            Some("session-xyz")
1065        );
1066    }
1067
1068    // ── derive_message_url tests ──────────────────────────────────────────────
1069
1070    #[test]
1071    fn derive_message_url_replaces_sse_segment_with_messages() {
1072        let url = derive_message_url("http://localhost:3000/mcp/sse", "messages");
1073        assert_eq!(url, Some("http://localhost:3000/mcp/messages".to_string()));
1074    }
1075
1076    #[test]
1077    fn derive_message_url_appends_when_no_sse_segment() {
1078        let url = derive_message_url("http://localhost:3000/mcp", "messages");
1079        assert_eq!(url, Some("http://localhost:3000/mcp/messages".to_string()));
1080    }
1081
1082    #[test]
1083    fn derive_message_url_returns_none_for_invalid_url() {
1084        let url = derive_message_url("not-a-url", "messages");
1085        assert!(url.is_none());
1086    }
1087
1088    #[test]
1089    fn derive_message_url_message_path_variant() {
1090        let url = derive_message_url("http://localhost:3000/mcp/sse", "message");
1091        assert_eq!(url, Some("http://localhost:3000/mcp/message".to_string()));
1092    }
1093
1094    // ── parse_endpoint_from_data tests ───────────────────────────────────────
1095
1096    #[test]
1097    fn parse_endpoint_absolute_http_url_returned_as_is() {
1098        let result = parse_endpoint_from_data("http://base/sse", "http://other/messages");
1099        assert_eq!(result, Some("http://other/messages".to_string()));
1100    }
1101
1102    #[test]
1103    fn parse_endpoint_absolute_https_url_returned_as_is() {
1104        let result = parse_endpoint_from_data("https://base/sse", "https://other/messages");
1105        assert_eq!(result, Some("https://other/messages".to_string()));
1106    }
1107
1108    #[test]
1109    fn parse_endpoint_relative_path_resolved_against_base() {
1110        let result = parse_endpoint_from_data("http://localhost:3000/sse", "/messages");
1111        assert_eq!(result, Some("http://localhost:3000/messages".to_string()));
1112    }
1113
1114    #[test]
1115    fn parse_endpoint_json_object_with_endpoint_key() {
1116        let json_data = r#"{"endpoint":"/messages"}"#;
1117        let result = parse_endpoint_from_data("http://localhost:3000/sse", json_data);
1118        assert_eq!(result, Some("http://localhost:3000/messages".to_string()));
1119    }
1120
1121    // ── looks_like_sse_text tests ─────────────────────────────────────────────
1122
1123    #[test]
1124    fn looks_like_sse_text_detects_data_prefix() {
1125        assert!(looks_like_sse_text("data:{\"jsonrpc\":\"2.0\"}"));
1126    }
1127
1128    #[test]
1129    fn looks_like_sse_text_detects_event_prefix() {
1130        assert!(looks_like_sse_text("event: message\ndata: {}"));
1131    }
1132
1133    #[test]
1134    fn looks_like_sse_text_detects_embedded_data_line() {
1135        assert!(looks_like_sse_text("id: 1\ndata:{\"x\":1}"));
1136    }
1137
1138    #[test]
1139    fn looks_like_sse_text_plain_json_is_not_sse() {
1140        assert!(!looks_like_sse_text(
1141            "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}"
1142        ));
1143    }
1144
1145    // ── extract_json_from_sse_text edge cases ─────────────────────────────────
1146
1147    #[test]
1148    fn extract_json_skips_comment_lines() {
1149        let input = ": keep-alive\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
1150        let extracted = extract_json_from_sse_text(input);
1151        let v: serde_json::Value = serde_json::from_str(extracted.as_ref()).unwrap();
1152        assert_eq!(v["jsonrpc"], "2.0");
1153    }
1154
1155    #[test]
1156    fn extract_json_empty_input_returns_empty_trimmed() {
1157        let result = extract_json_from_sse_text("   ");
1158        assert!(result.as_ref().trim().is_empty());
1159    }
1160
1161    #[test]
1162    fn extract_json_plain_json_returned_unchanged() {
1163        let input = "{\"jsonrpc\":\"2.0\",\"result\":{}}";
1164        let extracted = extract_json_from_sse_text(input);
1165        // No SSE framing, extracted as-is (trimmed)
1166        assert_eq!(extracted.as_ref(), input);
1167    }
1168
1169    // ── parse_jsonrpc_response_text edge cases ────────────────────────────────
1170
1171    #[test]
1172    fn parse_jsonrpc_response_rejects_whitespace_only() {
1173        assert!(parse_jsonrpc_response_text("   \n\t  ").is_err());
1174    }
1175
1176    #[test]
1177    fn parse_jsonrpc_response_with_error_result() {
1178        let json = r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"not found"}}"#;
1179        let resp = parse_jsonrpc_response_text(json).unwrap();
1180        assert!(resp.error.is_some());
1181        assert_eq!(resp.error.unwrap().code, -32601);
1182    }
1183
1184    // ── create_transport factory ──────────────────────────────────────────────
1185
1186    #[test]
1187    fn create_transport_stdio_fails_without_valid_command() {
1188        // Spawning a non-existent binary should fail
1189        let config = McpServerConfig {
1190            name: "test-stdio".into(),
1191            transport: McpTransport::Stdio,
1192            command: "/usr/bin/construct_nonexistent_binary_abc123".into(),
1193            ..Default::default()
1194        };
1195        let result = create_transport(&config);
1196        assert!(result.is_err());
1197    }
1198
1199    #[test]
1200    fn create_transport_http_without_url_fails() {
1201        let config = McpServerConfig {
1202            name: "test-http".into(),
1203            transport: McpTransport::Http,
1204            ..Default::default()
1205        };
1206        assert!(create_transport(&config).is_err());
1207    }
1208
1209    #[test]
1210    fn create_transport_sse_without_url_fails() {
1211        let config = McpServerConfig {
1212            name: "test-sse".into(),
1213            transport: McpTransport::Sse,
1214            ..Default::default()
1215        };
1216        assert!(create_transport(&config).is_err());
1217    }
1218
1219    #[test]
1220    fn create_transport_http_with_url_succeeds() {
1221        let config = McpServerConfig {
1222            name: "test-http".into(),
1223            transport: McpTransport::Http,
1224            url: Some("http://localhost:9999/mcp".into()),
1225            ..Default::default()
1226        };
1227        // Build should succeed even if server isn't running
1228        assert!(create_transport(&config).is_ok());
1229    }
1230
1231    #[test]
1232    fn create_transport_sse_with_url_succeeds() {
1233        let config = McpServerConfig {
1234            name: "test-sse".into(),
1235            transport: McpTransport::Sse,
1236            url: Some("http://localhost:9999/sse".into()),
1237            ..Default::default()
1238        };
1239        assert!(create_transport(&config).is_ok());
1240    }
1241
1242    // ── HTTP session id whitespace handling ───────────────────────────────────
1243
1244    #[test]
1245    fn http_transport_ignores_empty_session_id_header() {
1246        let config = McpServerConfig {
1247            name: "test-http".into(),
1248            transport: McpTransport::Http,
1249            url: Some("http://localhost/mcp".into()),
1250            ..Default::default()
1251        };
1252        let mut transport = HttpTransport::new(&config).expect("build transport");
1253        let mut headers = reqwest::header::HeaderMap::new();
1254        headers.insert(
1255            reqwest::header::HeaderName::from_static("mcp-session-id"),
1256            reqwest::header::HeaderValue::from_static("   "),
1257        );
1258        transport.update_session_id_from_headers(&headers);
1259        // Whitespace-only session id should not be stored
1260        assert!(transport.session_id.is_none());
1261    }
1262
1263    #[test]
1264    fn http_transport_no_session_header_leaves_none() {
1265        let config = McpServerConfig {
1266            name: "test-http".into(),
1267            transport: McpTransport::Http,
1268            url: Some("http://localhost/mcp".into()),
1269            ..Default::default()
1270        };
1271        let transport = HttpTransport::new(&config).expect("build transport");
1272        assert!(transport.session_id.is_none());
1273    }
1274
1275    #[test]
1276    fn http_transport_apply_session_header_noop_when_no_session() {
1277        let config = McpServerConfig {
1278            name: "test-http".into(),
1279            transport: McpTransport::Http,
1280            url: Some("http://localhost/mcp".into()),
1281            ..Default::default()
1282        };
1283        let transport = HttpTransport::new(&config).expect("build transport");
1284        let req = transport
1285            .apply_session_header(reqwest::Client::new().post("http://localhost/mcp"))
1286            .build()
1287            .expect("build request");
1288        assert!(req.headers().get(MCP_SESSION_ID_HEADER).is_none());
1289    }
1290}