1use std::borrow::Cow;
4
5use anyhow::{Context, Result, anyhow, bail};
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::process::{Child, Command};
8use tokio::sync::{Mutex, Notify, oneshot};
9use tokio::time::{Duration, timeout};
10use tokio_stream::StreamExt;
11
12use crate::config::schema::{McpServerConfig, McpTransport};
13use crate::tools::mcp_protocol::{INTERNAL_ERROR, JsonRpcError, JsonRpcRequest, JsonRpcResponse};
14
15const MAX_LINE_BYTES: usize = 4 * 1024 * 1024; const RECV_TIMEOUT_SECS: u64 = 30;
22
23const MCP_STREAMABLE_ACCEPT: &str = "application/json, text/event-stream";
25
26const MCP_JSON_CONTENT_TYPE: &str = "application/json";
28const MCP_SESSION_ID_HEADER: &str = "Mcp-Session-Id";
30
31#[async_trait::async_trait]
35pub trait McpTransportConn: Send + Sync {
36 async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse>;
38
39 async fn close(&mut self) -> Result<()>;
41}
42
43pub struct StdioTransport {
47 _child: Child,
48 stdin: tokio::process::ChildStdin,
49 stdout_lines: tokio::io::Lines<BufReader<tokio::process::ChildStdout>>,
50}
51
52impl StdioTransport {
53 pub fn new(config: &McpServerConfig) -> Result<Self> {
54 let mut child = Command::new(&config.command)
55 .args(&config.args)
56 .envs(&config.env)
57 .stdin(std::process::Stdio::piped())
58 .stdout(std::process::Stdio::piped())
59 .stderr(std::process::Stdio::inherit())
60 .kill_on_drop(true)
61 .spawn()
62 .with_context(|| format!("failed to spawn MCP server `{}`", config.name))?;
63
64 let stdin = child
65 .stdin
66 .take()
67 .ok_or_else(|| anyhow!("no stdin on MCP server `{}`", config.name))?;
68 let stdout = child
69 .stdout
70 .take()
71 .ok_or_else(|| anyhow!("no stdout on MCP server `{}`", config.name))?;
72 let stdout_lines = BufReader::new(stdout).lines();
73
74 Ok(Self {
75 _child: child,
76 stdin,
77 stdout_lines,
78 })
79 }
80
81 async fn send_raw(&mut self, line: &str) -> Result<()> {
82 self.stdin
83 .write_all(line.as_bytes())
84 .await
85 .context("failed to write to MCP server stdin")?;
86 self.stdin
87 .write_all(b"\n")
88 .await
89 .context("failed to write newline to MCP server stdin")?;
90 self.stdin.flush().await.context("failed to flush stdin")?;
91 Ok(())
92 }
93
94 async fn recv_raw(&mut self) -> Result<String> {
95 let line = self
96 .stdout_lines
97 .next_line()
98 .await?
99 .ok_or_else(|| anyhow!("MCP server closed stdout"))?;
100 if line.len() > MAX_LINE_BYTES {
101 bail!("MCP response too large: {} bytes", line.len());
102 }
103 Ok(line)
104 }
105}
106
107#[async_trait::async_trait]
108impl McpTransportConn for StdioTransport {
109 async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
110 let line = serde_json::to_string(request)?;
111 self.send_raw(&line).await?;
112 if request.id.is_none() {
113 return Ok(JsonRpcResponse {
114 jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
115 id: None,
116 result: None,
117 error: None,
118 });
119 }
120 loop {
124 let resp_line = self.recv_raw().await?;
125 let resp: JsonRpcResponse = serde_json::from_str(&resp_line)
126 .with_context(|| format!("invalid JSON-RPC response: {}", resp_line))?;
127 if resp.id.is_none() {
128 tracing::debug!(
129 "MCP stdio: skipping server notification while waiting for response"
130 );
131 continue;
132 }
133 if resp.id != request.id {
137 tracing::warn!(
138 "MCP stdio: discarding response with mismatched id \
139 (got {:?}, expected {:?}) — likely stale from a timed-out request",
140 resp.id,
141 request.id
142 );
143 continue;
144 }
145 return Ok(resp);
146 }
147 }
148
149 async fn close(&mut self) -> Result<()> {
150 let _ = self.stdin.shutdown().await;
151 Ok(())
152 }
153}
154
155pub struct HttpTransport {
159 url: String,
160 client: reqwest::Client,
161 headers: std::collections::HashMap<String, String>,
162 session_id: Option<String>,
163}
164
165impl HttpTransport {
166 pub fn new(config: &McpServerConfig) -> Result<Self> {
167 let url = config
168 .url
169 .as_ref()
170 .ok_or_else(|| anyhow!("URL required for HTTP transport"))?
171 .clone();
172
173 let client = reqwest::Client::builder()
174 .timeout(Duration::from_secs(120))
175 .build()
176 .context("failed to build HTTP client")?;
177
178 Ok(Self {
179 url,
180 client,
181 headers: config.headers.clone(),
182 session_id: None,
183 })
184 }
185
186 fn apply_session_header(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
187 if let Some(session_id) = self.session_id.as_deref() {
188 req.header(MCP_SESSION_ID_HEADER, session_id)
189 } else {
190 req
191 }
192 }
193
194 fn update_session_id_from_headers(&mut self, headers: &reqwest::header::HeaderMap) {
195 if let Some(session_id) = headers
196 .get(MCP_SESSION_ID_HEADER)
197 .and_then(|v| v.to_str().ok())
198 .map(str::trim)
199 .filter(|v| !v.is_empty())
200 {
201 self.session_id = Some(session_id.to_string());
202 }
203 }
204}
205
206#[async_trait::async_trait]
207impl McpTransportConn for HttpTransport {
208 async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
209 let body = serde_json::to_string(request)?;
210
211 let has_accept = self
212 .headers
213 .keys()
214 .any(|k| k.eq_ignore_ascii_case("Accept"));
215 let has_content_type = self
216 .headers
217 .keys()
218 .any(|k| k.eq_ignore_ascii_case("Content-Type"));
219
220 let mut req = self.client.post(&self.url).body(body);
221 if !has_content_type {
222 req = req.header("Content-Type", MCP_JSON_CONTENT_TYPE);
223 }
224 for (key, value) in &self.headers {
225 req = req.header(key, value);
226 }
227 req = self.apply_session_header(req);
228 if !has_accept {
229 req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
230 }
231
232 let resp = req
233 .send()
234 .await
235 .context("HTTP request to MCP server failed")?;
236
237 if !resp.status().is_success() {
238 bail!("MCP server returned HTTP {}", resp.status());
239 }
240
241 self.update_session_id_from_headers(resp.headers());
242
243 if request.id.is_none() {
244 return Ok(JsonRpcResponse {
245 jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
246 id: None,
247 result: None,
248 error: None,
249 });
250 }
251
252 let is_sse = resp
253 .headers()
254 .get(reqwest::header::CONTENT_TYPE)
255 .and_then(|v| v.to_str().ok())
256 .is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
257 if is_sse {
258 let maybe_resp = timeout(
259 Duration::from_secs(RECV_TIMEOUT_SECS),
260 read_first_jsonrpc_from_sse_response(resp),
261 )
262 .await
263 .context("timeout waiting for MCP response from streamable HTTP SSE stream")??;
264 return maybe_resp
265 .ok_or_else(|| anyhow!("MCP server returned no response in SSE stream"));
266 }
267
268 let resp_text = resp.text().await.context("failed to read HTTP response")?;
269 parse_jsonrpc_response_text(&resp_text)
270 }
271
272 async fn close(&mut self) -> Result<()> {
273 Ok(())
274 }
275}
276
277#[derive(Copy, Clone, Debug, Eq, PartialEq)]
281enum SseStreamState {
282 Unknown,
283 Connected,
284 Unsupported,
285}
286
287pub struct SseTransport {
288 sse_url: String,
289 server_name: String,
290 client: reqwest::Client,
291 headers: std::collections::HashMap<String, String>,
292 stream_state: SseStreamState,
293 shared: std::sync::Arc<Mutex<SseSharedState>>,
294 notify: std::sync::Arc<Notify>,
295 shutdown_tx: Option<oneshot::Sender<()>>,
296 reader_task: Option<tokio::task::JoinHandle<()>>,
297}
298
299impl SseTransport {
300 pub fn new(config: &McpServerConfig) -> Result<Self> {
301 let sse_url = config
302 .url
303 .as_ref()
304 .ok_or_else(|| anyhow!("URL required for SSE transport"))?
305 .clone();
306
307 let client = reqwest::Client::builder()
308 .build()
309 .context("failed to build HTTP client")?;
310
311 Ok(Self {
312 sse_url,
313 server_name: config.name.clone(),
314 client,
315 headers: config.headers.clone(),
316 stream_state: SseStreamState::Unknown,
317 shared: std::sync::Arc::new(Mutex::new(SseSharedState::default())),
318 notify: std::sync::Arc::new(Notify::new()),
319 shutdown_tx: None,
320 reader_task: None,
321 })
322 }
323
324 async fn ensure_connected(&mut self) -> Result<()> {
325 if self.stream_state == SseStreamState::Unsupported {
326 return Ok(());
327 }
328 if let Some(task) = &self.reader_task {
329 if !task.is_finished() {
330 self.stream_state = SseStreamState::Connected;
331 return Ok(());
332 }
333 }
334
335 let has_accept = self
336 .headers
337 .keys()
338 .any(|k| k.eq_ignore_ascii_case("Accept"));
339
340 let mut req = self
341 .client
342 .get(&self.sse_url)
343 .header("Cache-Control", "no-cache");
344 for (key, value) in &self.headers {
345 req = req.header(key, value);
346 }
347 if !has_accept {
348 req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
349 }
350
351 let resp = req.send().await.context("SSE GET to MCP server failed")?;
352 if resp.status() == reqwest::StatusCode::NOT_FOUND
353 || resp.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED
354 {
355 self.stream_state = SseStreamState::Unsupported;
356 return Ok(());
357 }
358 if !resp.status().is_success() {
359 return Err(anyhow!("MCP server returned HTTP {}", resp.status()));
360 }
361 let is_event_stream = resp
362 .headers()
363 .get(reqwest::header::CONTENT_TYPE)
364 .and_then(|v| v.to_str().ok())
365 .is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
366 if !is_event_stream {
367 self.stream_state = SseStreamState::Unsupported;
368 return Ok(());
369 }
370
371 let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>();
372 self.shutdown_tx = Some(shutdown_tx);
373
374 let shared = self.shared.clone();
375 let notify = self.notify.clone();
376 let sse_url = self.sse_url.clone();
377 let server_name = self.server_name.clone();
378
379 self.reader_task = Some(tokio::spawn(async move {
380 let stream = resp
381 .bytes_stream()
382 .map(|item| item.map_err(std::io::Error::other));
383 let reader = tokio_util::io::StreamReader::new(stream);
384 let mut lines = BufReader::new(reader).lines();
385
386 let mut cur_event: Option<String> = None;
387 let mut cur_id: Option<String> = None;
388 let mut cur_data: Vec<String> = Vec::new();
389
390 loop {
391 tokio::select! {
392 _ = &mut shutdown_rx => {
393 break;
394 }
395 line = lines.next_line() => {
396 let Ok(line_opt) = line else { break; };
397 let Some(mut line) = line_opt else { break; };
398 if line.ends_with('\r') {
399 line.pop();
400 }
401 if line.is_empty() {
402 if cur_event.is_none() && cur_id.is_none() && cur_data.is_empty() {
403 continue;
404 }
405 let event = cur_event.take();
406 let data = cur_data.join("\n");
407 cur_data.clear();
408 let id = cur_id.take();
409 handle_sse_event(&server_name, &sse_url, &shared, ¬ify, event.as_deref(), id.as_deref(), data).await;
410 continue;
411 }
412
413 if line.starts_with(':') {
414 continue;
415 }
416
417 if let Some(rest) = line.strip_prefix("event:") {
418 cur_event = Some(rest.trim().to_string());
419 }
420 if let Some(rest) = line.strip_prefix("data:") {
421 let rest = rest.strip_prefix(' ').unwrap_or(rest);
422 cur_data.push(rest.to_string());
423 }
424 if let Some(rest) = line.strip_prefix("id:") {
425 cur_id = Some(rest.trim().to_string());
426 }
427 }
428 }
429 }
430
431 let pending = {
432 let mut guard = shared.lock().await;
433 std::mem::take(&mut guard.pending)
434 };
435 for (_, tx) in pending {
436 let _ = tx.send(JsonRpcResponse {
437 jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
438 id: None,
439 result: None,
440 error: Some(JsonRpcError {
441 code: INTERNAL_ERROR,
442 message: "SSE connection closed".to_string(),
443 data: None,
444 }),
445 });
446 }
447 }));
448 self.stream_state = SseStreamState::Connected;
449
450 Ok(())
451 }
452
453 async fn get_message_url(&self) -> Result<(String, bool)> {
454 let guard = self.shared.lock().await;
455 if let Some(url) = &guard.message_url {
456 return Ok((url.clone(), guard.message_url_from_endpoint));
457 }
458 drop(guard);
459
460 let derived = derive_message_url(&self.sse_url, "messages")
461 .or_else(|| derive_message_url(&self.sse_url, "message"))
462 .ok_or_else(|| anyhow!("invalid SSE URL"))?;
463 let mut guard = self.shared.lock().await;
464 if guard.message_url.is_none() {
465 guard.message_url = Some(derived.clone());
466 guard.message_url_from_endpoint = false;
467 }
468 Ok((derived, false))
469 }
470
471 fn maybe_try_alternate_message_url(
472 &self,
473 current_url: &str,
474 from_endpoint: bool,
475 ) -> Option<String> {
476 if from_endpoint {
477 return None;
478 }
479 let alt = if current_url.ends_with("/messages") {
480 derive_message_url(&self.sse_url, "message")
481 } else {
482 derive_message_url(&self.sse_url, "messages")
483 }?;
484 if alt == current_url {
485 return None;
486 }
487 Some(alt)
488 }
489}
490
491#[derive(Default)]
492struct SseSharedState {
493 message_url: Option<String>,
494 message_url_from_endpoint: bool,
495 pending: std::collections::HashMap<u64, oneshot::Sender<JsonRpcResponse>>,
496}
497
498fn derive_message_url(sse_url: &str, message_path: &str) -> Option<String> {
499 let url = reqwest::Url::parse(sse_url).ok()?;
500 let mut segments: Vec<&str> = url.path_segments()?.collect();
501 if segments.is_empty() {
502 return None;
503 }
504 if segments.last().copied() == Some("sse") {
505 segments.pop();
506 segments.push(message_path);
507 let mut new_url = url.clone();
508 new_url.set_path(&format!("/{}", segments.join("/")));
509 return Some(new_url.to_string());
510 }
511 let mut new_url = url.clone();
512 let mut path = url.path().trim_end_matches('/').to_string();
513 path.push('/');
514 path.push_str(message_path);
515 new_url.set_path(&path);
516 Some(new_url.to_string())
517}
518
519async fn handle_sse_event(
520 server_name: &str,
521 sse_url: &str,
522 shared: &std::sync::Arc<Mutex<SseSharedState>>,
523 notify: &std::sync::Arc<Notify>,
524 event: Option<&str>,
525 _id: Option<&str>,
526 data: String,
527) {
528 let event = event.unwrap_or("message");
529 let trimmed = data.trim();
530 if trimmed.is_empty() {
531 return;
532 }
533
534 if event.eq_ignore_ascii_case("endpoint") || event.eq_ignore_ascii_case("mcp-endpoint") {
535 if let Some(url) = parse_endpoint_from_data(sse_url, trimmed) {
536 let mut guard = shared.lock().await;
537 guard.message_url = Some(url);
538 guard.message_url_from_endpoint = true;
539 drop(guard);
540 notify.notify_waiters();
541 }
542 return;
543 }
544
545 if !event.eq_ignore_ascii_case("message") {
546 return;
547 }
548
549 let Ok(value) = serde_json::from_str::<serde_json::Value>(trimmed) else {
550 return;
551 };
552
553 let Ok(resp) = serde_json::from_value::<JsonRpcResponse>(value.clone()) else {
554 let _ = serde_json::from_value::<JsonRpcRequest>(value);
555 return;
556 };
557
558 let Some(id_val) = resp.id.clone() else {
559 return;
560 };
561 let id = match id_val.as_u64() {
562 Some(v) => v,
563 None => return,
564 };
565
566 let tx = {
567 let mut guard = shared.lock().await;
568 guard.pending.remove(&id)
569 };
570 if let Some(tx) = tx {
571 let _ = tx.send(resp);
572 } else {
573 tracing::debug!(
574 "MCP SSE `{}` received response for unknown id {}",
575 server_name,
576 id
577 );
578 }
579}
580
581fn parse_endpoint_from_data(sse_url: &str, data: &str) -> Option<String> {
582 if data.starts_with('{') {
583 let v: serde_json::Value = serde_json::from_str(data).ok()?;
584 let endpoint = v.get("endpoint")?.as_str()?;
585 return parse_endpoint_from_data(sse_url, endpoint);
586 }
587 if data.starts_with("http://") || data.starts_with("https://") {
588 return Some(data.to_string());
589 }
590 let base = reqwest::Url::parse(sse_url).ok()?;
591 base.join(data).ok().map(|u| u.to_string())
592}
593
594fn extract_json_from_sse_text(resp_text: &str) -> Cow<'_, str> {
595 let text = resp_text.trim_start_matches('\u{feff}');
596 let mut current_data_lines: Vec<&str> = Vec::new();
597 let mut last_event_data_lines: Vec<&str> = Vec::new();
598
599 for raw_line in text.lines() {
600 let line = raw_line.trim_end_matches('\r').trim_start();
601 if line.is_empty() {
602 if !current_data_lines.is_empty() {
603 last_event_data_lines = std::mem::take(&mut current_data_lines);
604 }
605 continue;
606 }
607
608 if line.starts_with(':') {
609 continue;
610 }
611
612 if let Some(rest) = line.strip_prefix("data:") {
613 let rest = rest.strip_prefix(' ').unwrap_or(rest);
614 current_data_lines.push(rest);
615 }
616 }
617
618 if !current_data_lines.is_empty() {
619 last_event_data_lines = current_data_lines;
620 }
621
622 if last_event_data_lines.is_empty() {
623 return Cow::Borrowed(text.trim());
624 }
625
626 if last_event_data_lines.len() == 1 {
627 return Cow::Borrowed(last_event_data_lines[0].trim());
628 }
629
630 let joined = last_event_data_lines.join("\n");
631 Cow::Owned(joined.trim().to_string())
632}
633
634fn parse_jsonrpc_response_text(resp_text: &str) -> Result<JsonRpcResponse> {
635 let trimmed = resp_text.trim();
636 if trimmed.is_empty() {
637 bail!("MCP server returned no response");
638 }
639
640 let json_text = if looks_like_sse_text(trimmed) {
641 extract_json_from_sse_text(trimmed)
642 } else {
643 Cow::Borrowed(trimmed)
644 };
645
646 let mcp_resp: JsonRpcResponse = serde_json::from_str(json_text.as_ref())
647 .with_context(|| format!("invalid JSON-RPC response: {}", resp_text))?;
648 Ok(mcp_resp)
649}
650
651fn looks_like_sse_text(text: &str) -> bool {
652 text.starts_with("data:")
653 || text.starts_with("event:")
654 || text.contains("\ndata:")
655 || text.contains("\nevent:")
656}
657
658async fn read_first_jsonrpc_from_sse_response(
659 resp: reqwest::Response,
660) -> Result<Option<JsonRpcResponse>> {
661 let stream = resp
662 .bytes_stream()
663 .map(|item| item.map_err(std::io::Error::other));
664 let reader = tokio_util::io::StreamReader::new(stream);
665 let mut lines = BufReader::new(reader).lines();
666
667 let mut cur_event: Option<String> = None;
668 let mut cur_data: Vec<String> = Vec::new();
669
670 while let Ok(line_opt) = lines.next_line().await {
671 let Some(mut line) = line_opt else { break };
672 if line.ends_with('\r') {
673 line.pop();
674 }
675 if line.is_empty() {
676 if cur_event.is_none() && cur_data.is_empty() {
677 continue;
678 }
679 let event = cur_event.take();
680 let data = cur_data.join("\n");
681 cur_data.clear();
682
683 let event = event.unwrap_or_else(|| "message".to_string());
684 if event.eq_ignore_ascii_case("endpoint") || event.eq_ignore_ascii_case("mcp-endpoint")
685 {
686 continue;
687 }
688 if !event.eq_ignore_ascii_case("message") {
689 continue;
690 }
691
692 let trimmed = data.trim();
693 if trimmed.is_empty() {
694 continue;
695 }
696 let json_str = extract_json_from_sse_text(trimmed);
697 if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(json_str.as_ref()) {
698 return Ok(Some(resp));
699 }
700 continue;
701 }
702
703 if line.starts_with(':') {
704 continue;
705 }
706 if let Some(rest) = line.strip_prefix("event:") {
707 cur_event = Some(rest.trim().to_string());
708 }
709 if let Some(rest) = line.strip_prefix("data:") {
710 let rest = rest.strip_prefix(' ').unwrap_or(rest);
711 cur_data.push(rest.to_string());
712 }
713 }
714
715 Ok(None)
716}
717
718#[async_trait::async_trait]
719impl McpTransportConn for SseTransport {
720 async fn send_and_recv(&mut self, request: &JsonRpcRequest) -> Result<JsonRpcResponse> {
721 self.ensure_connected().await?;
722
723 let id = request.id.as_ref().and_then(|v| v.as_u64());
724 let body = serde_json::to_string(request)?;
725
726 let (mut message_url, mut from_endpoint) = self.get_message_url().await?;
727 if self.stream_state == SseStreamState::Connected && !from_endpoint {
728 for _ in 0..3 {
729 {
730 let guard = self.shared.lock().await;
731 if guard.message_url_from_endpoint {
732 if let Some(url) = &guard.message_url {
733 message_url = url.clone();
734 from_endpoint = true;
735 break;
736 }
737 }
738 }
739 let _ = timeout(Duration::from_millis(300), self.notify.notified()).await;
740 }
741 }
742 let primary_url = if from_endpoint {
743 message_url.clone()
744 } else {
745 self.sse_url.clone()
746 };
747 let secondary_url = if message_url == self.sse_url {
748 None
749 } else if primary_url == message_url {
750 Some(self.sse_url.clone())
751 } else {
752 Some(message_url.clone())
753 };
754 let has_secondary = secondary_url.is_some();
755
756 let mut rx = None;
757 if let Some(id) = id {
758 if self.stream_state == SseStreamState::Connected {
759 let (tx, ch) = oneshot::channel();
760 {
761 let mut guard = self.shared.lock().await;
762 guard.pending.insert(id, tx);
763 }
764 rx = Some((id, ch));
765 }
766 }
767
768 let mut got_direct = None;
769 let mut last_status = None;
770
771 for (i, url) in std::iter::once(primary_url)
772 .chain(secondary_url.into_iter())
773 .enumerate()
774 {
775 let has_accept = self
776 .headers
777 .keys()
778 .any(|k| k.eq_ignore_ascii_case("Accept"));
779 let has_content_type = self
780 .headers
781 .keys()
782 .any(|k| k.eq_ignore_ascii_case("Content-Type"));
783 let mut req = self
784 .client
785 .post(&url)
786 .timeout(Duration::from_secs(120))
787 .body(body.clone());
788 if !has_content_type {
789 req = req.header("Content-Type", MCP_JSON_CONTENT_TYPE);
790 }
791 for (key, value) in &self.headers {
792 req = req.header(key, value);
793 }
794 if !has_accept {
795 req = req.header("Accept", MCP_STREAMABLE_ACCEPT);
796 }
797
798 let resp = req.send().await.context("SSE POST to MCP server failed")?;
799 let status = resp.status();
800 last_status = Some(status);
801
802 if (status == reqwest::StatusCode::NOT_FOUND
803 || status == reqwest::StatusCode::METHOD_NOT_ALLOWED)
804 && i == 0
805 {
806 continue;
807 }
808
809 if !status.is_success() {
810 break;
811 }
812
813 if request.id.is_none() {
814 got_direct = Some(JsonRpcResponse {
815 jsonrpc: crate::tools::mcp_protocol::JSONRPC_VERSION.to_string(),
816 id: None,
817 result: None,
818 error: None,
819 });
820 break;
821 }
822
823 let is_sse = resp
824 .headers()
825 .get(reqwest::header::CONTENT_TYPE)
826 .and_then(|v| v.to_str().ok())
827 .is_some_and(|v| v.to_ascii_lowercase().contains("text/event-stream"));
828
829 if is_sse {
830 if i == 0 && has_secondary {
831 match timeout(
832 Duration::from_secs(3),
833 read_first_jsonrpc_from_sse_response(resp),
834 )
835 .await
836 {
837 Ok(res) => {
838 if let Some(resp) = res? {
839 got_direct = Some(resp);
840 }
841 break;
842 }
843 Err(_) => continue,
844 }
845 }
846 if let Some(resp) = read_first_jsonrpc_from_sse_response(resp).await? {
847 got_direct = Some(resp);
848 }
849 break;
850 }
851
852 let text = if i == 0 && has_secondary {
853 match timeout(Duration::from_secs(3), resp.text()).await {
854 Ok(Ok(t)) => t,
855 Ok(Err(_)) => String::new(),
856 Err(_) => continue,
857 }
858 } else {
859 resp.text().await.unwrap_or_default()
860 };
861 let trimmed = text.trim();
862 if !trimmed.is_empty() {
863 let json_str = if trimmed.contains("\ndata:") || trimmed.starts_with("data:") {
864 extract_json_from_sse_text(trimmed)
865 } else {
866 Cow::Borrowed(trimmed)
867 };
868 if let Ok(mcp_resp) = serde_json::from_str::<JsonRpcResponse>(json_str.as_ref()) {
869 got_direct = Some(mcp_resp);
870 }
871 }
872 break;
873 }
874
875 if let Some((id, _)) = rx.as_ref() {
876 if got_direct.is_some() {
877 let mut guard = self.shared.lock().await;
878 guard.pending.remove(id);
879 } else if let Some(status) = last_status {
880 if !status.is_success() {
881 let mut guard = self.shared.lock().await;
882 guard.pending.remove(id);
883 }
884 }
885 }
886
887 if let Some(resp) = got_direct {
888 return Ok(resp);
889 }
890
891 if let Some(status) = last_status {
892 if !status.is_success() {
893 bail!("MCP server returned HTTP {}", status);
894 }
895 } else {
896 bail!("MCP request not sent");
897 }
898
899 let Some((_id, rx)) = rx else {
900 bail!("MCP server returned no response");
901 };
902
903 rx.await.map_err(|_| anyhow!("SSE response channel closed"))
904 }
905
906 async fn close(&mut self) -> Result<()> {
907 if let Some(tx) = self.shutdown_tx.take() {
908 let _ = tx.send(());
909 }
910 if let Some(task) = self.reader_task.take() {
911 task.abort();
912 }
913 Ok(())
914 }
915}
916
917pub fn create_transport(config: &McpServerConfig) -> Result<Box<dyn McpTransportConn>> {
921 match config.transport {
922 McpTransport::Stdio => Ok(Box::new(StdioTransport::new(config)?)),
923 McpTransport::Http => Ok(Box::new(HttpTransport::new(config)?)),
924 McpTransport::Sse => Ok(Box::new(SseTransport::new(config)?)),
925 }
926}
927
928#[cfg(test)]
931mod tests {
932 use super::*;
933
934 #[test]
935 fn test_transport_default_is_stdio() {
936 let config = McpServerConfig::default();
937 assert_eq!(config.transport, McpTransport::Stdio);
938 }
939
940 #[test]
941 fn test_http_transport_requires_url() {
942 let config = McpServerConfig {
943 name: "test".into(),
944 transport: McpTransport::Http,
945 ..Default::default()
946 };
947 assert!(HttpTransport::new(&config).is_err());
948 }
949
950 #[test]
951 fn test_sse_transport_requires_url() {
952 let config = McpServerConfig {
953 name: "test".into(),
954 transport: McpTransport::Sse,
955 ..Default::default()
956 };
957 assert!(SseTransport::new(&config).is_err());
958 }
959
960 #[test]
961 fn test_extract_json_from_sse_data_no_space() {
962 let input = "data:{\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
963 let extracted = extract_json_from_sse_text(input);
964 let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
965 }
966
967 #[test]
968 fn test_extract_json_from_sse_with_event_and_id() {
969 let input = "id: 1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
970 let extracted = extract_json_from_sse_text(input);
971 let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
972 }
973
974 #[test]
975 fn test_extract_json_from_sse_multiline_data() {
976 let input = "event: message\ndata: {\ndata: \"jsonrpc\": \"2.0\",\ndata: \"result\": {}\ndata: }\n\n";
977 let extracted = extract_json_from_sse_text(input);
978 let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
979 }
980
981 #[test]
982 fn test_extract_json_from_sse_skips_bom_and_leading_whitespace() {
983 let input = "\u{feff}\n\n data: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
984 let extracted = extract_json_from_sse_text(input);
985 let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
986 }
987
988 #[test]
989 fn test_extract_json_from_sse_uses_last_event_with_data() {
990 let input =
991 ": keep-alive\n\nid: 1\nevent: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
992 let extracted = extract_json_from_sse_text(input);
993 let _: JsonRpcResponse = serde_json::from_str(extracted.as_ref()).unwrap();
994 }
995
996 #[test]
997 fn test_parse_jsonrpc_response_text_handles_plain_json() {
998 let parsed = parse_jsonrpc_response_text("{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}")
999 .expect("plain JSON response should parse");
1000 assert_eq!(parsed.id, Some(serde_json::json!(1)));
1001 assert!(parsed.error.is_none());
1002 }
1003
1004 #[test]
1005 fn test_parse_jsonrpc_response_text_handles_sse_framed_json() {
1006 let sse =
1007 "event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":2,\"result\":{\"ok\":true}}\n\n";
1008 let parsed =
1009 parse_jsonrpc_response_text(sse).expect("SSE-framed JSON response should parse");
1010 assert_eq!(parsed.id, Some(serde_json::json!(2)));
1011 assert_eq!(
1012 parsed
1013 .result
1014 .as_ref()
1015 .and_then(|v| v.get("ok"))
1016 .and_then(|v| v.as_bool()),
1017 Some(true)
1018 );
1019 }
1020
1021 #[test]
1022 fn test_parse_jsonrpc_response_text_rejects_empty_payload() {
1023 assert!(parse_jsonrpc_response_text(" \n\t ").is_err());
1024 }
1025
1026 #[test]
1027 fn http_transport_updates_session_id_from_response_headers() {
1028 let config = McpServerConfig {
1029 name: "test-http".into(),
1030 transport: McpTransport::Http,
1031 url: Some("http://localhost/mcp".into()),
1032 ..Default::default()
1033 };
1034 let mut transport = HttpTransport::new(&config).expect("build transport");
1035
1036 let mut headers = reqwest::header::HeaderMap::new();
1037 headers.insert(
1038 reqwest::header::HeaderName::from_static("mcp-session-id"),
1039 reqwest::header::HeaderValue::from_static("session-abc"),
1040 );
1041 transport.update_session_id_from_headers(&headers);
1042 assert_eq!(transport.session_id.as_deref(), Some("session-abc"));
1043 }
1044
1045 #[test]
1046 fn http_transport_injects_session_id_header_when_available() {
1047 let config = McpServerConfig {
1048 name: "test-http".into(),
1049 transport: McpTransport::Http,
1050 url: Some("http://localhost/mcp".into()),
1051 ..Default::default()
1052 };
1053 let mut transport = HttpTransport::new(&config).expect("build transport");
1054 transport.session_id = Some("session-xyz".to_string());
1055
1056 let req = transport
1057 .apply_session_header(reqwest::Client::new().post("http://localhost/mcp"))
1058 .build()
1059 .expect("build request");
1060 assert_eq!(
1061 req.headers()
1062 .get(MCP_SESSION_ID_HEADER)
1063 .and_then(|v| v.to_str().ok()),
1064 Some("session-xyz")
1065 );
1066 }
1067
1068 #[test]
1071 fn derive_message_url_replaces_sse_segment_with_messages() {
1072 let url = derive_message_url("http://localhost:3000/mcp/sse", "messages");
1073 assert_eq!(url, Some("http://localhost:3000/mcp/messages".to_string()));
1074 }
1075
1076 #[test]
1077 fn derive_message_url_appends_when_no_sse_segment() {
1078 let url = derive_message_url("http://localhost:3000/mcp", "messages");
1079 assert_eq!(url, Some("http://localhost:3000/mcp/messages".to_string()));
1080 }
1081
1082 #[test]
1083 fn derive_message_url_returns_none_for_invalid_url() {
1084 let url = derive_message_url("not-a-url", "messages");
1085 assert!(url.is_none());
1086 }
1087
1088 #[test]
1089 fn derive_message_url_message_path_variant() {
1090 let url = derive_message_url("http://localhost:3000/mcp/sse", "message");
1091 assert_eq!(url, Some("http://localhost:3000/mcp/message".to_string()));
1092 }
1093
1094 #[test]
1097 fn parse_endpoint_absolute_http_url_returned_as_is() {
1098 let result = parse_endpoint_from_data("http://base/sse", "http://other/messages");
1099 assert_eq!(result, Some("http://other/messages".to_string()));
1100 }
1101
1102 #[test]
1103 fn parse_endpoint_absolute_https_url_returned_as_is() {
1104 let result = parse_endpoint_from_data("https://base/sse", "https://other/messages");
1105 assert_eq!(result, Some("https://other/messages".to_string()));
1106 }
1107
1108 #[test]
1109 fn parse_endpoint_relative_path_resolved_against_base() {
1110 let result = parse_endpoint_from_data("http://localhost:3000/sse", "/messages");
1111 assert_eq!(result, Some("http://localhost:3000/messages".to_string()));
1112 }
1113
1114 #[test]
1115 fn parse_endpoint_json_object_with_endpoint_key() {
1116 let json_data = r#"{"endpoint":"/messages"}"#;
1117 let result = parse_endpoint_from_data("http://localhost:3000/sse", json_data);
1118 assert_eq!(result, Some("http://localhost:3000/messages".to_string()));
1119 }
1120
1121 #[test]
1124 fn looks_like_sse_text_detects_data_prefix() {
1125 assert!(looks_like_sse_text("data:{\"jsonrpc\":\"2.0\"}"));
1126 }
1127
1128 #[test]
1129 fn looks_like_sse_text_detects_event_prefix() {
1130 assert!(looks_like_sse_text("event: message\ndata: {}"));
1131 }
1132
1133 #[test]
1134 fn looks_like_sse_text_detects_embedded_data_line() {
1135 assert!(looks_like_sse_text("id: 1\ndata:{\"x\":1}"));
1136 }
1137
1138 #[test]
1139 fn looks_like_sse_text_plain_json_is_not_sse() {
1140 assert!(!looks_like_sse_text(
1141 "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}"
1142 ));
1143 }
1144
1145 #[test]
1148 fn extract_json_skips_comment_lines() {
1149 let input = ": keep-alive\ndata: {\"jsonrpc\":\"2.0\",\"result\":{}}\n\n";
1150 let extracted = extract_json_from_sse_text(input);
1151 let v: serde_json::Value = serde_json::from_str(extracted.as_ref()).unwrap();
1152 assert_eq!(v["jsonrpc"], "2.0");
1153 }
1154
1155 #[test]
1156 fn extract_json_empty_input_returns_empty_trimmed() {
1157 let result = extract_json_from_sse_text(" ");
1158 assert!(result.as_ref().trim().is_empty());
1159 }
1160
1161 #[test]
1162 fn extract_json_plain_json_returned_unchanged() {
1163 let input = "{\"jsonrpc\":\"2.0\",\"result\":{}}";
1164 let extracted = extract_json_from_sse_text(input);
1165 assert_eq!(extracted.as_ref(), input);
1167 }
1168
1169 #[test]
1172 fn parse_jsonrpc_response_rejects_whitespace_only() {
1173 assert!(parse_jsonrpc_response_text(" \n\t ").is_err());
1174 }
1175
1176 #[test]
1177 fn parse_jsonrpc_response_with_error_result() {
1178 let json = r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32601,"message":"not found"}}"#;
1179 let resp = parse_jsonrpc_response_text(json).unwrap();
1180 assert!(resp.error.is_some());
1181 assert_eq!(resp.error.unwrap().code, -32601);
1182 }
1183
1184 #[test]
1187 fn create_transport_stdio_fails_without_valid_command() {
1188 let config = McpServerConfig {
1190 name: "test-stdio".into(),
1191 transport: McpTransport::Stdio,
1192 command: "/usr/bin/construct_nonexistent_binary_abc123".into(),
1193 ..Default::default()
1194 };
1195 let result = create_transport(&config);
1196 assert!(result.is_err());
1197 }
1198
1199 #[test]
1200 fn create_transport_http_without_url_fails() {
1201 let config = McpServerConfig {
1202 name: "test-http".into(),
1203 transport: McpTransport::Http,
1204 ..Default::default()
1205 };
1206 assert!(create_transport(&config).is_err());
1207 }
1208
1209 #[test]
1210 fn create_transport_sse_without_url_fails() {
1211 let config = McpServerConfig {
1212 name: "test-sse".into(),
1213 transport: McpTransport::Sse,
1214 ..Default::default()
1215 };
1216 assert!(create_transport(&config).is_err());
1217 }
1218
1219 #[test]
1220 fn create_transport_http_with_url_succeeds() {
1221 let config = McpServerConfig {
1222 name: "test-http".into(),
1223 transport: McpTransport::Http,
1224 url: Some("http://localhost:9999/mcp".into()),
1225 ..Default::default()
1226 };
1227 assert!(create_transport(&config).is_ok());
1229 }
1230
1231 #[test]
1232 fn create_transport_sse_with_url_succeeds() {
1233 let config = McpServerConfig {
1234 name: "test-sse".into(),
1235 transport: McpTransport::Sse,
1236 url: Some("http://localhost:9999/sse".into()),
1237 ..Default::default()
1238 };
1239 assert!(create_transport(&config).is_ok());
1240 }
1241
1242 #[test]
1245 fn http_transport_ignores_empty_session_id_header() {
1246 let config = McpServerConfig {
1247 name: "test-http".into(),
1248 transport: McpTransport::Http,
1249 url: Some("http://localhost/mcp".into()),
1250 ..Default::default()
1251 };
1252 let mut transport = HttpTransport::new(&config).expect("build transport");
1253 let mut headers = reqwest::header::HeaderMap::new();
1254 headers.insert(
1255 reqwest::header::HeaderName::from_static("mcp-session-id"),
1256 reqwest::header::HeaderValue::from_static(" "),
1257 );
1258 transport.update_session_id_from_headers(&headers);
1259 assert!(transport.session_id.is_none());
1261 }
1262
1263 #[test]
1264 fn http_transport_no_session_header_leaves_none() {
1265 let config = McpServerConfig {
1266 name: "test-http".into(),
1267 transport: McpTransport::Http,
1268 url: Some("http://localhost/mcp".into()),
1269 ..Default::default()
1270 };
1271 let transport = HttpTransport::new(&config).expect("build transport");
1272 assert!(transport.session_id.is_none());
1273 }
1274
1275 #[test]
1276 fn http_transport_apply_session_header_noop_when_no_session() {
1277 let config = McpServerConfig {
1278 name: "test-http".into(),
1279 transport: McpTransport::Http,
1280 url: Some("http://localhost/mcp".into()),
1281 ..Default::default()
1282 };
1283 let transport = HttpTransport::new(&config).expect("build transport");
1284 let req = transport
1285 .apply_session_header(reqwest::Client::new().post("http://localhost/mcp"))
1286 .build()
1287 .expect("build request");
1288 assert!(req.headers().get(MCP_SESSION_ID_HEADER).is_none());
1289 }
1290}