Skip to main content

agentkit_mcp/
lib.rs

1use std::collections::{BTreeMap, VecDeque};
2use std::fmt;
3use std::process::Stdio;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use agentkit_capabilities::{
9    CapabilityContext, CapabilityError, CapabilityName, CapabilityProvider, Invocable,
10    InvocableOutput, InvocableRequest, InvocableResult, InvocableSpec, PromptContents,
11    PromptDescriptor, PromptId, PromptProvider, ResourceContents, ResourceDescriptor, ResourceId,
12    ResourceProvider,
13};
14use agentkit_core::{
15    DataRef, Item, ItemKind, MetadataMap, Part, TextPart, ToolOutput, ToolResultPart,
16};
17use agentkit_tools_core::{
18    AuthOperation, AuthRequest, AuthResolution, Tool, ToolAnnotations, ToolContext, ToolError,
19    ToolName, ToolRegistry, ToolRequest, ToolResult, ToolSpec,
20};
21use async_trait::async_trait;
22use futures_util::TryStreamExt;
23use reqwest::{Client, StatusCode, Url};
24use serde::{Deserialize, Serialize};
25use serde_json::{Value, json};
26use thiserror::Error;
27use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader};
28use tokio::process::{Child, ChildStdin, ChildStdout, Command};
29use tokio::sync::{Mutex, mpsc, oneshot};
30use tokio::task::JoinHandle;
31use tokio::time::sleep;
32use tokio_util::io::StreamReader;
33
34const MCP_LATEST_PROTOCOL_VERSION: &str = "2025-11-25";
35const MCP_SUPPORTED_PROTOCOL_VERSIONS: &[&str] =
36    &["2025-11-25", "2025-06-18", "2025-03-26", "2024-11-05"];
37
38/// Unique identifier for a registered MCP server.
39///
40/// Each MCP server in a [`McpServerManager`] is addressed by its `McpServerId`.
41/// The inner string is typically a short, human-readable name such as `"filesystem"`
42/// or `"github"`.
43///
44/// # Example
45///
46/// ```rust
47/// use agentkit_mcp::McpServerId;
48///
49/// let id = McpServerId::new("filesystem");
50/// assert_eq!(id.to_string(), "filesystem");
51/// ```
52#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
53pub struct McpServerId(pub String);
54
55impl McpServerId {
56    /// Creates a new server identifier from any string-like value.
57    pub fn new(value: impl Into<String>) -> Self {
58        Self(value.into())
59    }
60}
61
62impl fmt::Display for McpServerId {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        self.0.fmt(f)
65    }
66}
67
68/// Configuration for an MCP server that communicates over standard I/O (stdin/stdout).
69///
70/// This is the most common transport for local MCP servers. The specified command is
71/// spawned as a child process, and JSON-RPC messages are exchanged line-by-line over
72/// its stdin and stdout streams.
73///
74/// # Example
75///
76/// ```rust
77/// use agentkit_mcp::StdioTransportConfig;
78///
79/// let config = StdioTransportConfig::new("npx")
80///     .with_arg("-y")
81///     .with_arg("@modelcontextprotocol/server-filesystem")
82///     .with_env("HOME", "/home/user")
83///     .with_cwd("/tmp");
84/// ```
85#[derive(Clone, Debug, PartialEq, Eq)]
86pub struct StdioTransportConfig {
87    /// The executable to launch (e.g. `"npx"`, `"python"`, `"node"`).
88    pub command: String,
89    /// Command-line arguments passed to the executable.
90    pub args: Vec<String>,
91    /// Additional environment variables set for the child process.
92    pub env: Vec<(String, String)>,
93    /// Optional working directory for the child process.
94    pub cwd: Option<std::path::PathBuf>,
95}
96
97impl StdioTransportConfig {
98    /// Creates a new stdio transport configuration for the given command.
99    pub fn new(command: impl Into<String>) -> Self {
100        Self {
101            command: command.into(),
102            args: Vec::new(),
103            env: Vec::new(),
104            cwd: None,
105        }
106    }
107
108    /// Appends a command-line argument. Returns `self` for chaining.
109    pub fn with_arg(mut self, arg: impl Into<String>) -> Self {
110        self.args.push(arg.into());
111        self
112    }
113
114    /// Adds an environment variable for the child process. Returns `self` for chaining.
115    pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
116        self.env.push((key.into(), value.into()));
117        self
118    }
119
120    /// Sets the working directory for the child process. Returns `self` for chaining.
121    pub fn with_cwd(mut self, cwd: impl Into<std::path::PathBuf>) -> Self {
122        self.cwd = Some(cwd.into());
123        self
124    }
125}
126
127/// Configuration for an MCP server that communicates over Server-Sent Events (SSE).
128///
129/// Use this transport for remote MCP servers exposed over HTTP. The client opens an
130/// SSE stream to the given URL, receives an `endpoint` event pointing to the POST
131/// endpoint, and then exchanges JSON-RPC messages over that endpoint.
132///
133/// # Example
134///
135/// ```rust
136/// use agentkit_mcp::SseTransportConfig;
137///
138/// let config = SseTransportConfig::new("https://mcp.example.com/sse")
139///     .with_header("Authorization", "Bearer tok_abc123");
140/// ```
141#[derive(Clone, Debug, PartialEq, Eq)]
142pub struct SseTransportConfig {
143    /// The SSE endpoint URL to connect to.
144    pub url: String,
145    /// Additional HTTP headers sent with every request (e.g. authentication tokens).
146    pub headers: Vec<(String, String)>,
147}
148
149impl SseTransportConfig {
150    /// Creates a new SSE transport configuration for the given URL.
151    pub fn new(url: impl Into<String>) -> Self {
152        Self {
153            url: url.into(),
154            headers: Vec::new(),
155        }
156    }
157
158    /// Adds an HTTP header to include with every request. Returns `self` for chaining.
159    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
160        self.headers.push((key.into(), value.into()));
161        self
162    }
163}
164
165/// Configuration for an MCP server that communicates over Streamable HTTP.
166///
167/// Use this transport for modern remote MCP servers that expose a single HTTP
168/// endpoint supporting JSON-RPC over POST, with optional SSE responses for
169/// streaming server messages.
170///
171/// # Example
172///
173/// ```rust
174/// use agentkit_mcp::StreamableHttpTransportConfig;
175///
176/// let config = StreamableHttpTransportConfig::new("https://mcp.example.com/mcp")
177///     .with_header("Authorization", "Bearer tok_abc123");
178/// ```
179#[derive(Clone, Debug, PartialEq, Eq)]
180pub struct StreamableHttpTransportConfig {
181    /// The MCP endpoint URL to connect to.
182    pub url: String,
183    /// Additional HTTP headers sent with every request (e.g. authentication tokens).
184    pub headers: Vec<(String, String)>,
185}
186
187impl StreamableHttpTransportConfig {
188    /// Creates a new Streamable HTTP transport configuration for the given MCP endpoint.
189    pub fn new(url: impl Into<String>) -> Self {
190        Self {
191            url: url.into(),
192            headers: Vec::new(),
193        }
194    }
195
196    /// Adds an HTTP header to include with every request. Returns `self` for chaining.
197    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
198        self.headers.push((key.into(), value.into()));
199        self
200    }
201}
202
203/// Selects which transport an MCP server should use.
204///
205/// This enum is passed into [`McpServerConfig`] and determines how the client will
206/// communicate with the MCP server. The built-in options are [`Stdio`](Self::Stdio),
207/// [`StreamableHttp`](Self::StreamableHttp), and the legacy [`Sse`](Self::Sse);
208/// use [`Custom`](Self::Custom) to provide your own [`McpTransportFactory`].
209#[derive(Clone)]
210pub enum McpTransportBinding {
211    /// Communicate over the child process's stdin/stdout.
212    Stdio(StdioTransportConfig),
213    /// Communicate over the MCP Streamable HTTP transport.
214    StreamableHttp(StreamableHttpTransportConfig),
215    /// Communicate over HTTP Server-Sent Events.
216    Sse(SseTransportConfig),
217    /// A user-supplied transport factory.
218    Custom(Arc<dyn McpTransportFactory>),
219}
220
221/// Full configuration for a single MCP server, combining an identifier, a transport
222/// binding, and optional metadata.
223///
224/// Register one or more of these with [`McpServerManager`] to manage the lifecycle
225/// of MCP servers in an agentkit runtime.
226///
227/// # Example
228///
229/// ```rust
230/// use agentkit_mcp::{McpServerConfig, McpTransportBinding, StdioTransportConfig};
231///
232/// let config = McpServerConfig::new(
233///     "filesystem",
234///     McpTransportBinding::Stdio(
235///         StdioTransportConfig::new("npx")
236///             .with_arg("-y")
237///             .with_arg("@modelcontextprotocol/server-filesystem"),
238///     ),
239/// );
240/// ```
241#[derive(Clone)]
242pub struct McpServerConfig {
243    /// Unique identifier for this server.
244    pub id: McpServerId,
245    /// Transport binding that determines how communication happens.
246    pub transport: McpTransportBinding,
247    /// Arbitrary metadata attached to this server configuration.
248    pub metadata: MetadataMap,
249}
250
251impl McpServerConfig {
252    /// Creates a new server configuration with the given identifier and transport.
253    ///
254    /// # Arguments
255    ///
256    /// * `id` - A unique name for this server (e.g. `"filesystem"`).
257    /// * `transport` - The [`McpTransportBinding`] that determines how to connect.
258    pub fn new(id: impl Into<String>, transport: McpTransportBinding) -> Self {
259        Self {
260            id: McpServerId::new(id),
261            transport,
262            metadata: MetadataMap::new(),
263        }
264    }
265
266    /// Creates a stdio-backed server configuration.
267    pub fn stdio(id: impl Into<String>, command: impl Into<String>) -> Self {
268        Self::new(
269            id,
270            McpTransportBinding::Stdio(StdioTransportConfig::new(command)),
271        )
272    }
273
274    /// Creates an SSE-backed server configuration.
275    pub fn sse(id: impl Into<String>, url: impl Into<String>) -> Self {
276        Self::new(id, McpTransportBinding::Sse(SseTransportConfig::new(url)))
277    }
278
279    /// Creates a Streamable HTTP-backed server configuration.
280    pub fn streamable_http(id: impl Into<String>, url: impl Into<String>) -> Self {
281        Self::new(
282            id,
283            McpTransportBinding::StreamableHttp(StreamableHttpTransportConfig::new(url)),
284        )
285    }
286
287    /// Replaces the configuration metadata.
288    pub fn with_metadata(mut self, metadata: MetadataMap) -> Self {
289        self.metadata = metadata;
290        self
291    }
292}
293
294/// A single JSON-RPC frame exchanged with an MCP server.
295///
296/// This is the low-level wire unit. Most users will not interact with `McpFrame`
297/// directly; instead use [`McpConnection`] or the higher-level adapters.
298#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
299pub struct McpFrame {
300    /// The raw JSON-RPC value (request, response, or notification).
301    pub value: Value,
302}
303
304/// Factory trait for creating new [`McpTransport`] connections.
305///
306/// Implement this trait to provide a custom transport mechanism. The built-in
307/// [`StdioTransportFactory`] and [`SseTransportFactory`] cover the two standard
308/// MCP transports; use this trait for in-memory, WebSocket, or other custom
309/// transports.
310///
311/// # Errors
312///
313/// Returns [`McpError`] if the connection cannot be established.
314#[async_trait]
315pub trait McpTransportFactory: Send + Sync {
316    /// Establishes a new transport connection and returns it.
317    async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError>;
318}
319
320/// Bidirectional transport for exchanging [`McpFrame`] messages with an MCP server.
321///
322/// Implement this trait to provide a custom transport. Each transport instance
323/// represents a single, live connection.
324///
325/// # Errors
326///
327/// All methods return [`McpError`] on I/O or protocol failures.
328#[async_trait]
329pub trait McpTransport: Send + Sync {
330    /// Sends a JSON-RPC frame to the server.
331    async fn send(&mut self, message: McpFrame) -> Result<(), McpError>;
332    /// Receives the next JSON-RPC frame from the server, or `None` if the stream has ended.
333    async fn recv(&mut self) -> Result<Option<McpFrame>, McpError>;
334    /// Closes the transport, releasing any underlying resources.
335    async fn close(&mut self) -> Result<(), McpError>;
336}
337
338/// Factory that spawns a child process and connects via stdin/stdout.
339///
340/// Created from a [`StdioTransportConfig`]. Each call to
341/// [`connect`](McpTransportFactory::connect) spawns a new child process.
342pub struct StdioTransportFactory {
343    config: StdioTransportConfig,
344}
345
346impl StdioTransportFactory {
347    /// Creates a new factory from the given stdio transport configuration.
348    pub fn new(config: StdioTransportConfig) -> Self {
349        Self { config }
350    }
351}
352
353#[async_trait]
354impl McpTransportFactory for StdioTransportFactory {
355    async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
356        let mut command = Command::new(&self.config.command);
357        command.args(&self.config.args);
358        command.stdin(Stdio::piped());
359        command.stdout(Stdio::piped());
360        command.stderr(Stdio::inherit());
361
362        if let Some(cwd) = &self.config.cwd {
363            command.current_dir(cwd);
364        }
365
366        for (key, value) in &self.config.env {
367            command.env(key, value);
368        }
369
370        let mut child = command.spawn().map_err(McpError::Io)?;
371        let stdin = child
372            .stdin
373            .take()
374            .ok_or_else(|| McpError::Transport("failed to capture MCP stdin".into()))?;
375        let stdout = child
376            .stdout
377            .take()
378            .ok_or_else(|| McpError::Transport("failed to capture MCP stdout".into()))?;
379
380        Ok(Box::new(StdioTransport {
381            child,
382            stdin,
383            stdout: BufReader::new(stdout),
384        }))
385    }
386}
387
388/// Factory that opens an HTTP SSE stream and connects via Server-Sent Events.
389///
390/// Created from an [`SseTransportConfig`]. Each call to
391/// [`connect`](McpTransportFactory::connect) opens a new HTTP connection.
392pub struct SseTransportFactory {
393    config: SseTransportConfig,
394}
395
396impl SseTransportFactory {
397    /// Creates a new factory from the given SSE transport configuration.
398    pub fn new(config: SseTransportConfig) -> Self {
399        Self { config }
400    }
401}
402
403/// Factory that connects to a Streamable HTTP MCP endpoint.
404///
405/// Created from a [`StreamableHttpTransportConfig`]. Each call to
406/// [`connect`](McpTransportFactory::connect) creates a new HTTP-backed MCP session.
407pub struct StreamableHttpTransportFactory {
408    config: StreamableHttpTransportConfig,
409}
410
411impl StreamableHttpTransportFactory {
412    /// Creates a new factory from the given Streamable HTTP transport configuration.
413    pub fn new(config: StreamableHttpTransportConfig) -> Self {
414        Self { config }
415    }
416}
417
418#[async_trait]
419impl McpTransportFactory for SseTransportFactory {
420    async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
421        let client = Client::builder()
422            .user_agent(concat!("agentkit-mcp/", env!("CARGO_PKG_VERSION")))
423            .build()
424            .map_err(McpError::Http)?;
425
426        let mut request = client
427            .get(&self.config.url)
428            .header("Accept", "text/event-stream")
429            .header("Cache-Control", "no-cache");
430
431        for (key, value) in &self.config.headers {
432            request = request.header(key, value);
433        }
434
435        let response = request.send().await.map_err(McpError::Http)?;
436        let status = response.status();
437        if !status.is_success() {
438            let body = response
439                .text()
440                .await
441                .unwrap_or_else(|_| "<unreadable response body>".into());
442            return Err(McpError::Transport(format!(
443                "SSE connection failed with status {status}: {body}"
444            )));
445        }
446
447        let response_url = response.url().clone();
448        let stream = response.bytes_stream().map_err(std::io::Error::other);
449        let reader = BufReader::new(StreamReader::new(stream));
450        let (frame_tx, frame_rx) = mpsc::unbounded_channel();
451        let (endpoint_tx, endpoint_rx) = oneshot::channel();
452        let read_task = tokio::spawn(read_sse_stream(reader, response_url, frame_tx, endpoint_tx));
453
454        let endpoint_url = endpoint_rx
455            .await
456            .map_err(|_| McpError::Transport("SSE stream closed before endpoint event".into()))??;
457
458        Ok(Box::new(SseTransport {
459            client,
460            endpoint_url,
461            headers: self.config.headers.clone(),
462            frame_rx,
463            read_task,
464        }))
465    }
466}
467
468#[async_trait]
469impl McpTransportFactory for StreamableHttpTransportFactory {
470    async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
471        let client = Client::builder()
472            .user_agent(concat!("agentkit-mcp/", env!("CARGO_PKG_VERSION")))
473            .build()
474            .map_err(McpError::Http)?;
475
476        let endpoint_url = Url::parse(&self.config.url)
477            .map_err(|error| McpError::Transport(format!("invalid MCP endpoint URL: {error}")))?;
478
479        Ok(Box::new(StreamableHttpTransport {
480            client,
481            endpoint_url,
482            headers: self.config.headers.clone(),
483            protocol_version: None,
484            session_id: None,
485            pending_frames: VecDeque::new(),
486        }))
487    }
488}
489
490struct StdioTransport {
491    child: Child,
492    stdin: ChildStdin,
493    stdout: BufReader<ChildStdout>,
494}
495
496struct SseTransport {
497    client: Client,
498    endpoint_url: Url,
499    headers: Vec<(String, String)>,
500    frame_rx: mpsc::UnboundedReceiver<Result<McpFrame, McpError>>,
501    read_task: JoinHandle<()>,
502}
503
504struct StreamableHttpTransport {
505    client: Client,
506    endpoint_url: Url,
507    headers: Vec<(String, String)>,
508    protocol_version: Option<String>,
509    session_id: Option<String>,
510    pending_frames: VecDeque<McpFrame>,
511}
512
513#[async_trait]
514impl McpTransport for StdioTransport {
515    async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
516        let mut encoded = serde_json::to_vec(&message.value).map_err(McpError::Serialize)?;
517        encoded.push(b'\n');
518        self.stdin.write_all(&encoded).await.map_err(McpError::Io)?;
519        self.stdin.flush().await.map_err(McpError::Io)?;
520        Ok(())
521    }
522
523    async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
524        let mut line = String::new();
525        let read = self
526            .stdout
527            .read_line(&mut line)
528            .await
529            .map_err(McpError::Io)?;
530        if read == 0 {
531            return Ok(None);
532        }
533
534        let value = serde_json::from_str(line.trim()).map_err(McpError::Serialize)?;
535        Ok(Some(McpFrame { value }))
536    }
537
538    async fn close(&mut self) -> Result<(), McpError> {
539        let _ = self.stdin.shutdown().await;
540        let _ = self.child.kill().await;
541        Ok(())
542    }
543}
544
545#[async_trait]
546impl McpTransport for SseTransport {
547    async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
548        let mut request = self
549            .client
550            .post(self.endpoint_url.clone())
551            .header("Content-Type", "application/json");
552
553        for (key, value) in &self.headers {
554            request = request.header(key, value);
555        }
556
557        let response = request
558            .json(&message.value)
559            .send()
560            .await
561            .map_err(McpError::Http)?;
562        let status = response.status();
563        if !status.is_success() {
564            let body = response
565                .text()
566                .await
567                .unwrap_or_else(|_| "<unreadable response body>".into());
568            return Err(McpError::Transport(format!(
569                "SSE POST failed with status {status}: {body}"
570            )));
571        }
572
573        Ok(())
574    }
575
576    async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
577        match self.frame_rx.recv().await {
578            Some(Ok(frame)) => Ok(Some(frame)),
579            Some(Err(error)) => Err(error),
580            None => Ok(None),
581        }
582    }
583
584    async fn close(&mut self) -> Result<(), McpError> {
585        self.read_task.abort();
586        Ok(())
587    }
588}
589
590#[async_trait]
591impl McpTransport for StreamableHttpTransport {
592    async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
593        let is_request = is_jsonrpc_request(&message.value);
594        let request_id = message.value.get("id").cloned();
595        let is_initialize =
596            message.value.get("method").and_then(Value::as_str) == Some("initialize");
597
598        let mut request = self
599            .client
600            .post(self.endpoint_url.clone())
601            .header("Content-Type", "application/json")
602            .header("Accept", "application/json, text/event-stream");
603
604        request = apply_streamable_http_headers(
605            request,
606            &self.headers,
607            self.protocol_version.as_deref(),
608            self.session_id.as_deref(),
609        );
610
611        let response = request
612            .json(&message.value)
613            .send()
614            .await
615            .map_err(McpError::Http)?;
616
617        if is_initialize {
618            self.capture_session_id(response.headers());
619        }
620
621        let status = response.status();
622        if !status.is_success() {
623            return Err(
624                streamable_http_status_error("Streamable HTTP POST", status, response).await,
625            );
626        }
627
628        if !is_request {
629            return Ok(());
630        }
631
632        let content_type = response
633            .headers()
634            .get(reqwest::header::CONTENT_TYPE)
635            .and_then(|value| value.to_str().ok())
636            .unwrap_or_default()
637            .to_string();
638
639        if content_type.starts_with("application/json") {
640            let value = response.json::<Value>().await.map_err(McpError::Http)?;
641            self.maybe_update_protocol_version(&message.value, &value)?;
642            self.pending_frames.push_back(McpFrame { value });
643            return Ok(());
644        }
645
646        if !content_type.starts_with("text/event-stream") {
647            let body = response
648                .text()
649                .await
650                .unwrap_or_else(|_| "<unreadable response body>".into());
651            return Err(McpError::Transport(format!(
652                "unexpected Streamable HTTP response content type {content_type:?}: {body}"
653            )));
654        }
655
656        let request_id = request_id.ok_or_else(|| {
657            McpError::Protocol("JSON-RPC request over Streamable HTTP is missing an id".into())
658        })?;
659        self.collect_streamable_http_response(response, &message.value, &request_id)
660            .await
661    }
662
663    async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
664        Ok(self.pending_frames.pop_front())
665    }
666
667    async fn close(&mut self) -> Result<(), McpError> {
668        let Some(session_id) = self.session_id.clone() else {
669            return Ok(());
670        };
671
672        let mut request = self.client.delete(self.endpoint_url.clone());
673        request = apply_streamable_http_headers(
674            request,
675            &self.headers,
676            self.protocol_version.as_deref(),
677            Some(session_id.as_str()),
678        );
679
680        let response = request.send().await.map_err(McpError::Http)?;
681        if response.status().is_success()
682            || response.status() == StatusCode::METHOD_NOT_ALLOWED
683            || response.status() == StatusCode::NOT_FOUND
684        {
685            self.session_id = None;
686            return Ok(());
687        }
688
689        Err(
690            streamable_http_status_error("Streamable HTTP DELETE", response.status(), response)
691                .await,
692        )
693    }
694}
695
696impl StreamableHttpTransport {
697    async fn collect_streamable_http_response(
698        &mut self,
699        response: reqwest::Response,
700        request_message: &Value,
701        request_id: &Value,
702    ) -> Result<(), McpError> {
703        let mut retry_delay = Duration::from_millis(0);
704        let mut last_event_id = None;
705        let mut saw_response = false;
706
707        saw_response |= self
708            .read_streamable_http_events(
709                response,
710                request_message,
711                request_id,
712                &mut last_event_id,
713                &mut retry_delay,
714            )
715            .await?;
716
717        while !saw_response && last_event_id.is_some() {
718            if !retry_delay.is_zero() {
719                sleep(retry_delay).await;
720            }
721
722            let response = self
723                .resume_streamable_http_stream(last_event_id.as_deref().unwrap())
724                .await?;
725            saw_response |= self
726                .read_streamable_http_events(
727                    response,
728                    request_message,
729                    request_id,
730                    &mut last_event_id,
731                    &mut retry_delay,
732                )
733                .await?;
734        }
735
736        Ok(())
737    }
738
739    async fn read_streamable_http_events(
740        &mut self,
741        response: reqwest::Response,
742        request_message: &Value,
743        request_id: &Value,
744        last_event_id: &mut Option<String>,
745        retry_delay: &mut Duration,
746    ) -> Result<bool, McpError> {
747        let stream = response.bytes_stream().map_err(std::io::Error::other);
748        let mut reader = BufReader::new(StreamReader::new(stream));
749        let mut saw_response = false;
750
751        while let Some(event) = read_next_sse_event(&mut reader).await? {
752            if let Some(id) = event.id.clone() {
753                *last_event_id = Some(id);
754            }
755            if let Some(retry_ms) = event.retry_ms {
756                *retry_delay = Duration::from_millis(retry_ms);
757            }
758
759            let Some(frame) = streamable_http_event_to_frame(event)? else {
760                continue;
761            };
762
763            self.maybe_update_protocol_version(request_message, &frame.value)?;
764            if frame.value.get("id") == Some(request_id) {
765                saw_response = true;
766            }
767            self.pending_frames.push_back(frame);
768        }
769
770        Ok(saw_response)
771    }
772
773    async fn resume_streamable_http_stream(
774        &self,
775        last_event_id: &str,
776    ) -> Result<reqwest::Response, McpError> {
777        let mut request = self
778            .client
779            .get(self.endpoint_url.clone())
780            .header("Accept", "text/event-stream")
781            .header("Cache-Control", "no-cache")
782            .header("Last-Event-ID", last_event_id);
783
784        request = apply_streamable_http_headers(
785            request,
786            &self.headers,
787            self.protocol_version.as_deref(),
788            self.session_id.as_deref(),
789        );
790
791        let response = request.send().await.map_err(McpError::Http)?;
792        let status = response.status();
793        if !status.is_success() {
794            return Err(
795                streamable_http_status_error("Streamable HTTP GET", status, response).await,
796            );
797        }
798
799        let content_type = response
800            .headers()
801            .get(reqwest::header::CONTENT_TYPE)
802            .and_then(|value| value.to_str().ok())
803            .unwrap_or_default();
804        if !content_type.starts_with("text/event-stream") {
805            let content_type = content_type.to_string();
806            let body = response
807                .text()
808                .await
809                .unwrap_or_else(|_| "<unreadable response body>".into());
810            return Err(McpError::Transport(format!(
811                "Streamable HTTP GET expected text/event-stream, got {content_type:?}: {body}"
812            )));
813        }
814
815        Ok(response)
816    }
817
818    fn maybe_update_protocol_version(
819        &mut self,
820        request_message: &Value,
821        response_value: &Value,
822    ) -> Result<(), McpError> {
823        if request_message.get("method").and_then(Value::as_str) != Some("initialize") {
824            return Ok(());
825        }
826
827        let protocol_version = response_value
828            .get("result")
829            .and_then(|result| result.get("protocolVersion"))
830            .and_then(Value::as_str);
831
832        if let Some(protocol_version) = protocol_version {
833            self.protocol_version = Some(protocol_version.to_string());
834        }
835
836        Ok(())
837    }
838
839    fn capture_session_id(&mut self, headers: &reqwest::header::HeaderMap) {
840        self.session_id = headers
841            .get("MCP-Session-Id")
842            .and_then(|value| value.to_str().ok())
843            .map(|value| value.to_string());
844    }
845}
846
847/// Descriptor for a tool advertised by an MCP server.
848///
849/// Returned as part of a [`McpDiscoverySnapshot`] after server discovery. The
850/// [`input_schema`](Self::input_schema) field is the JSON Schema that describes
851/// the tool's expected input.
852#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
853pub struct McpToolDescriptor {
854    /// The tool name as reported by the MCP server.
855    pub name: String,
856    /// Optional human-readable description of the tool.
857    pub description: Option<String>,
858    /// JSON Schema describing the tool's input parameters.
859    pub input_schema: Value,
860    /// Arbitrary metadata attached to this descriptor.
861    pub metadata: MetadataMap,
862}
863
864/// Descriptor for a resource advertised by an MCP server.
865///
866/// Resources represent data that the server can provide (e.g. files, database
867/// records). Each resource is identified by a URI.
868#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
869pub struct McpResourceDescriptor {
870    /// The resource URI (e.g. `"file:///tmp/example.txt"`).
871    pub id: String,
872    /// Human-readable name of the resource.
873    pub name: String,
874    /// Optional description of the resource.
875    pub description: Option<String>,
876    /// Optional MIME type (e.g. `"text/plain"`, `"application/json"`).
877    pub mime_type: Option<String>,
878    /// Arbitrary metadata attached to this descriptor.
879    pub metadata: MetadataMap,
880}
881
882/// Descriptor for a prompt template advertised by an MCP server.
883///
884/// Prompts are reusable message templates that can be parameterized with arguments.
885/// The [`input_schema`](Self::input_schema) describes the expected arguments.
886#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
887pub struct McpPromptDescriptor {
888    /// Unique identifier for the prompt (typically the same as `name`).
889    pub id: String,
890    /// Human-readable name of the prompt.
891    pub name: String,
892    /// Optional description of what the prompt does.
893    pub description: Option<String>,
894    /// JSON Schema describing the prompt's input arguments.
895    pub input_schema: Value,
896    /// Arbitrary metadata attached to this descriptor.
897    pub metadata: MetadataMap,
898}
899
900/// A snapshot of all capabilities discovered from a single MCP server.
901///
902/// Obtained by calling [`McpConnection::discover`] or as part of a
903/// [`McpServerHandle`]. Contains the full list of tools, resources, and prompts
904/// that the server advertised at discovery time.
905#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
906pub struct McpDiscoverySnapshot {
907    /// The server this snapshot was taken from.
908    pub server_id: McpServerId,
909    /// Tools advertised by the server.
910    pub tools: Vec<McpToolDescriptor>,
911    /// Resources advertised by the server.
912    pub resources: Vec<McpResourceDescriptor>,
913    /// Prompts advertised by the server.
914    pub prompts: Vec<McpPromptDescriptor>,
915    /// Arbitrary metadata attached to this snapshot.
916    pub metadata: MetadataMap,
917}
918
919/// A live connection to a single MCP server.
920///
921/// Handles JSON-RPC request/response framing, automatic auth enrichment, and
922/// high-level methods for tool calls, resource reads, prompt retrieval, and
923/// capability discovery.
924///
925/// Create a connection with [`McpConnection::connect`] or indirectly through
926/// [`McpServerManager::connect_server`].
927///
928/// # Example
929///
930/// ```rust,no_run
931/// use agentkit_mcp::{McpConnection, McpServerConfig, McpTransportBinding, StdioTransportConfig};
932///
933/// # #[tokio::main]
934/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
935/// let config = McpServerConfig::new(
936///     "filesystem",
937///     McpTransportBinding::Stdio(StdioTransportConfig::new("npx")
938///         .with_arg("-y")
939///         .with_arg("@modelcontextprotocol/server-filesystem")),
940/// );
941///
942/// let connection = McpConnection::connect(&config).await?;
943/// let snapshot = connection.discover().await?;
944/// println!("found {} tools", snapshot.tools.len());
945/// # Ok(())
946/// # }
947/// ```
948pub struct McpConnection {
949    server_id: McpServerId,
950    transport: Mutex<Box<dyn McpTransport>>,
951    auth: Mutex<Option<MetadataMap>>,
952    next_id: AtomicU64,
953}
954
955/// The result of replaying an MCP operation after auth resolution.
956///
957/// Returned by [`McpConnection::replay_auth_operation`] and
958/// [`McpServerManager::resolve_auth_and_resume`].
959#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
960pub enum McpOperationResult {
961    /// The server was successfully (re)connected; contains the discovery snapshot.
962    Connected(McpDiscoverySnapshot),
963    /// A tool call completed; contains the raw JSON result.
964    Tool(Value),
965    /// A resource was read successfully.
966    Resource(ResourceContents),
967    /// A prompt was retrieved successfully.
968    Prompt(PromptContents),
969}
970
971impl McpConnection {
972    /// Connects to an MCP server, performs the JSON-RPC `initialize` handshake, and
973    /// returns a ready-to-use connection.
974    ///
975    /// # Errors
976    ///
977    /// Returns [`McpError`] if the transport fails to connect, the handshake is
978    /// rejected, or the server requires authentication ([`McpError::AuthRequired`]).
979    pub async fn connect(config: &McpServerConfig) -> Result<Self, McpError> {
980        Self::connect_with_auth(config, None).await
981    }
982
983    async fn connect_with_auth(
984        config: &McpServerConfig,
985        auth: Option<&MetadataMap>,
986    ) -> Result<Self, McpError> {
987        let factory: Arc<dyn McpTransportFactory> = match &config.transport {
988            McpTransportBinding::Stdio(binding) => {
989                Arc::new(StdioTransportFactory::new(binding.clone()))
990            }
991            McpTransportBinding::StreamableHttp(binding) => {
992                Arc::new(StreamableHttpTransportFactory::new(binding.clone()))
993            }
994            McpTransportBinding::Sse(binding) => {
995                Arc::new(SseTransportFactory::new(binding.clone()))
996            }
997            McpTransportBinding::Custom(factory) => factory.clone(),
998        };
999
1000        let mut transport = factory.connect().await?;
1001        let mut params = serde_json::Map::new();
1002        params.insert(
1003            "protocolVersion".into(),
1004            Value::String(MCP_LATEST_PROTOCOL_VERSION.into()),
1005        );
1006        params.insert("capabilities".into(), json!({}));
1007        params.insert(
1008            "clientInfo".into(),
1009            json!({
1010                "name": "agentkit-mcp",
1011                "version": env!("CARGO_PKG_VERSION")
1012            }),
1013        );
1014        if let Some(auth) = auth {
1015            params.insert("auth".into(), metadata_to_value(auth));
1016        }
1017        let init_params = Value::Object(params.clone());
1018        transport
1019            .send(McpFrame {
1020                value: json!({
1021                    "jsonrpc": "2.0",
1022                    "id": 0,
1023                    "method": "initialize",
1024                    "params": init_params.clone()
1025                }),
1026            })
1027            .await?;
1028        let init_response = transport.recv().await?.ok_or_else(|| {
1029            McpError::Transport("transport closed during MCP initialization".into())
1030        })?;
1031        if let Some(error) = init_response.value.get("error") {
1032            if let Some(auth_request) =
1033                parse_auth_request(&config.id, "initialize", &init_params, error)
1034            {
1035                return Err(McpError::AuthRequired(Box::new(auth_request)));
1036            }
1037            return Err(McpError::Invocation(error.to_string()));
1038        }
1039        let negotiated_protocol_version = init_response
1040            .value
1041            .get("result")
1042            .and_then(|result| result.get("protocolVersion"))
1043            .and_then(Value::as_str)
1044            .ok_or_else(|| {
1045                McpError::Protocol("initialize response missing result.protocolVersion".into())
1046            })?;
1047        if !MCP_SUPPORTED_PROTOCOL_VERSIONS.contains(&negotiated_protocol_version) {
1048            return Err(McpError::Protocol(format!(
1049                "unsupported MCP protocol version negotiated during initialize: {negotiated_protocol_version}"
1050            )));
1051        }
1052        transport
1053            .send(McpFrame {
1054                value: json!({
1055                    "jsonrpc": "2.0",
1056                    "method": "notifications/initialized",
1057                    "params": {}
1058                }),
1059            })
1060            .await?;
1061
1062        Ok(Self {
1063            server_id: config.id.clone(),
1064            transport: Mutex::new(transport),
1065            auth: Mutex::new(auth.cloned()),
1066            next_id: AtomicU64::new(1),
1067        })
1068    }
1069
1070    /// Returns the [`McpServerId`] for this connection.
1071    pub fn server_id(&self) -> &McpServerId {
1072        &self.server_id
1073    }
1074
1075    /// Closes the underlying transport, shutting down the connection to the server.
1076    ///
1077    /// # Errors
1078    ///
1079    /// Returns [`McpError`] if the transport cannot be closed cleanly.
1080    pub async fn close(&self) -> Result<(), McpError> {
1081        let mut transport = self.transport.lock().await;
1082        transport.close().await
1083    }
1084
1085    /// Stores or clears authentication credentials for future requests on this
1086    /// connection.
1087    ///
1088    /// After calling this method with [`AuthResolution::Provided`], every subsequent
1089    /// JSON-RPC request will include the credentials in an `auth` field.
1090    ///
1091    /// # Errors
1092    ///
1093    /// Returns [`McpError`] if the resolution cannot be applied.
1094    pub async fn resolve_auth(&self, resolution: AuthResolution) -> Result<(), McpError> {
1095        let mut auth = self.auth.lock().await;
1096        match resolution {
1097            AuthResolution::Provided { credentials, .. } => {
1098                *auth = Some(credentials);
1099            }
1100            AuthResolution::Cancelled { .. } => {
1101                *auth = None;
1102            }
1103        }
1104        Ok(())
1105    }
1106
1107    /// Performs full capability discovery by listing tools, resources, and prompts.
1108    ///
1109    /// Returns an [`McpDiscoverySnapshot`] containing everything the server advertises.
1110    ///
1111    /// # Errors
1112    ///
1113    /// Returns [`McpError`] if any of the list requests fail.
1114    pub async fn discover(&self) -> Result<McpDiscoverySnapshot, McpError> {
1115        Ok(McpDiscoverySnapshot {
1116            server_id: self.server_id.clone(),
1117            tools: self.list_tools().await?,
1118            resources: self.list_resources().await?,
1119            prompts: self.list_prompts().await?,
1120            metadata: MetadataMap::new(),
1121        })
1122    }
1123
1124    /// Lists all tools advertised by the connected MCP server.
1125    ///
1126    /// # Errors
1127    ///
1128    /// Returns [`McpError`] if the `tools/list` request fails.
1129    pub async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>, McpError> {
1130        let result = self.request("tools/list", json!({})).await?;
1131        result
1132            .get("tools")
1133            .and_then(Value::as_array)
1134            .cloned()
1135            .unwrap_or_default()
1136            .into_iter()
1137            .map(parse_tool_descriptor)
1138            .collect()
1139    }
1140
1141    /// Lists all resources advertised by the connected MCP server.
1142    ///
1143    /// # Errors
1144    ///
1145    /// Returns [`McpError`] if the `resources/list` request fails.
1146    pub async fn list_resources(&self) -> Result<Vec<McpResourceDescriptor>, McpError> {
1147        let result = self.request("resources/list", json!({})).await?;
1148        result
1149            .get("resources")
1150            .and_then(Value::as_array)
1151            .cloned()
1152            .unwrap_or_default()
1153            .into_iter()
1154            .map(parse_resource_descriptor)
1155            .collect()
1156    }
1157
1158    /// Lists all prompts advertised by the connected MCP server.
1159    ///
1160    /// # Errors
1161    ///
1162    /// Returns [`McpError`] if the `prompts/list` request fails.
1163    pub async fn list_prompts(&self) -> Result<Vec<McpPromptDescriptor>, McpError> {
1164        let result = self.request("prompts/list", json!({})).await?;
1165        result
1166            .get("prompts")
1167            .and_then(Value::as_array)
1168            .cloned()
1169            .unwrap_or_default()
1170            .into_iter()
1171            .map(parse_prompt_descriptor)
1172            .collect()
1173    }
1174
1175    /// Invokes a tool on the MCP server and returns the raw JSON result.
1176    ///
1177    /// # Arguments
1178    ///
1179    /// * `name` - The tool name as it appears in the server's tool list.
1180    /// * `arguments` - A JSON value matching the tool's input schema.
1181    ///
1182    /// # Errors
1183    ///
1184    /// Returns [`McpError::AuthRequired`] if the server demands authentication,
1185    /// or another [`McpError`] variant on transport or protocol failures.
1186    pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value, McpError> {
1187        self.request(
1188            "tools/call",
1189            json!({
1190                "name": name,
1191                "arguments": arguments,
1192            }),
1193        )
1194        .await
1195    }
1196
1197    /// Reads a resource from the MCP server by URI.
1198    ///
1199    /// # Arguments
1200    ///
1201    /// * `uri` - The resource URI (e.g. `"file:///tmp/example.txt"`).
1202    ///
1203    /// # Errors
1204    ///
1205    /// Returns [`McpError`] if the resource cannot be read or the response is malformed.
1206    pub async fn read_resource(&self, uri: &str) -> Result<ResourceContents, McpError> {
1207        let result = self
1208            .request(
1209                "resources/read",
1210                json!({
1211                    "uri": uri,
1212                }),
1213            )
1214            .await?;
1215        let content = result
1216            .get("contents")
1217            .and_then(Value::as_array)
1218            .and_then(|values| values.first())
1219            .cloned()
1220            .ok_or_else(|| McpError::Protocol("resources/read returned no contents".into()))?;
1221
1222        let data = if let Some(text) = content.get("text").and_then(Value::as_str) {
1223            DataRef::InlineText(text.into())
1224        } else if let Some(found_uri) = content.get("uri").and_then(Value::as_str) {
1225            DataRef::Uri(found_uri.into())
1226        } else {
1227            return Err(McpError::Protocol(
1228                "unsupported resource content shape".into(),
1229            ));
1230        };
1231
1232        Ok(ResourceContents {
1233            data,
1234            metadata: MetadataMap::new(),
1235        })
1236    }
1237
1238    /// Retrieves a prompt from the MCP server, rendering it with the given arguments.
1239    ///
1240    /// # Arguments
1241    ///
1242    /// * `name` - The prompt name as it appears in the server's prompt list.
1243    /// * `arguments` - A JSON value containing the prompt's input arguments.
1244    ///
1245    /// # Errors
1246    ///
1247    /// Returns [`McpError`] if the prompt cannot be retrieved or the response is malformed.
1248    pub async fn get_prompt(
1249        &self,
1250        name: &str,
1251        arguments: Value,
1252    ) -> Result<PromptContents, McpError> {
1253        let result = self
1254            .request(
1255                "prompts/get",
1256                json!({
1257                    "name": name,
1258                    "arguments": arguments,
1259                }),
1260            )
1261            .await?;
1262        let items = result
1263            .get("messages")
1264            .and_then(Value::as_array)
1265            .cloned()
1266            .unwrap_or_default()
1267            .into_iter()
1268            .map(parse_prompt_message)
1269            .collect::<Result<Vec<_>, _>>()?;
1270
1271        Ok(PromptContents {
1272            items,
1273            metadata: MetadataMap::new(),
1274        })
1275    }
1276
1277    async fn request(&self, method: &str, params: Value) -> Result<Value, McpError> {
1278        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1279        let params = self.enrich_params(params.clone()).await;
1280        let mut transport = self.transport.lock().await;
1281        transport
1282            .send(McpFrame {
1283                value: json!({
1284                    "jsonrpc": "2.0",
1285                    "id": id,
1286                    "method": method,
1287                    "params": params,
1288                }),
1289            })
1290            .await?;
1291
1292        loop {
1293            let Some(frame) = transport.recv().await? else {
1294                return Err(McpError::Transport(
1295                    "transport closed while waiting for MCP response".into(),
1296                ));
1297            };
1298
1299            if frame.value.get("id").and_then(Value::as_u64) != Some(id) {
1300                continue;
1301            }
1302
1303            if let Some(error) = frame.value.get("error") {
1304                if let Some(auth_request) =
1305                    parse_auth_request(&self.server_id, method, &params, error)
1306                {
1307                    return Err(McpError::AuthRequired(Box::new(auth_request)));
1308                }
1309                return Err(McpError::Invocation(error.to_string()));
1310            }
1311
1312            return frame
1313                .value
1314                .get("result")
1315                .cloned()
1316                .ok_or_else(|| McpError::Protocol("MCP response missing result".into()));
1317        }
1318    }
1319
1320    async fn enrich_params(&self, params: Value) -> Value {
1321        let auth = self.auth.lock().await;
1322        let Some(auth) = auth.as_ref() else {
1323            return params;
1324        };
1325
1326        match params {
1327            Value::Object(mut object) => {
1328                object
1329                    .entry("auth")
1330                    .or_insert_with(|| metadata_to_value(auth));
1331                Value::Object(object)
1332            }
1333            other => other,
1334        }
1335    }
1336
1337    /// Replays an MCP operation that previously failed with an auth challenge.
1338    ///
1339    /// This is called after credentials have been resolved via [`resolve_auth`](Self::resolve_auth).
1340    /// The operation is re-issued with the stored credentials attached.
1341    ///
1342    /// # Errors
1343    ///
1344    /// Returns [`McpError::AuthResolution`] if the operation targets a different server,
1345    /// or other [`McpError`] variants if the replayed operation itself fails.
1346    pub async fn replay_auth_operation(
1347        &self,
1348        operation: &AuthOperation,
1349    ) -> Result<McpOperationResult, McpError> {
1350        match operation {
1351            AuthOperation::McpToolCall {
1352                server_id,
1353                tool_name,
1354                input,
1355                ..
1356            } => {
1357                self.ensure_server_match(server_id)?;
1358                self.call_tool(tool_name, input.clone())
1359                    .await
1360                    .map(McpOperationResult::Tool)
1361            }
1362            AuthOperation::McpResourceRead {
1363                server_id,
1364                resource_id,
1365                ..
1366            } => {
1367                self.ensure_server_match(server_id)?;
1368                self.read_resource(resource_id)
1369                    .await
1370                    .map(McpOperationResult::Resource)
1371            }
1372            AuthOperation::McpPromptGet {
1373                server_id,
1374                prompt_id,
1375                args,
1376                ..
1377            } => {
1378                self.ensure_server_match(server_id)?;
1379                self.get_prompt(prompt_id, args.clone())
1380                    .await
1381                    .map(McpOperationResult::Prompt)
1382            }
1383            AuthOperation::ToolCall {
1384                tool_name,
1385                input,
1386                metadata,
1387                ..
1388            } => {
1389                if let Some(server_id) = metadata.get("server_id").and_then(Value::as_str) {
1390                    self.ensure_server_match(server_id)?;
1391                }
1392                let tool_name = normalize_mcp_tool_name(self.server_id(), tool_name);
1393                self.call_tool(&tool_name, input.clone())
1394                    .await
1395                    .map(McpOperationResult::Tool)
1396            }
1397            AuthOperation::McpConnect { .. } => Err(McpError::AuthResolution(
1398                "connect operations must be replayed through the server manager".into(),
1399            )),
1400            AuthOperation::Custom { kind, .. } => Err(McpError::AuthResolution(format!(
1401                "unsupported auth operation for replay: {kind}"
1402            ))),
1403        }
1404    }
1405
1406    fn ensure_server_match(&self, server_id: &str) -> Result<(), McpError> {
1407        if self.server_id.0 == server_id {
1408            Ok(())
1409        } else {
1410            Err(McpError::AuthResolution(format!(
1411                "auth operation targets server {server_id}, but connection is for {}",
1412                self.server_id
1413            )))
1414        }
1415    }
1416}
1417
1418/// Adapter that exposes an MCP tool as an [`Invocable`] for the capabilities system.
1419///
1420/// This is the capabilities-layer adapter. For the tool-layer adapter, see
1421/// [`McpToolAdapter`]. Names are prefixed with `mcp.<server_id>.<tool_name>`.
1422pub struct McpInvocable {
1423    connection: Arc<McpConnection>,
1424    descriptor: McpToolDescriptor,
1425    spec: InvocableSpec,
1426}
1427
1428impl McpInvocable {
1429    /// Creates a new invocable adapter for the given MCP tool.
1430    ///
1431    /// # Arguments
1432    ///
1433    /// * `connection` - A shared connection to the MCP server that owns the tool.
1434    /// * `descriptor` - The tool descriptor obtained from discovery.
1435    pub fn new(connection: Arc<McpConnection>, descriptor: McpToolDescriptor) -> Self {
1436        let spec = InvocableSpec {
1437            name: CapabilityName::new(format!(
1438                "mcp.{}.{}",
1439                connection.server_id(),
1440                descriptor.name
1441            )),
1442            description: descriptor
1443                .description
1444                .clone()
1445                .unwrap_or_else(|| descriptor.name.clone()),
1446            input_schema: descriptor.input_schema.clone(),
1447            metadata: descriptor.metadata.clone(),
1448        };
1449
1450        Self {
1451            connection,
1452            descriptor,
1453            spec,
1454        }
1455    }
1456}
1457
1458#[async_trait]
1459impl Invocable for McpInvocable {
1460    fn spec(&self) -> &InvocableSpec {
1461        &self.spec
1462    }
1463
1464    async fn invoke(
1465        &self,
1466        request: InvocableRequest,
1467        _ctx: &mut CapabilityContext<'_>,
1468    ) -> Result<InvocableResult, CapabilityError> {
1469        let result = self
1470            .connection
1471            .call_tool(&self.descriptor.name, request.input)
1472            .await
1473            .map_err(|error| match error {
1474                McpError::AuthRequired(request) => {
1475                    CapabilityError::Unavailable(format!("auth required: {:?}", request))
1476                }
1477                other => CapabilityError::ExecutionFailed(other.to_string()),
1478            })?;
1479
1480        Ok(InvocableResult {
1481            output: value_to_invocable_output(result),
1482            metadata: MetadataMap::new(),
1483        })
1484    }
1485}
1486
1487/// Adapter that exposes a single MCP resource as a [`ResourceProvider`].
1488///
1489/// Created automatically by [`McpCapabilityProvider::from_snapshot`] for each
1490/// resource discovered on the server.
1491pub struct McpResourceHandle {
1492    connection: Arc<McpConnection>,
1493    descriptor: ResourceDescriptor,
1494}
1495
1496#[async_trait]
1497impl ResourceProvider for McpResourceHandle {
1498    async fn list_resources(&self) -> Result<Vec<ResourceDescriptor>, CapabilityError> {
1499        Ok(vec![self.descriptor.clone()])
1500    }
1501
1502    async fn read_resource(
1503        &self,
1504        id: &ResourceId,
1505        _ctx: &mut CapabilityContext<'_>,
1506    ) -> Result<ResourceContents, CapabilityError> {
1507        self.connection
1508            .read_resource(&id.0)
1509            .await
1510            .map_err(|error| match error {
1511                McpError::AuthRequired(request) => {
1512                    CapabilityError::Unavailable(format!("auth required: {:?}", request))
1513                }
1514                other => CapabilityError::ExecutionFailed(other.to_string()),
1515            })
1516    }
1517}
1518
1519/// Adapter that exposes a single MCP prompt as a [`PromptProvider`].
1520///
1521/// Created automatically by [`McpCapabilityProvider::from_snapshot`] for each
1522/// prompt discovered on the server.
1523pub struct McpPromptHandle {
1524    connection: Arc<McpConnection>,
1525    descriptor: PromptDescriptor,
1526}
1527
1528#[async_trait]
1529impl PromptProvider for McpPromptHandle {
1530    async fn list_prompts(&self) -> Result<Vec<PromptDescriptor>, CapabilityError> {
1531        Ok(vec![self.descriptor.clone()])
1532    }
1533
1534    async fn get_prompt(
1535        &self,
1536        id: &PromptId,
1537        args: Value,
1538        _ctx: &mut CapabilityContext<'_>,
1539    ) -> Result<PromptContents, CapabilityError> {
1540        self.connection
1541            .get_prompt(&id.0, args)
1542            .await
1543            .map_err(|error| match error {
1544                McpError::AuthRequired(request) => {
1545                    CapabilityError::Unavailable(format!("auth required: {:?}", request))
1546                }
1547                other => CapabilityError::ExecutionFailed(other.to_string()),
1548            })
1549    }
1550}
1551
1552/// A [`CapabilityProvider`] that surfaces MCP tools, resources, and prompts into the
1553/// agentkit capabilities system.
1554///
1555/// Built from a discovery snapshot, this provider wraps each MCP tool as an
1556/// [`McpInvocable`], each resource as an [`McpResourceHandle`], and each prompt as
1557/// an [`McpPromptHandle`].
1558///
1559/// # Example
1560///
1561/// ```rust,no_run
1562/// use std::sync::Arc;
1563/// use agentkit_mcp::{McpCapabilityProvider, McpServerConfig, McpTransportBinding, StdioTransportConfig};
1564///
1565/// # #[tokio::main]
1566/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
1567/// let config = McpServerConfig::new(
1568///     "filesystem",
1569///     McpTransportBinding::Stdio(StdioTransportConfig::new("npx")
1570///         .with_arg("-y")
1571///         .with_arg("@modelcontextprotocol/server-filesystem")),
1572/// );
1573/// let (connection, provider, snapshot) = McpCapabilityProvider::connect(&config).await?;
1574/// // `provider` implements CapabilityProvider and can be registered with an agent.
1575/// # Ok(())
1576/// # }
1577/// ```
1578pub struct McpCapabilityProvider {
1579    invocables: Vec<Arc<dyn Invocable>>,
1580    resources: Vec<Arc<dyn ResourceProvider>>,
1581    prompts: Vec<Arc<dyn PromptProvider>>,
1582}
1583
1584impl McpCapabilityProvider {
1585    /// Creates a capability provider from an existing connection and its discovery
1586    /// snapshot.
1587    ///
1588    /// Each tool, resource, and prompt in the snapshot is wrapped in the appropriate
1589    /// adapter type.
1590    pub fn from_snapshot(connection: Arc<McpConnection>, snapshot: &McpDiscoverySnapshot) -> Self {
1591        let invocables = snapshot
1592            .tools
1593            .iter()
1594            .cloned()
1595            .map(|descriptor| {
1596                Arc::new(McpInvocable::new(connection.clone(), descriptor)) as Arc<dyn Invocable>
1597            })
1598            .collect();
1599
1600        let resources = snapshot
1601            .resources
1602            .iter()
1603            .cloned()
1604            .map(|descriptor| {
1605                Arc::new(McpResourceHandle {
1606                    connection: connection.clone(),
1607                    descriptor: ResourceDescriptor {
1608                        id: ResourceId::new(descriptor.id),
1609                        name: descriptor.name,
1610                        description: descriptor.description,
1611                        mime_type: descriptor.mime_type,
1612                        metadata: descriptor.metadata,
1613                    },
1614                }) as Arc<dyn ResourceProvider>
1615            })
1616            .collect();
1617
1618        let prompts = snapshot
1619            .prompts
1620            .iter()
1621            .cloned()
1622            .map(|descriptor| {
1623                Arc::new(McpPromptHandle {
1624                    connection: connection.clone(),
1625                    descriptor: PromptDescriptor {
1626                        id: PromptId::new(descriptor.id),
1627                        name: descriptor.name,
1628                        description: descriptor.description,
1629                        input_schema: descriptor.input_schema,
1630                        metadata: descriptor.metadata,
1631                    },
1632                }) as Arc<dyn PromptProvider>
1633            })
1634            .collect();
1635
1636        Self {
1637            invocables,
1638            resources,
1639            prompts,
1640        }
1641    }
1642
1643    /// Merges multiple capability providers into a single provider.
1644    ///
1645    /// This is useful when managing several MCP servers through a
1646    /// [`McpServerManager`] and you want one combined provider for the agent.
1647    pub fn merge<I>(providers: I) -> Self
1648    where
1649        I: IntoIterator<Item = Self>,
1650    {
1651        let mut invocables = Vec::new();
1652        let mut resources = Vec::new();
1653        let mut prompts = Vec::new();
1654
1655        for provider in providers {
1656            invocables.extend(provider.invocables);
1657            resources.extend(provider.resources);
1658            prompts.extend(provider.prompts);
1659        }
1660
1661        Self {
1662            invocables,
1663            resources,
1664            prompts,
1665        }
1666    }
1667
1668    /// Connects to an MCP server, performs discovery, and builds a capability
1669    /// provider in one step.
1670    ///
1671    /// Returns the shared connection, the provider, and the discovery snapshot.
1672    ///
1673    /// # Errors
1674    ///
1675    /// Returns [`McpError`] if connection or discovery fails.
1676    pub async fn connect(
1677        config: &McpServerConfig,
1678    ) -> Result<(Arc<McpConnection>, Self, McpDiscoverySnapshot), McpError> {
1679        let connection = Arc::new(McpConnection::connect(config).await?);
1680        let snapshot = connection.discover().await?;
1681        let provider = Self::from_snapshot(connection.clone(), &snapshot);
1682
1683        Ok((connection, provider, snapshot))
1684    }
1685}
1686
1687impl CapabilityProvider for McpCapabilityProvider {
1688    fn invocables(&self) -> Vec<Arc<dyn Invocable>> {
1689        self.invocables.clone()
1690    }
1691
1692    fn resources(&self) -> Vec<Arc<dyn ResourceProvider>> {
1693        self.resources.clone()
1694    }
1695
1696    fn prompts(&self) -> Vec<Arc<dyn PromptProvider>> {
1697        self.prompts.clone()
1698    }
1699}
1700
1701/// A connected MCP server together with its configuration and discovery snapshot.
1702///
1703/// Obtained from [`McpServerManager::connect_server`] or
1704/// [`McpServerManager::connect_all`]. Provides convenience methods to create
1705/// tool registries and capability providers from the server's discovered capabilities.
1706#[derive(Clone)]
1707pub struct McpServerHandle {
1708    config: McpServerConfig,
1709    connection: Arc<McpConnection>,
1710    snapshot: McpDiscoverySnapshot,
1711}
1712
1713impl McpServerHandle {
1714    /// Returns the original configuration used to connect this server.
1715    pub fn config(&self) -> &McpServerConfig {
1716        &self.config
1717    }
1718
1719    /// Returns the server's unique identifier.
1720    pub fn server_id(&self) -> &McpServerId {
1721        self.connection.server_id()
1722    }
1723
1724    /// Returns a shared reference to the underlying [`McpConnection`].
1725    pub fn connection(&self) -> Arc<McpConnection> {
1726        self.connection.clone()
1727    }
1728
1729    /// Returns the discovery snapshot captured when the server was connected.
1730    pub fn snapshot(&self) -> &McpDiscoverySnapshot {
1731        &self.snapshot
1732    }
1733
1734    /// Builds a [`ToolRegistry`] containing an [`McpToolAdapter`] for each tool
1735    /// discovered on this server.
1736    pub fn tool_registry(&self) -> ToolRegistry {
1737        self.snapshot
1738            .tools
1739            .iter()
1740            .cloned()
1741            .fold(ToolRegistry::new(), |registry, descriptor| {
1742                registry.with(McpToolAdapter::new(
1743                    self.server_id(),
1744                    self.connection.clone(),
1745                    descriptor,
1746                ))
1747            })
1748    }
1749
1750    /// Builds an [`McpCapabilityProvider`] from this server's discovery snapshot.
1751    pub fn capability_provider(&self) -> McpCapabilityProvider {
1752        McpCapabilityProvider::from_snapshot(self.connection.clone(), &self.snapshot)
1753    }
1754}
1755
1756/// Manages the lifecycle of one or more MCP servers: registration, connection,
1757/// discovery, refresh, disconnection, and auth resolution.
1758///
1759/// This is the primary entry point for integrating MCP servers into an agentkit
1760/// application. Register server configurations, connect them, and then obtain a
1761/// combined [`ToolRegistry`] or [`McpCapabilityProvider`] for use in an agent loop.
1762///
1763/// # Example
1764///
1765/// ```rust,no_run
1766/// use agentkit_mcp::{
1767///     McpServerConfig, McpServerManager, McpTransportBinding, StdioTransportConfig,
1768/// };
1769///
1770/// # #[tokio::main]
1771/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
1772/// let mut manager = McpServerManager::new()
1773///     .with_server(McpServerConfig::new(
1774///         "filesystem",
1775///         McpTransportBinding::Stdio(
1776///             StdioTransportConfig::new("npx")
1777///                 .with_arg("-y")
1778///                 .with_arg("@modelcontextprotocol/server-filesystem"),
1779///         ),
1780///     ))
1781///     .with_server(McpServerConfig::new(
1782///         "github",
1783///         McpTransportBinding::Stdio(
1784///             StdioTransportConfig::new("npx")
1785///                 .with_arg("-y")
1786///                 .with_arg("@modelcontextprotocol/server-github"),
1787///         ),
1788///     ));
1789///
1790/// let handles = manager.connect_all().await?;
1791/// let registry = manager.tool_registry();
1792/// println!("tools: {:?}", registry.specs().iter().map(|s| &s.name).collect::<Vec<_>>());
1793/// # Ok(())
1794/// # }
1795/// ```
1796#[derive(Default)]
1797pub struct McpServerManager {
1798    configs: BTreeMap<McpServerId, McpServerConfig>,
1799    connections: BTreeMap<McpServerId, McpServerHandle>,
1800    auth: BTreeMap<McpServerId, MetadataMap>,
1801}
1802
1803impl McpServerManager {
1804    /// Creates an empty server manager with no registered servers.
1805    pub fn new() -> Self {
1806        Self::default()
1807    }
1808
1809    /// Registers a server configuration and returns `self` for chaining.
1810    ///
1811    /// The server is not connected until [`connect_server`](Self::connect_server) or
1812    /// [`connect_all`](Self::connect_all) is called.
1813    pub fn with_server(mut self, config: McpServerConfig) -> Self {
1814        self.register_server(config);
1815        self
1816    }
1817
1818    /// Registers a server configuration by mutable reference.
1819    ///
1820    /// The server is not connected until [`connect_server`](Self::connect_server) or
1821    /// [`connect_all`](Self::connect_all) is called.
1822    pub fn register_server(&mut self, config: McpServerConfig) -> &mut Self {
1823        self.configs.insert(config.id.clone(), config);
1824        self
1825    }
1826
1827    /// Returns the handle for a connected server, or `None` if it is not connected.
1828    pub fn connected_server(&self, server_id: &McpServerId) -> Option<&McpServerHandle> {
1829        self.connections.get(server_id)
1830    }
1831
1832    /// Returns handles for all currently connected servers.
1833    pub fn connected_servers(&self) -> Vec<&McpServerHandle> {
1834        self.connections.values().collect()
1835    }
1836
1837    /// Connects a single registered server by its identifier.
1838    ///
1839    /// Performs the MCP handshake and full capability discovery.
1840    ///
1841    /// # Errors
1842    ///
1843    /// Returns [`McpError::UnknownServer`] if the server ID has not been registered,
1844    /// or other [`McpError`] variants if connection or discovery fails.
1845    pub async fn connect_server(
1846        &mut self,
1847        server_id: &McpServerId,
1848    ) -> Result<McpServerHandle, McpError> {
1849        let config = self
1850            .configs
1851            .get(server_id)
1852            .cloned()
1853            .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))?;
1854        let connection =
1855            Arc::new(McpConnection::connect_with_auth(&config, self.auth.get(server_id)).await?);
1856        let snapshot = connection.discover().await?;
1857        let handle = McpServerHandle {
1858            config,
1859            connection,
1860            snapshot,
1861        };
1862        self.connections.insert(server_id.clone(), handle.clone());
1863        Ok(handle)
1864    }
1865
1866    /// Connects all registered servers sequentially.
1867    ///
1868    /// Returns a handle for each server in registration order. If any server fails
1869    /// to connect, the error is returned immediately and remaining servers are
1870    /// not attempted.
1871    ///
1872    /// # Errors
1873    ///
1874    /// Returns the first [`McpError`] encountered during connection.
1875    pub async fn connect_all(&mut self) -> Result<Vec<McpServerHandle>, McpError> {
1876        let server_ids = self.configs.keys().cloned().collect::<Vec<_>>();
1877        let mut handles = Vec::with_capacity(server_ids.len());
1878
1879        for server_id in server_ids {
1880            handles.push(self.connect_server(&server_id).await?);
1881        }
1882
1883        Ok(handles)
1884    }
1885
1886    /// Re-discovers capabilities for a connected server, updating the stored snapshot.
1887    ///
1888    /// Call this after the server's capabilities may have changed (e.g. after
1889    /// installing a plugin).
1890    ///
1891    /// # Errors
1892    ///
1893    /// Returns [`McpError::UnknownServer`] if the server is not connected, or other
1894    /// [`McpError`] variants if discovery fails.
1895    pub async fn refresh_server(
1896        &mut self,
1897        server_id: &McpServerId,
1898    ) -> Result<McpDiscoverySnapshot, McpError> {
1899        let handle = self
1900            .connections
1901            .get_mut(server_id)
1902            .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))?;
1903        let snapshot = handle.connection.discover().await?;
1904        handle.snapshot = snapshot.clone();
1905        Ok(snapshot)
1906    }
1907
1908    /// Disconnects a server and removes it from the active connections.
1909    ///
1910    /// The server configuration remains registered and can be reconnected later
1911    /// with [`connect_server`](Self::connect_server).
1912    ///
1913    /// # Errors
1914    ///
1915    /// Returns [`McpError::UnknownServer`] if the server is not connected.
1916    pub async fn disconnect_server(&mut self, server_id: &McpServerId) -> Result<(), McpError> {
1917        let Some(handle) = self.connections.remove(server_id) else {
1918            return Err(McpError::UnknownServer(server_id.to_string()));
1919        };
1920        handle.connection.close().await
1921    }
1922
1923    /// Stores or clears authentication credentials for a server and, if already
1924    /// connected, updates the live connection as well.
1925    ///
1926    /// # Errors
1927    ///
1928    /// Returns [`McpError::UnknownServer`] if the server ID from the resolution
1929    /// does not match any registered server.
1930    pub async fn resolve_auth(&mut self, resolution: AuthResolution) -> Result<(), McpError> {
1931        let server_id = resolution
1932            .request()
1933            .server_id()
1934            .ok_or_else(|| McpError::AuthResolution("auth resolution missing server id".into()))?;
1935        let server_id = McpServerId::new(server_id);
1936        match &resolution {
1937            AuthResolution::Provided { credentials, .. } => {
1938                self.auth.insert(server_id.clone(), credentials.clone());
1939            }
1940            AuthResolution::Cancelled { .. } => {
1941                self.auth.remove(&server_id);
1942            }
1943        }
1944
1945        if let Some(handle) = self.connections.get(&server_id) {
1946            handle.connection.resolve_auth(resolution).await?;
1947            return Ok(());
1948        }
1949
1950        if self.configs.contains_key(&server_id) {
1951            Ok(())
1952        } else {
1953            Err(McpError::UnknownServer(server_id.to_string()))
1954        }
1955    }
1956
1957    /// Resolves authentication and immediately replays the operation that originally
1958    /// triggered the auth challenge.
1959    ///
1960    /// This is a convenience method combining [`resolve_auth`](Self::resolve_auth)
1961    /// and [`replay_auth_request`](Self::replay_auth_request).
1962    ///
1963    /// # Errors
1964    ///
1965    /// Returns [`McpError`] if auth resolution or the replayed operation fails.
1966    pub async fn resolve_auth_and_resume(
1967        &mut self,
1968        resolution: AuthResolution,
1969    ) -> Result<McpOperationResult, McpError> {
1970        let request = resolution.request().clone();
1971        self.resolve_auth(resolution).await?;
1972        self.replay_auth_request(&request).await
1973    }
1974
1975    /// Replays an auth request's original MCP operation using stored credentials.
1976    ///
1977    /// For connect operations the server is (re)connected. For tool calls, resource
1978    /// reads, and prompt retrievals the request is re-issued on the existing or
1979    /// newly established connection.
1980    ///
1981    /// # Errors
1982    ///
1983    /// Returns [`McpError`] if the operation cannot be replayed.
1984    pub async fn replay_auth_request(
1985        &mut self,
1986        request: &AuthRequest,
1987    ) -> Result<McpOperationResult, McpError> {
1988        match &request.operation {
1989            AuthOperation::McpConnect { server_id, .. } => {
1990                let server_id = McpServerId::new(server_id);
1991                let handle = self.connect_server(&server_id).await?;
1992                Ok(McpOperationResult::Connected(handle.snapshot.clone()))
1993            }
1994            AuthOperation::McpToolCall { server_id, .. }
1995            | AuthOperation::McpResourceRead { server_id, .. }
1996            | AuthOperation::McpPromptGet { server_id, .. } => {
1997                let connection = self.connection_for_auth_server(server_id).await?;
1998                connection.replay_auth_operation(&request.operation).await
1999            }
2000            AuthOperation::ToolCall { metadata, .. } => {
2001                let server_id = metadata
2002                    .get("server_id")
2003                    .and_then(Value::as_str)
2004                    .ok_or_else(|| {
2005                        McpError::AuthResolution(
2006                            "tool-call auth replay requires metadata.server_id".into(),
2007                        )
2008                    })?;
2009                let connection = self.connection_for_auth_server(server_id).await?;
2010                connection.replay_auth_operation(&request.operation).await
2011            }
2012            AuthOperation::Custom { kind, .. } => Err(McpError::AuthResolution(format!(
2013                "unsupported auth operation for replay: {kind}"
2014            ))),
2015        }
2016    }
2017
2018    async fn connection_for_auth_server(
2019        &mut self,
2020        server_id: &str,
2021    ) -> Result<Arc<McpConnection>, McpError> {
2022        let server_id = McpServerId::new(server_id);
2023        if !self.connections.contains_key(&server_id) {
2024            self.connect_server(&server_id).await?;
2025        }
2026        self.connections
2027            .get(&server_id)
2028            .map(McpServerHandle::connection)
2029            .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))
2030    }
2031
2032    /// Builds a combined [`ToolRegistry`] containing [`McpToolAdapter`]s for every
2033    /// tool discovered across all connected servers.
2034    ///
2035    /// Tool names are prefixed as `mcp.<server_id>.<tool_name>`.
2036    pub fn tool_registry(&self) -> ToolRegistry {
2037        self.connections
2038            .values()
2039            .fold(ToolRegistry::new(), |mut registry, handle| {
2040                for tool in handle.snapshot.tools.iter().cloned() {
2041                    registry.register(McpToolAdapter::new(
2042                        handle.server_id(),
2043                        handle.connection.clone(),
2044                        tool,
2045                    ));
2046                }
2047                registry
2048            })
2049    }
2050
2051    /// Builds a combined [`McpCapabilityProvider`] from all connected servers,
2052    /// merging their tools, resources, and prompts.
2053    pub fn capability_provider(&self) -> McpCapabilityProvider {
2054        McpCapabilityProvider::merge(
2055            self.connections
2056                .values()
2057                .map(McpServerHandle::capability_provider),
2058        )
2059    }
2060}
2061
2062/// Adapter that exposes an MCP tool as an agentkit [`Tool`].
2063///
2064/// This is the tool-layer adapter for the tool registry. For the capabilities-layer
2065/// adapter, see [`McpInvocable`]. Tool names are prefixed as
2066/// `mcp.<server_id>.<tool_name>`.
2067///
2068/// # Example
2069///
2070/// ```rust
2071/// use std::sync::Arc;
2072/// use agentkit_core::MetadataMap;
2073/// use agentkit_mcp::{McpToolAdapter, McpToolDescriptor, McpServerId};
2074/// # // McpToolAdapter::new requires a connection which we cannot construct in a doc test,
2075/// # // so this example only shows the construction pattern.
2076/// ```
2077pub struct McpToolAdapter {
2078    descriptor: McpToolDescriptor,
2079    connection: Arc<McpConnection>,
2080    spec: ToolSpec,
2081}
2082
2083impl McpToolAdapter {
2084    /// Creates a new tool adapter for the given MCP tool.
2085    ///
2086    /// # Arguments
2087    ///
2088    /// * `server_id` - The server's identifier, used to namespace the tool name.
2089    /// * `connection` - A shared connection to the owning MCP server.
2090    /// * `descriptor` - The tool descriptor obtained from discovery.
2091    pub fn new(
2092        server_id: &McpServerId,
2093        connection: Arc<McpConnection>,
2094        descriptor: McpToolDescriptor,
2095    ) -> Self {
2096        let spec = ToolSpec {
2097            name: ToolName::new(format!("mcp.{}.{}", server_id, descriptor.name)),
2098            description: descriptor
2099                .description
2100                .clone()
2101                .unwrap_or_else(|| descriptor.name.clone()),
2102            input_schema: descriptor.input_schema.clone(),
2103            annotations: ToolAnnotations::default(),
2104            metadata: descriptor.metadata.clone(),
2105        };
2106
2107        Self {
2108            descriptor,
2109            connection,
2110            spec,
2111        }
2112    }
2113}
2114
2115#[async_trait]
2116impl Tool for McpToolAdapter {
2117    fn spec(&self) -> &ToolSpec {
2118        &self.spec
2119    }
2120
2121    async fn invoke(
2122        &self,
2123        request: ToolRequest,
2124        _ctx: &mut ToolContext<'_>,
2125    ) -> Result<ToolResult, ToolError> {
2126        let result = self
2127            .connection
2128            .call_tool(&self.descriptor.name, request.input)
2129            .await
2130            .map_err(|error| match error {
2131                McpError::AuthRequired(request) => ToolError::AuthRequired(request),
2132                other => ToolError::ExecutionFailed(other.to_string()),
2133            })?;
2134
2135        Ok(ToolResult {
2136            result: ToolResultPart {
2137                call_id: request.call_id,
2138                output: invocable_output_to_tool_output(value_to_invocable_output(result)),
2139                is_error: false,
2140                metadata: MetadataMap::new(),
2141            },
2142            duration: None,
2143            metadata: MetadataMap::new(),
2144        })
2145    }
2146}
2147
2148fn parse_tool_descriptor(value: Value) -> Result<McpToolDescriptor, McpError> {
2149    Ok(McpToolDescriptor {
2150        name: required_string(&value, "name")?,
2151        description: value
2152            .get("description")
2153            .and_then(Value::as_str)
2154            .map(str::to_owned),
2155        input_schema: value
2156            .get("inputSchema")
2157            .cloned()
2158            .unwrap_or_else(|| json!({ "type": "object" })),
2159        metadata: MetadataMap::new(),
2160    })
2161}
2162
2163fn parse_resource_descriptor(value: Value) -> Result<McpResourceDescriptor, McpError> {
2164    Ok(McpResourceDescriptor {
2165        id: required_string(&value, "uri")?,
2166        name: value
2167            .get("name")
2168            .and_then(Value::as_str)
2169            .map(str::to_owned)
2170            .unwrap_or_else(|| {
2171                value
2172                    .get("uri")
2173                    .and_then(Value::as_str)
2174                    .unwrap_or_default()
2175                    .to_string()
2176            }),
2177        description: value
2178            .get("description")
2179            .and_then(Value::as_str)
2180            .map(str::to_owned),
2181        mime_type: value
2182            .get("mimeType")
2183            .and_then(Value::as_str)
2184            .map(str::to_owned),
2185        metadata: MetadataMap::new(),
2186    })
2187}
2188
2189fn parse_prompt_descriptor(value: Value) -> Result<McpPromptDescriptor, McpError> {
2190    let name = required_string(&value, "name")?;
2191    let properties = value
2192        .get("arguments")
2193        .and_then(Value::as_array)
2194        .cloned()
2195        .unwrap_or_default()
2196        .into_iter()
2197        .filter_map(|arg| {
2198            let name = arg.get("name")?.as_str()?.to_string();
2199            Some((name, json!({ "type": "string" })))
2200        })
2201        .collect::<serde_json::Map<String, Value>>();
2202
2203    Ok(McpPromptDescriptor {
2204        id: name.clone(),
2205        name,
2206        description: value
2207            .get("description")
2208            .and_then(Value::as_str)
2209            .map(str::to_owned),
2210        input_schema: json!({
2211            "type": "object",
2212            "properties": properties,
2213        }),
2214        metadata: MetadataMap::new(),
2215    })
2216}
2217
2218fn parse_prompt_message(value: Value) -> Result<Item, McpError> {
2219    let role = value.get("role").and_then(Value::as_str).unwrap_or("user");
2220    let kind = match role {
2221        "assistant" => ItemKind::Assistant,
2222        "system" => ItemKind::System,
2223        _ => ItemKind::User,
2224    };
2225
2226    let content = value.get("content").cloned().unwrap_or(Value::Null);
2227    let text = if let Some(text) = content.get("text").and_then(Value::as_str) {
2228        text.to_string()
2229    } else if let Some(text) = content.as_str() {
2230        text.to_string()
2231    } else {
2232        content.to_string()
2233    };
2234
2235    Ok(Item {
2236        id: None,
2237        kind,
2238        parts: vec![Part::Text(TextPart {
2239            text,
2240            metadata: MetadataMap::new(),
2241        })],
2242        metadata: MetadataMap::new(),
2243    })
2244}
2245
2246fn required_string(value: &Value, field: &str) -> Result<String, McpError> {
2247    value
2248        .get(field)
2249        .and_then(Value::as_str)
2250        .map(str::to_owned)
2251        .ok_or_else(|| McpError::Protocol(format!("missing string field {field}")))
2252}
2253
2254fn value_to_invocable_output(value: Value) -> InvocableOutput {
2255    if let Some(content) = value.get("content").and_then(Value::as_array) {
2256        let text = content
2257            .iter()
2258            .filter_map(|item| item.get("text").and_then(Value::as_str))
2259            .collect::<Vec<_>>()
2260            .join("\n");
2261        if !text.is_empty() {
2262            return InvocableOutput::Text(text);
2263        }
2264    }
2265
2266    if let Some(text) = value.as_str() {
2267        InvocableOutput::Text(text.to_string())
2268    } else {
2269        InvocableOutput::Structured(value)
2270    }
2271}
2272
2273fn invocable_output_to_tool_output(output: InvocableOutput) -> ToolOutput {
2274    match output {
2275        InvocableOutput::Text(text) => ToolOutput::Text(text),
2276        InvocableOutput::Structured(value) => ToolOutput::Structured(value),
2277        InvocableOutput::Items(items) => {
2278            ToolOutput::Parts(items.into_iter().flat_map(|item| item.parts).collect())
2279        }
2280        InvocableOutput::Data(data) => ToolOutput::Structured(json!({ "data": data })),
2281    }
2282}
2283
2284fn metadata_to_value(metadata: &MetadataMap) -> Value {
2285    Value::Object(
2286        metadata
2287            .iter()
2288            .map(|(key, value)| (key.clone(), value.clone()))
2289            .collect(),
2290    )
2291}
2292
2293fn parse_auth_request(
2294    server_id: &McpServerId,
2295    method: &str,
2296    params: &Value,
2297    error: &Value,
2298) -> Option<AuthRequest> {
2299    let code = error.get("code").and_then(Value::as_i64);
2300    let message = error.get("message").and_then(Value::as_str);
2301    let data = error.get("data");
2302
2303    let auth_marker = matches!(code, Some(401 | -32001))
2304        || data
2305            .and_then(|data| data.get("auth_required"))
2306            .and_then(Value::as_bool)
2307            == Some(true)
2308        || data.and_then(|data| data.get("auth")).is_some();
2309
2310    if !auth_marker {
2311        return None;
2312    }
2313
2314    let mut challenge = MetadataMap::new();
2315    challenge.insert("server_id".into(), Value::String(server_id.to_string()));
2316    challenge.insert("method".into(), Value::String(method.into()));
2317
2318    if let Some(code) = code {
2319        challenge.insert("code".into(), Value::Number(code.into()));
2320    }
2321    if let Some(message) = message {
2322        challenge.insert("message".into(), Value::String(message.into()));
2323    }
2324    if let Some(data) = data {
2325        challenge.insert("data".into(), data.clone());
2326    }
2327
2328    Some(AuthRequest {
2329        task_id: None,
2330        id: format!("mcp:{}:{}", server_id, method),
2331        provider: format!("mcp.{}", server_id),
2332        operation: auth_operation_for_method(server_id, method, params),
2333        challenge,
2334    })
2335}
2336
2337fn auth_operation_for_method(
2338    server_id: &McpServerId,
2339    method: &str,
2340    params: &Value,
2341) -> AuthOperation {
2342    match method {
2343        "initialize" => AuthOperation::McpConnect {
2344            server_id: server_id.to_string(),
2345            metadata: MetadataMap::new(),
2346        },
2347        "tools/call" => AuthOperation::McpToolCall {
2348            server_id: server_id.to_string(),
2349            tool_name: params
2350                .get("name")
2351                .and_then(Value::as_str)
2352                .unwrap_or_default()
2353                .to_string(),
2354            input: params
2355                .get("arguments")
2356                .cloned()
2357                .unwrap_or_else(|| json!({})),
2358            metadata: MetadataMap::new(),
2359        },
2360        "resources/read" => AuthOperation::McpResourceRead {
2361            server_id: server_id.to_string(),
2362            resource_id: params
2363                .get("uri")
2364                .and_then(Value::as_str)
2365                .unwrap_or_default()
2366                .to_string(),
2367            metadata: MetadataMap::new(),
2368        },
2369        "prompts/get" => AuthOperation::McpPromptGet {
2370            server_id: server_id.to_string(),
2371            prompt_id: params
2372                .get("name")
2373                .and_then(Value::as_str)
2374                .unwrap_or_default()
2375                .to_string(),
2376            args: params
2377                .get("arguments")
2378                .cloned()
2379                .unwrap_or_else(|| json!({})),
2380            metadata: MetadataMap::new(),
2381        },
2382        other => AuthOperation::Custom {
2383            kind: format!("mcp.{other}"),
2384            payload: params.clone(),
2385            metadata: {
2386                let mut metadata = MetadataMap::new();
2387                metadata.insert("server_id".into(), Value::String(server_id.to_string()));
2388                metadata
2389            },
2390        },
2391    }
2392}
2393
2394fn normalize_mcp_tool_name(server_id: &McpServerId, tool_name: &str) -> String {
2395    let prefix = format!("mcp.{server_id}.");
2396    tool_name
2397        .strip_prefix(&prefix)
2398        .unwrap_or(tool_name)
2399        .to_string()
2400}
2401
2402async fn read_sse_stream<R>(
2403    mut reader: R,
2404    response_url: Url,
2405    frame_tx: mpsc::UnboundedSender<Result<McpFrame, McpError>>,
2406    endpoint_tx: oneshot::Sender<Result<Url, McpError>>,
2407) where
2408    R: AsyncBufRead + Unpin,
2409{
2410    let mut endpoint_tx = Some(endpoint_tx);
2411    loop {
2412        match read_next_sse_event(&mut reader).await {
2413            Ok(Some(event)) => {
2414                if let Some(endpoint) = legacy_sse_event_to_endpoint(&response_url, &event) {
2415                    if let Some(tx) = endpoint_tx.take() {
2416                        let _ = tx.send(endpoint);
2417                    }
2418                    continue;
2419                }
2420
2421                if let Some(frame) = legacy_sse_event_to_frame(event) {
2422                    let _ = frame_tx.send(frame);
2423                }
2424            }
2425            Ok(None) => break,
2426            Err(error) => {
2427                if let Some(tx) = endpoint_tx.take() {
2428                    let _ = tx.send(Err(error));
2429                } else {
2430                    let _ = frame_tx.send(Err(error));
2431                }
2432                return;
2433            }
2434        }
2435    }
2436
2437    if let Some(tx) = endpoint_tx.take() {
2438        let _ = tx.send(Err(McpError::Transport(
2439            "SSE stream ended before endpoint event".into(),
2440        )));
2441    }
2442}
2443
2444fn resolve_sse_endpoint(response_url: &Url, endpoint: &str) -> Result<Url, McpError> {
2445    response_url
2446        .join(endpoint.trim())
2447        .map_err(|error| McpError::Transport(format!("invalid SSE endpoint URL: {error}")))
2448}
2449
2450#[derive(Debug)]
2451struct SseEvent {
2452    event_name: Option<String>,
2453    data: String,
2454    id: Option<String>,
2455    retry_ms: Option<u64>,
2456}
2457
2458async fn read_next_sse_event<R>(reader: &mut R) -> Result<Option<SseEvent>, McpError>
2459where
2460    R: AsyncBufRead + Unpin,
2461{
2462    let mut event_name = None;
2463    let mut data_lines = Vec::new();
2464    let mut id = None;
2465    let mut retry_ms = None;
2466
2467    loop {
2468        let mut line = String::new();
2469        let read = reader.read_line(&mut line).await.map_err(McpError::Io)?;
2470        if read == 0 {
2471            if event_name.is_none() && data_lines.is_empty() && id.is_none() && retry_ms.is_none() {
2472                return Ok(None);
2473            }
2474            return Ok(Some(SseEvent {
2475                event_name,
2476                data: data_lines.join("\n"),
2477                id,
2478                retry_ms,
2479            }));
2480        }
2481
2482        let line = line.trim_end_matches(['\r', '\n']);
2483        if line.is_empty() {
2484            if event_name.is_none() && data_lines.is_empty() && id.is_none() && retry_ms.is_none() {
2485                continue;
2486            }
2487            return Ok(Some(SseEvent {
2488                event_name,
2489                data: data_lines.join("\n"),
2490                id,
2491                retry_ms,
2492            }));
2493        }
2494
2495        if line.starts_with(':') {
2496            continue;
2497        }
2498
2499        if let Some(rest) = line.strip_prefix("event:") {
2500            event_name = Some(rest.trim_start().to_string());
2501            continue;
2502        }
2503        if let Some(rest) = line.strip_prefix("data:") {
2504            data_lines.push(rest.trim_start().to_string());
2505            continue;
2506        }
2507        if let Some(rest) = line.strip_prefix("id:") {
2508            id = Some(rest.trim_start().to_string());
2509            continue;
2510        }
2511        if let Some(rest) = line.strip_prefix("retry:") {
2512            retry_ms = rest.trim_start().parse().ok();
2513        }
2514    }
2515}
2516
2517fn legacy_sse_event_to_endpoint(
2518    response_url: &Url,
2519    event: &SseEvent,
2520) -> Option<Result<Url, McpError>> {
2521    if event.event_name.as_deref() != Some("endpoint") {
2522        return None;
2523    }
2524    if event.data.is_empty() {
2525        return Some(Err(McpError::Transport(
2526            "legacy SSE endpoint event is missing data".into(),
2527        )));
2528    }
2529    Some(resolve_sse_endpoint(response_url, &event.data))
2530}
2531
2532fn legacy_sse_event_to_frame(event: SseEvent) -> Option<Result<McpFrame, McpError>> {
2533    let event_name = event.event_name.unwrap_or_else(|| "message".into());
2534    if event_name != "message" || event.data.is_empty() {
2535        return None;
2536    }
2537
2538    Some(
2539        serde_json::from_str(&event.data)
2540            .map_err(McpError::Serialize)
2541            .map(|value| McpFrame { value }),
2542    )
2543}
2544
2545fn streamable_http_event_to_frame(event: SseEvent) -> Result<Option<McpFrame>, McpError> {
2546    let event_name = event.event_name.unwrap_or_else(|| "message".into());
2547    if event_name != "message" || event.data.is_empty() {
2548        return Ok(None);
2549    }
2550
2551    let value = serde_json::from_str(&event.data).map_err(McpError::Serialize)?;
2552    Ok(Some(McpFrame { value }))
2553}
2554
2555fn is_jsonrpc_request(value: &Value) -> bool {
2556    value.get("method").is_some() && value.get("id").is_some()
2557}
2558
2559fn apply_streamable_http_headers(
2560    mut request: reqwest::RequestBuilder,
2561    headers: &[(String, String)],
2562    protocol_version: Option<&str>,
2563    session_id: Option<&str>,
2564) -> reqwest::RequestBuilder {
2565    for (key, value) in headers {
2566        request = request.header(key, value);
2567    }
2568
2569    if let Some(protocol_version) = protocol_version {
2570        request = request.header("MCP-Protocol-Version", protocol_version);
2571    }
2572    if let Some(session_id) = session_id {
2573        request = request.header("MCP-Session-Id", session_id);
2574    }
2575
2576    request
2577}
2578
2579async fn streamable_http_status_error(
2580    operation: &str,
2581    status: StatusCode,
2582    response: reqwest::Response,
2583) -> McpError {
2584    let body = response
2585        .text()
2586        .await
2587        .unwrap_or_else(|_| "<unreadable response body>".into());
2588    McpError::Transport(format!("{operation} failed with status {status}: {body}"))
2589}
2590
2591/// Errors produced by MCP transport, protocol, and lifecycle operations.
2592#[derive(Debug, Error)]
2593pub enum McpError {
2594    /// An underlying I/O error (e.g. spawning a child process or reading from a pipe).
2595    #[error("io error: {0}")]
2596    Io(#[from] std::io::Error),
2597    /// An HTTP-level error from the SSE transport.
2598    #[error("http error: {0}")]
2599    Http(#[from] reqwest::Error),
2600    /// A JSON serialization or deserialization error.
2601    #[error("serialization error: {0}")]
2602    Serialize(#[from] serde_json::Error),
2603    /// A transport-level error (e.g. unexpected disconnection or bad SSE response).
2604    #[error("transport error: {0}")]
2605    Transport(String),
2606    /// An MCP protocol violation (e.g. missing required fields in a response).
2607    #[error("protocol error: {0}")]
2608    Protocol(String),
2609    /// The server requires authentication before the operation can proceed.
2610    /// Contains the [`AuthRequest`] that describes the challenge.
2611    #[error("MCP auth required: {0:?}")]
2612    AuthRequired(Box<AuthRequest>),
2613    /// An error occurred while resolving or replaying authentication.
2614    #[error("auth resolution error: {0}")]
2615    AuthResolution(String),
2616    /// The MCP server returned an error for the invoked method.
2617    #[error("invocation error: {0}")]
2618    Invocation(String),
2619    /// The referenced server ID is not registered in the [`McpServerManager`].
2620    #[error("unknown MCP server: {0}")]
2621    UnknownServer(String),
2622}
2623
2624#[cfg(test)]
2625mod tests {
2626    use std::collections::VecDeque;
2627    use std::sync::{Arc as StdArc, Mutex as StdMutex};
2628
2629    use super::*;
2630    use agentkit_tools_core::{PermissionChecker, PermissionDecision, PermissionRequest};
2631    use tokio::io::{AsyncReadExt, AsyncWriteExt};
2632    use tokio::net::TcpListener;
2633
2634    struct AllowAll;
2635
2636    impl PermissionChecker for AllowAll {
2637        fn evaluate(&self, _request: &dyn PermissionRequest) -> PermissionDecision {
2638            PermissionDecision::Allow
2639        }
2640    }
2641
2642    struct FakeTransport {
2643        recv: VecDeque<Value>,
2644    }
2645
2646    #[async_trait]
2647    impl McpTransport for FakeTransport {
2648        async fn send(&mut self, _message: McpFrame) -> Result<(), McpError> {
2649            Ok(())
2650        }
2651
2652        async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
2653            Ok(self.recv.pop_front().map(|value| McpFrame { value }))
2654        }
2655
2656        async fn close(&mut self) -> Result<(), McpError> {
2657            Ok(())
2658        }
2659    }
2660
2661    fn fake_connection(responses: Vec<Value>) -> McpConnection {
2662        McpConnection {
2663            server_id: McpServerId::new("fake"),
2664            transport: Mutex::new(Box::new(FakeTransport {
2665                recv: responses.into(),
2666            })),
2667            auth: Mutex::new(None),
2668            next_id: AtomicU64::new(1),
2669        }
2670    }
2671
2672    #[derive(Clone)]
2673    struct FakeTransportFactory {
2674        responses: StdArc<StdMutex<VecDeque<Vec<Value>>>>,
2675    }
2676
2677    impl FakeTransportFactory {
2678        fn new(sequences: Vec<Vec<Value>>) -> Self {
2679            Self {
2680                responses: StdArc::new(StdMutex::new(sequences.into())),
2681            }
2682        }
2683    }
2684
2685    #[async_trait]
2686    impl McpTransportFactory for FakeTransportFactory {
2687        async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
2688            let responses =
2689                self.responses.lock().unwrap().pop_front().ok_or_else(|| {
2690                    McpError::Transport("no fake transport responses left".into())
2691                })?;
2692            Ok(Box::new(FakeTransport {
2693                recv: responses.into(),
2694            }))
2695        }
2696    }
2697
2698    #[tokio::test]
2699    async fn discovery_parses_snapshot() {
2700        let connection = fake_connection(vec![
2701            json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "echo", "description": "Echo", "inputSchema": {"type": "object"} }] } }),
2702            json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [{ "uri": "file:///tmp/example.txt", "name": "example.txt", "mimeType": "text/plain" }] } }),
2703            json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [{ "name": "summarize", "description": "Summarize", "arguments": [{ "name": "path" }] }] } }),
2704        ]);
2705
2706        let snapshot = connection.discover().await.unwrap();
2707        assert_eq!(snapshot.tools[0].name, "echo");
2708        assert_eq!(snapshot.resources[0].id, "file:///tmp/example.txt");
2709        assert_eq!(snapshot.prompts[0].id, "summarize");
2710    }
2711
2712    #[tokio::test]
2713    async fn tool_adapter_returns_text_output() {
2714        let connection = Arc::new(fake_connection(vec![json!({
2715            "jsonrpc": "2.0",
2716            "id": 1,
2717            "result": { "content": [{ "type": "text", "text": "pong" }] }
2718        })]));
2719        let server_id = connection.server_id().clone();
2720        let adapter = McpToolAdapter::new(
2721            &server_id,
2722            connection,
2723            McpToolDescriptor {
2724                name: "echo".into(),
2725                description: Some("Echo".into()),
2726                input_schema: json!({ "type": "object" }),
2727                metadata: MetadataMap::new(),
2728            },
2729        );
2730        let metadata = MetadataMap::new();
2731        let mut ctx = ToolContext {
2732            capability: CapabilityContext {
2733                session_id: None,
2734                turn_id: None,
2735                metadata: &metadata,
2736            },
2737            permissions: &AllowAll,
2738            resources: &(),
2739            cancellation: None,
2740        };
2741
2742        let result = adapter
2743            .invoke(
2744                ToolRequest {
2745                    call_id: "call-1".into(),
2746                    tool_name: ToolName::new("mcp.fake.echo"),
2747                    input: json!({}),
2748                    session_id: "session-1".into(),
2749                    turn_id: "turn-1".into(),
2750                    metadata: MetadataMap::new(),
2751                },
2752                &mut ctx,
2753            )
2754            .await
2755            .unwrap();
2756
2757        assert_eq!(result.result.output, ToolOutput::Text("pong".into()));
2758    }
2759
2760    #[tokio::test]
2761    async fn request_surfaces_auth_required_errors() {
2762        let connection = fake_connection(vec![json!({
2763            "jsonrpc": "2.0",
2764            "id": 1,
2765            "error": {
2766                "code": -32001,
2767                "message": "authentication required",
2768                "data": {
2769                    "auth_required": true,
2770                    "scope": "secrets.read"
2771                }
2772            }
2773        })]);
2774
2775        let error = connection.call_tool("echo", json!({})).await.unwrap_err();
2776        match error {
2777            McpError::AuthRequired(request) => {
2778                assert_eq!(request.provider, "mcp.fake");
2779                assert_eq!(
2780                    request.challenge.get("method"),
2781                    Some(&Value::String("tools/call".into()))
2782                );
2783                assert!(matches!(
2784                    request.operation,
2785                    AuthOperation::McpToolCall { ref tool_name, .. } if tool_name == "echo"
2786                ));
2787            }
2788            other => panic!("unexpected error: {other:?}"),
2789        }
2790    }
2791
2792    #[tokio::test]
2793    async fn tool_adapter_maps_auth_required_into_tool_error() {
2794        let connection = Arc::new(fake_connection(vec![json!({
2795            "jsonrpc": "2.0",
2796            "id": 1,
2797            "error": {
2798                "code": -32001,
2799                "message": "authentication required",
2800                "data": { "auth_required": true }
2801            }
2802        })]));
2803        let server_id = connection.server_id().clone();
2804        let adapter = McpToolAdapter::new(
2805            &server_id,
2806            connection,
2807            McpToolDescriptor {
2808                name: "echo".into(),
2809                description: Some("Echo".into()),
2810                input_schema: json!({ "type": "object" }),
2811                metadata: MetadataMap::new(),
2812            },
2813        );
2814        let metadata = MetadataMap::new();
2815        let mut ctx = ToolContext {
2816            capability: CapabilityContext {
2817                session_id: None,
2818                turn_id: None,
2819                metadata: &metadata,
2820            },
2821            permissions: &AllowAll,
2822            resources: &(),
2823            cancellation: None,
2824        };
2825
2826        let error = adapter
2827            .invoke(
2828                ToolRequest {
2829                    call_id: "call-1".into(),
2830                    tool_name: ToolName::new("mcp.fake.echo"),
2831                    input: json!({}),
2832                    session_id: "session-1".into(),
2833                    turn_id: "turn-1".into(),
2834                    metadata: MetadataMap::new(),
2835                },
2836                &mut ctx,
2837            )
2838            .await
2839            .unwrap_err();
2840
2841        match error {
2842            ToolError::AuthRequired(request) => {
2843                assert_eq!(request.provider, "mcp.fake");
2844            }
2845            other => panic!("unexpected error: {other:?}"),
2846        }
2847    }
2848
2849    struct RecordingTransport {
2850        recv: VecDeque<Value>,
2851        sent: StdArc<StdMutex<Vec<Value>>>,
2852    }
2853
2854    #[async_trait]
2855    impl McpTransport for RecordingTransport {
2856        async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
2857            self.sent.lock().unwrap().push(message.value);
2858            Ok(())
2859        }
2860
2861        async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
2862            Ok(self.recv.pop_front().map(|value| McpFrame { value }))
2863        }
2864
2865        async fn close(&mut self) -> Result<(), McpError> {
2866            Ok(())
2867        }
2868    }
2869
2870    #[derive(Clone)]
2871    struct RecordingTransportFactory {
2872        responses: StdArc<StdMutex<VecDeque<Vec<Value>>>>,
2873        sent: StdArc<StdMutex<Vec<Value>>>,
2874    }
2875
2876    impl RecordingTransportFactory {
2877        fn new(sequences: Vec<Vec<Value>>) -> Self {
2878            Self {
2879                responses: StdArc::new(StdMutex::new(sequences.into())),
2880                sent: StdArc::new(StdMutex::new(Vec::new())),
2881            }
2882        }
2883
2884        fn sent(&self) -> Vec<Value> {
2885            self.sent.lock().unwrap().clone()
2886        }
2887    }
2888
2889    #[async_trait]
2890    impl McpTransportFactory for RecordingTransportFactory {
2891        async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
2892            let responses = self.responses.lock().unwrap().pop_front().ok_or_else(|| {
2893                McpError::Transport("no recording transport responses left".into())
2894            })?;
2895            Ok(Box::new(RecordingTransport {
2896                recv: responses.into(),
2897                sent: self.sent.clone(),
2898            }))
2899        }
2900    }
2901
2902    #[tokio::test]
2903    async fn connection_includes_resolved_auth_in_future_requests() {
2904        let factory = RecordingTransportFactory::new(vec![vec![
2905            json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
2906            json!({ "jsonrpc": "2.0", "id": 1, "result": { "content": [{ "type": "text", "text": "ok" }] } }),
2907        ]]);
2908        let config = McpServerConfig::new(
2909            "recording",
2910            McpTransportBinding::Custom(Arc::new(factory.clone())),
2911        );
2912        let connection = McpConnection::connect(&config).await.unwrap();
2913        let mut auth = MetadataMap::new();
2914        auth.insert("token".into(), json!("secret-token"));
2915        let request = AuthRequest {
2916            task_id: None,
2917            id: "auth-recording-tool".into(),
2918            provider: "mcp.recording".into(),
2919            operation: AuthOperation::McpToolCall {
2920                server_id: "recording".into(),
2921                tool_name: "echo".into(),
2922                input: json!({}),
2923                metadata: MetadataMap::new(),
2924            },
2925            challenge: MetadataMap::new(),
2926        };
2927        connection
2928            .resolve_auth(agentkit_tools_core::AuthResolution::Provided {
2929                request,
2930                credentials: auth,
2931            })
2932            .await
2933            .unwrap();
2934
2935        let _ = connection.call_tool("echo", json!({})).await.unwrap();
2936        let sent = factory.sent();
2937        assert!(
2938            sent.iter().any(|value| {
2939                value
2940                    .get("params")
2941                    .and_then(|params| params.get("auth"))
2942                    .and_then(|auth| auth.get("token"))
2943                    == Some(&json!("secret-token"))
2944            }),
2945            "expected an MCP request to include the resolved auth payload, saw {:?}",
2946            sent
2947        );
2948    }
2949
2950    #[tokio::test]
2951    async fn manager_reuses_stored_auth_on_connect() {
2952        let factory = RecordingTransportFactory::new(vec![vec![
2953            json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
2954            json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
2955            json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
2956            json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
2957        ]]);
2958        let server_id = McpServerId::new("recording");
2959        let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
2960            server_id.to_string(),
2961            McpTransportBinding::Custom(Arc::new(factory.clone())),
2962        ));
2963        let mut auth = MetadataMap::new();
2964        auth.insert("token".into(), json!("seed-token"));
2965        let request = AuthRequest {
2966            task_id: None,
2967            id: "auth-recording-connect".into(),
2968            provider: "mcp.recording".into(),
2969            operation: AuthOperation::McpConnect {
2970                server_id: server_id.to_string(),
2971                metadata: MetadataMap::new(),
2972            },
2973            challenge: MetadataMap::new(),
2974        };
2975        manager
2976            .resolve_auth(agentkit_tools_core::AuthResolution::Provided {
2977                request,
2978                credentials: auth,
2979            })
2980            .await
2981            .unwrap();
2982
2983        manager.connect_server(&server_id).await.unwrap();
2984        let sent = factory.sent();
2985        assert!(
2986            sent.iter().any(|value| {
2987                value.get("method").and_then(Value::as_str) == Some("initialize")
2988                    && value
2989                        .get("params")
2990                        .and_then(|params| params.get("auth"))
2991                        .and_then(|auth| auth.get("token"))
2992                        == Some(&json!("seed-token"))
2993            }),
2994            "expected initialize to include stored auth, saw {:?}",
2995            sent
2996        );
2997    }
2998
2999    #[tokio::test]
3000    async fn manager_resolves_auth_and_replays_resource_read() {
3001        let factory = RecordingTransportFactory::new(vec![vec![
3002            json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
3003            json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
3004            json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3005            json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3006            json!({
3007                "jsonrpc": "2.0",
3008                "id": 4,
3009                "result": {
3010                    "contents": [
3011                        {
3012                            "uri": "file:///tmp/secret.txt",
3013                            "text": "secret from resource"
3014                        }
3015                    ]
3016                }
3017            }),
3018        ]]);
3019        let server_id = McpServerId::new("recording");
3020        let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
3021            server_id.to_string(),
3022            McpTransportBinding::Custom(Arc::new(factory.clone())),
3023        ));
3024        let mut auth = MetadataMap::new();
3025        auth.insert("token".into(), json!("resource-token"));
3026        let request = AuthRequest {
3027            task_id: None,
3028            id: "auth-recording-resource".into(),
3029            provider: "mcp.recording".into(),
3030            operation: AuthOperation::McpResourceRead {
3031                server_id: server_id.to_string(),
3032                resource_id: "file:///tmp/secret.txt".into(),
3033                metadata: MetadataMap::new(),
3034            },
3035            challenge: MetadataMap::new(),
3036        };
3037
3038        let result = manager
3039            .resolve_auth_and_resume(agentkit_tools_core::AuthResolution::Provided {
3040                request,
3041                credentials: auth,
3042            })
3043            .await
3044            .unwrap();
3045
3046        match result {
3047            McpOperationResult::Resource(contents) => {
3048                assert_eq!(
3049                    contents.data,
3050                    DataRef::InlineText("secret from resource".into())
3051                );
3052            }
3053            other => panic!("unexpected replay result: {other:?}"),
3054        }
3055
3056        let sent = factory.sent();
3057        assert!(
3058            sent.iter().any(|value| {
3059                value.get("method").and_then(Value::as_str) == Some("resources/read")
3060                    && value
3061                        .get("params")
3062                        .and_then(|params| params.get("auth"))
3063                        .and_then(|auth| auth.get("token"))
3064                        == Some(&json!("resource-token"))
3065            }),
3066            "expected resources/read to include resolved auth, saw {:?}",
3067            sent
3068        );
3069    }
3070
3071    #[tokio::test]
3072    async fn manager_resolves_auth_and_replays_connect() {
3073        let factory = RecordingTransportFactory::new(vec![vec![
3074            json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
3075            json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
3076            json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3077            json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3078        ]]);
3079        let server_id = McpServerId::new("recording");
3080        let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
3081            server_id.to_string(),
3082            McpTransportBinding::Custom(Arc::new(factory.clone())),
3083        ));
3084        let mut auth = MetadataMap::new();
3085        auth.insert("token".into(), json!("connect-token"));
3086        let request = AuthRequest {
3087            task_id: None,
3088            id: "auth-recording-connect-replay".into(),
3089            provider: "mcp.recording".into(),
3090            operation: AuthOperation::McpConnect {
3091                server_id: server_id.to_string(),
3092                metadata: MetadataMap::new(),
3093            },
3094            challenge: MetadataMap::new(),
3095        };
3096
3097        let result = manager
3098            .resolve_auth_and_resume(agentkit_tools_core::AuthResolution::Provided {
3099                request,
3100                credentials: auth,
3101            })
3102            .await
3103            .unwrap();
3104
3105        match result {
3106            McpOperationResult::Connected(snapshot) => {
3107                assert_eq!(snapshot.server_id, server_id);
3108            }
3109            other => panic!("unexpected replay result: {other:?}"),
3110        }
3111    }
3112
3113    #[tokio::test]
3114    async fn sse_transport_posts_messages_and_receives_frames() {
3115        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3116        let address = listener.local_addr().unwrap();
3117        let requests = StdArc::new(StdMutex::new(Vec::new()));
3118        let captured = requests.clone();
3119
3120        let server = tokio::spawn(async move {
3121            for _ in 0..2 {
3122                let (mut socket, _) = listener.accept().await.unwrap();
3123                let mut buffer = vec![0_u8; 4096];
3124                let read = socket.read(&mut buffer).await.unwrap();
3125                let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3126
3127                if request.starts_with("GET /sse ") {
3128                    let body = concat!(
3129                        "event: endpoint\n",
3130                        "data: /messages\n\n",
3131                        "event: message\n",
3132                        "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}\n\n"
3133                    );
3134                    let response = format!(
3135                        "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3136                        body.len(),
3137                        body
3138                    );
3139                    socket.write_all(response.as_bytes()).await.unwrap();
3140                } else {
3141                    captured.lock().unwrap().push(request);
3142                    socket
3143                        .write_all(
3144                            b"HTTP/1.1 202 Accepted\r\ncontent-length: 0\r\nconnection: close\r\n\r\n",
3145                        )
3146                        .await
3147                        .unwrap();
3148                }
3149            }
3150        });
3151
3152        let factory =
3153            SseTransportFactory::new(SseTransportConfig::new(format!("http://{address}/sse")));
3154        let mut transport = factory.connect().await.unwrap();
3155        transport
3156            .send(McpFrame {
3157                value: json!({
3158                    "jsonrpc": "2.0",
3159                    "id": 1,
3160                    "method": "tools/list",
3161                    "params": {}
3162                }),
3163            })
3164            .await
3165            .unwrap();
3166        let frame = transport.recv().await.unwrap().unwrap();
3167        transport.close().await.unwrap();
3168        server.await.unwrap();
3169
3170        assert_eq!(frame.value["result"]["tools"], json!([]));
3171        let requests = requests.lock().unwrap();
3172        assert_eq!(requests.len(), 1);
3173        assert!(requests[0].starts_with("POST /messages "));
3174        assert!(requests[0].contains("\"method\":\"tools/list\""));
3175    }
3176
3177    #[tokio::test]
3178    async fn streamable_http_connection_tracks_session_and_protocol_headers() {
3179        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3180        let address = listener.local_addr().unwrap();
3181        let requests = StdArc::new(StdMutex::new(Vec::new()));
3182        let captured = requests.clone();
3183
3184        let server = tokio::spawn(async move {
3185            for _ in 0..4 {
3186                let (mut socket, _) = listener.accept().await.unwrap();
3187                let mut buffer = vec![0_u8; 8192];
3188                let read = socket.read(&mut buffer).await.unwrap();
3189                let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3190                captured.lock().unwrap().push(request.clone());
3191
3192                let response = if request.contains("\"method\":\"initialize\"") {
3193                    let body = "{\"jsonrpc\":\"2.0\",\"id\":0,\"result\":{\"protocolVersion\":\"2025-11-25\",\"capabilities\":{},\"serverInfo\":{\"name\":\"remote\",\"version\":\"1.0.0\"}}}";
3194                    format!(
3195                        "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\nMCP-Session-Id: session-123\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3196                        body.len(),
3197                        body
3198                    )
3199                } else if request.contains("\"method\":\"notifications/initialized\"") {
3200                    "HTTP/1.1 202 Accepted\r\ncontent-length: 0\r\nconnection: close\r\n\r\n"
3201                        .to_string()
3202                } else if request.starts_with("DELETE /mcp ") {
3203                    "HTTP/1.1 204 No Content\r\ncontent-length: 0\r\nconnection: close\r\n\r\n"
3204                        .to_string()
3205                } else {
3206                    let body = "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}";
3207                    format!(
3208                        "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3209                        body.len(),
3210                        body
3211                    )
3212                };
3213
3214                socket.write_all(response.as_bytes()).await.unwrap();
3215            }
3216        });
3217
3218        let config = McpServerConfig::new(
3219            "remote",
3220            McpTransportBinding::StreamableHttp(StreamableHttpTransportConfig::new(format!(
3221                "http://{address}/mcp"
3222            ))),
3223        );
3224        let connection = McpConnection::connect(&config).await.unwrap();
3225        let _ = connection.list_tools().await.unwrap();
3226        connection.close().await.unwrap();
3227        server.await.unwrap();
3228
3229        let requests = requests.lock().unwrap();
3230        assert_eq!(requests.len(), 4);
3231        let normalized = requests
3232            .iter()
3233            .map(|request| request.to_ascii_lowercase())
3234            .collect::<Vec<_>>();
3235        assert!(requests[0].starts_with("POST /mcp "));
3236        assert!(!requests[0].contains("MCP-Session-Id:"));
3237        assert!(normalized[1].contains("mcp-session-id: session-123"));
3238        assert!(normalized[1].contains("mcp-protocol-version: 2025-11-25"));
3239        assert!(normalized[2].contains("mcp-session-id: session-123"));
3240        assert!(normalized[2].contains("mcp-protocol-version: 2025-11-25"));
3241        assert!(requests[3].starts_with("DELETE /mcp "));
3242        assert!(normalized[3].contains("mcp-session-id: session-123"));
3243    }
3244
3245    #[tokio::test]
3246    async fn streamable_http_transport_resumes_sse_streams_until_response_arrives() {
3247        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3248        let address = listener.local_addr().unwrap();
3249        let requests = StdArc::new(StdMutex::new(Vec::new()));
3250        let captured = requests.clone();
3251
3252        let server = tokio::spawn(async move {
3253            for _ in 0..2 {
3254                let (mut socket, _) = listener.accept().await.unwrap();
3255                let mut buffer = vec![0_u8; 8192];
3256                let read = socket.read(&mut buffer).await.unwrap();
3257                let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3258                captured.lock().unwrap().push(request.clone());
3259
3260                let response = if request.starts_with("POST /mcp ") {
3261                    let body = concat!(
3262                        "id: evt-1\n",
3263                        "event: message\n",
3264                        "data: {\"jsonrpc\":\"2.0\",\"method\":\"notifications/message\",\"params\":{\"phase\":\"stream-start\"}}\n\n"
3265                    );
3266                    format!(
3267                        "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3268                        body.len(),
3269                        body
3270                    )
3271                } else {
3272                    let body = concat!(
3273                        "id: evt-2\n",
3274                        "event: message\n",
3275                        "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}\n\n"
3276                    );
3277                    format!(
3278                        "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3279                        body.len(),
3280                        body
3281                    )
3282                };
3283
3284                socket.write_all(response.as_bytes()).await.unwrap();
3285            }
3286        });
3287
3288        let factory = StreamableHttpTransportFactory::new(StreamableHttpTransportConfig::new(
3289            format!("http://{address}/mcp"),
3290        ));
3291        let mut transport = factory.connect().await.unwrap();
3292        transport
3293            .send(McpFrame {
3294                value: json!({
3295                    "jsonrpc": "2.0",
3296                    "id": 1,
3297                    "method": "tools/list",
3298                    "params": {}
3299                }),
3300            })
3301            .await
3302            .unwrap();
3303
3304        let first = transport.recv().await.unwrap().unwrap();
3305        let second = transport.recv().await.unwrap().unwrap();
3306        transport.close().await.unwrap();
3307        server.await.unwrap();
3308
3309        assert_eq!(
3310            first.value["method"],
3311            Value::String("notifications/message".into())
3312        );
3313        assert_eq!(second.value["result"]["tools"], json!([]));
3314
3315        let requests = requests.lock().unwrap();
3316        assert_eq!(requests.len(), 2);
3317        assert!(requests[0].starts_with("POST /mcp "));
3318        assert!(requests[1].starts_with("GET /mcp "));
3319        assert!(
3320            requests[1].contains("last-event-id: evt-1")
3321                || requests[1].contains("Last-Event-ID: evt-1")
3322        );
3323    }
3324
3325    #[tokio::test]
3326    async fn server_manager_connects_refreshes_and_aggregates_tools() {
3327        let alpha = McpServerConfig::new(
3328            "alpha",
3329            McpTransportBinding::Custom(Arc::new(FakeTransportFactory::new(vec![vec![
3330                json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "alpha", "version": "1.0.0" } } }),
3331                json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "echo", "description": "Echo", "inputSchema": {"type": "object"} }] } }),
3332                json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3333                json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3334                json!({ "jsonrpc": "2.0", "id": 4, "result": { "tools": [{ "name": "echo_v2", "description": "Echo 2", "inputSchema": {"type": "object"} }] } }),
3335                json!({ "jsonrpc": "2.0", "id": 5, "result": { "resources": [] } }),
3336                json!({ "jsonrpc": "2.0", "id": 6, "result": { "prompts": [] } }),
3337            ]]))),
3338        );
3339        let beta = McpServerConfig::new(
3340            "beta",
3341            McpTransportBinding::Custom(Arc::new(FakeTransportFactory::new(vec![vec![
3342                json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "beta", "version": "1.0.0" } } }),
3343                json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "search", "description": "Search", "inputSchema": {"type": "object"} }] } }),
3344                json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3345                json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3346            ]]))),
3347        );
3348
3349        let mut manager = McpServerManager::new().with_server(alpha).with_server(beta);
3350
3351        let handles = manager.connect_all().await.unwrap();
3352        assert_eq!(handles.len(), 2);
3353        assert_eq!(
3354            manager
3355                .tool_registry()
3356                .specs()
3357                .into_iter()
3358                .map(|spec| spec.name.0)
3359                .collect::<Vec<_>>(),
3360            vec!["mcp.alpha.echo".to_string(), "mcp.beta.search".to_string()]
3361        );
3362
3363        let refreshed = manager
3364            .refresh_server(&McpServerId::new("alpha"))
3365            .await
3366            .unwrap();
3367        assert_eq!(refreshed.tools[0].name, "echo_v2");
3368        assert_eq!(
3369            manager
3370                .connected_server(&McpServerId::new("alpha"))
3371                .unwrap()
3372                .snapshot()
3373                .tools[0]
3374                .name,
3375            "echo_v2"
3376        );
3377
3378        let capabilities = manager.capability_provider();
3379        assert_eq!(capabilities.invocables().len(), 2);
3380
3381        manager
3382            .disconnect_server(&McpServerId::new("alpha"))
3383            .await
3384            .unwrap();
3385        assert!(
3386            manager
3387                .connected_server(&McpServerId::new("alpha"))
3388                .is_none()
3389        );
3390    }
3391}