Skip to main content

agentkit_mcp/
lib.rs

1use std::collections::{BTreeMap, VecDeque};
2use std::fmt;
3use std::process::Stdio;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use agentkit_capabilities::{
9    CapabilityContext, CapabilityError, CapabilityName, CapabilityProvider, Invocable,
10    InvocableOutput, InvocableRequest, InvocableResult, InvocableSpec, PromptContents,
11    PromptDescriptor, PromptId, PromptProvider, ResourceContents, ResourceDescriptor, ResourceId,
12    ResourceProvider,
13};
14use agentkit_core::{
15    DataRef, Item, ItemKind, MetadataMap, Part, TextPart, ToolOutput, ToolResultPart,
16};
17use agentkit_tools_core::{
18    AuthOperation, AuthRequest, AuthResolution, Tool, ToolAnnotations, ToolContext, ToolError,
19    ToolName, ToolRegistry, ToolRequest, ToolResult, ToolSpec,
20};
21use async_trait::async_trait;
22use futures_util::TryStreamExt;
23use reqwest::{Client, StatusCode, Url};
24use serde::{Deserialize, Serialize};
25use serde_json::{Value, json};
26use thiserror::Error;
27use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader};
28use tokio::process::{Child, ChildStdin, ChildStdout, Command};
29use tokio::sync::{Mutex, mpsc, oneshot};
30use tokio::task::JoinHandle;
31use tokio::time::sleep;
32use tokio_util::io::StreamReader;
33
34const MCP_LATEST_PROTOCOL_VERSION: &str = "2025-11-25";
35const MCP_SUPPORTED_PROTOCOL_VERSIONS: &[&str] =
36    &["2025-11-25", "2025-06-18", "2025-03-26", "2024-11-05"];
37
38/// Unique identifier for a registered MCP server.
39///
40/// Each MCP server in a [`McpServerManager`] is addressed by its `McpServerId`.
41/// The inner string is typically a short, human-readable name such as `"filesystem"`
42/// or `"github"`.
43///
44/// # Example
45///
46/// ```rust
47/// use agentkit_mcp::McpServerId;
48///
49/// let id = McpServerId::new("filesystem");
50/// assert_eq!(id.to_string(), "filesystem");
51/// ```
52#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
53pub struct McpServerId(pub String);
54
55impl McpServerId {
56    /// Creates a new server identifier from any string-like value.
57    pub fn new(value: impl Into<String>) -> Self {
58        Self(value.into())
59    }
60}
61
62impl fmt::Display for McpServerId {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        self.0.fmt(f)
65    }
66}
67
68/// Configuration for an MCP server that communicates over standard I/O (stdin/stdout).
69///
70/// This is the most common transport for local MCP servers. The specified command is
71/// spawned as a child process, and JSON-RPC messages are exchanged line-by-line over
72/// its stdin and stdout streams.
73///
74/// # Example
75///
76/// ```rust
77/// use agentkit_mcp::StdioTransportConfig;
78///
79/// let config = StdioTransportConfig::new("npx")
80///     .with_arg("-y")
81///     .with_arg("@modelcontextprotocol/server-filesystem")
82///     .with_env("HOME", "/home/user")
83///     .with_cwd("/tmp");
84/// ```
85#[derive(Clone, Debug, PartialEq, Eq)]
86pub struct StdioTransportConfig {
87    /// The executable to launch (e.g. `"npx"`, `"python"`, `"node"`).
88    pub command: String,
89    /// Command-line arguments passed to the executable.
90    pub args: Vec<String>,
91    /// Additional environment variables set for the child process.
92    pub env: Vec<(String, String)>,
93    /// Optional working directory for the child process.
94    pub cwd: Option<std::path::PathBuf>,
95}
96
97impl StdioTransportConfig {
98    /// Creates a new stdio transport configuration for the given command.
99    pub fn new(command: impl Into<String>) -> Self {
100        Self {
101            command: command.into(),
102            args: Vec::new(),
103            env: Vec::new(),
104            cwd: None,
105        }
106    }
107
108    /// Appends a command-line argument. Returns `self` for chaining.
109    pub fn with_arg(mut self, arg: impl Into<String>) -> Self {
110        self.args.push(arg.into());
111        self
112    }
113
114    /// Adds an environment variable for the child process. Returns `self` for chaining.
115    pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
116        self.env.push((key.into(), value.into()));
117        self
118    }
119
120    /// Sets the working directory for the child process. Returns `self` for chaining.
121    pub fn with_cwd(mut self, cwd: impl Into<std::path::PathBuf>) -> Self {
122        self.cwd = Some(cwd.into());
123        self
124    }
125}
126
127/// Configuration for an MCP server that communicates over Server-Sent Events (SSE).
128///
129/// Use this transport for remote MCP servers exposed over HTTP. The client opens an
130/// SSE stream to the given URL, receives an `endpoint` event pointing to the POST
131/// endpoint, and then exchanges JSON-RPC messages over that endpoint.
132///
133/// # Example
134///
135/// ```rust
136/// use agentkit_mcp::SseTransportConfig;
137///
138/// let config = SseTransportConfig::new("https://mcp.example.com/sse")
139///     .with_header("Authorization", "Bearer tok_abc123");
140/// ```
141#[derive(Clone, Debug, PartialEq, Eq)]
142pub struct SseTransportConfig {
143    /// The SSE endpoint URL to connect to.
144    pub url: String,
145    /// Additional HTTP headers sent with every request (e.g. authentication tokens).
146    pub headers: Vec<(String, String)>,
147}
148
149impl SseTransportConfig {
150    /// Creates a new SSE transport configuration for the given URL.
151    pub fn new(url: impl Into<String>) -> Self {
152        Self {
153            url: url.into(),
154            headers: Vec::new(),
155        }
156    }
157
158    /// Adds an HTTP header to include with every request. Returns `self` for chaining.
159    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
160        self.headers.push((key.into(), value.into()));
161        self
162    }
163}
164
165/// Configuration for an MCP server that communicates over Streamable HTTP.
166///
167/// Use this transport for modern remote MCP servers that expose a single HTTP
168/// endpoint supporting JSON-RPC over POST, with optional SSE responses for
169/// streaming server messages.
170///
171/// # Example
172///
173/// ```rust
174/// use agentkit_mcp::StreamableHttpTransportConfig;
175///
176/// let config = StreamableHttpTransportConfig::new("https://mcp.example.com/mcp")
177///     .with_header("Authorization", "Bearer tok_abc123");
178/// ```
179#[derive(Clone, Debug, PartialEq, Eq)]
180pub struct StreamableHttpTransportConfig {
181    /// The MCP endpoint URL to connect to.
182    pub url: String,
183    /// Additional HTTP headers sent with every request (e.g. authentication tokens).
184    pub headers: Vec<(String, String)>,
185}
186
187impl StreamableHttpTransportConfig {
188    /// Creates a new Streamable HTTP transport configuration for the given MCP endpoint.
189    pub fn new(url: impl Into<String>) -> Self {
190        Self {
191            url: url.into(),
192            headers: Vec::new(),
193        }
194    }
195
196    /// Adds an HTTP header to include with every request. Returns `self` for chaining.
197    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
198        self.headers.push((key.into(), value.into()));
199        self
200    }
201}
202
203/// Selects which transport an MCP server should use.
204///
205/// This enum is passed into [`McpServerConfig`] and determines how the client will
206/// communicate with the MCP server. The built-in options are [`Stdio`](Self::Stdio),
207/// [`StreamableHttp`](Self::StreamableHttp), and the legacy [`Sse`](Self::Sse);
208/// use [`Custom`](Self::Custom) to provide your own [`McpTransportFactory`].
209#[derive(Clone)]
210pub enum McpTransportBinding {
211    /// Communicate over the child process's stdin/stdout.
212    Stdio(StdioTransportConfig),
213    /// Communicate over the MCP Streamable HTTP transport.
214    StreamableHttp(StreamableHttpTransportConfig),
215    /// Communicate over HTTP Server-Sent Events.
216    Sse(SseTransportConfig),
217    /// A user-supplied transport factory.
218    Custom(Arc<dyn McpTransportFactory>),
219}
220
221/// Full configuration for a single MCP server, combining an identifier, a transport
222/// binding, and optional metadata.
223///
224/// Register one or more of these with [`McpServerManager`] to manage the lifecycle
225/// of MCP servers in an agentkit runtime.
226///
227/// # Example
228///
229/// ```rust
230/// use agentkit_mcp::{McpServerConfig, McpTransportBinding, StdioTransportConfig};
231///
232/// let config = McpServerConfig::new(
233///     "filesystem",
234///     McpTransportBinding::Stdio(
235///         StdioTransportConfig::new("npx")
236///             .with_arg("-y")
237///             .with_arg("@modelcontextprotocol/server-filesystem"),
238///     ),
239/// );
240/// ```
241#[derive(Clone)]
242pub struct McpServerConfig {
243    /// Unique identifier for this server.
244    pub id: McpServerId,
245    /// Transport binding that determines how communication happens.
246    pub transport: McpTransportBinding,
247    /// Arbitrary metadata attached to this server configuration.
248    pub metadata: MetadataMap,
249}
250
251impl McpServerConfig {
252    /// Creates a new server configuration with the given identifier and transport.
253    ///
254    /// # Arguments
255    ///
256    /// * `id` - A unique name for this server (e.g. `"filesystem"`).
257    /// * `transport` - The [`McpTransportBinding`] that determines how to connect.
258    pub fn new(id: impl Into<String>, transport: McpTransportBinding) -> Self {
259        Self {
260            id: McpServerId::new(id),
261            transport,
262            metadata: MetadataMap::new(),
263        }
264    }
265
266    /// Creates a stdio-backed server configuration.
267    pub fn stdio(id: impl Into<String>, command: impl Into<String>) -> Self {
268        Self::new(
269            id,
270            McpTransportBinding::Stdio(StdioTransportConfig::new(command)),
271        )
272    }
273
274    /// Creates an SSE-backed server configuration.
275    pub fn sse(id: impl Into<String>, url: impl Into<String>) -> Self {
276        Self::new(id, McpTransportBinding::Sse(SseTransportConfig::new(url)))
277    }
278
279    /// Creates a Streamable HTTP-backed server configuration.
280    pub fn streamable_http(id: impl Into<String>, url: impl Into<String>) -> Self {
281        Self::new(
282            id,
283            McpTransportBinding::StreamableHttp(StreamableHttpTransportConfig::new(url)),
284        )
285    }
286
287    /// Replaces the configuration metadata.
288    pub fn with_metadata(mut self, metadata: MetadataMap) -> Self {
289        self.metadata = metadata;
290        self
291    }
292}
293
294/// A single JSON-RPC frame exchanged with an MCP server.
295///
296/// This is the low-level wire unit. Most users will not interact with `McpFrame`
297/// directly; instead use [`McpConnection`] or the higher-level adapters.
298#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
299pub struct McpFrame {
300    /// The raw JSON-RPC value (request, response, or notification).
301    pub value: Value,
302}
303
304/// Factory trait for creating new [`McpTransport`] connections.
305///
306/// Implement this trait to provide a custom transport mechanism. The built-in
307/// [`StdioTransportFactory`] and [`SseTransportFactory`] cover the two standard
308/// MCP transports; use this trait for in-memory, WebSocket, or other custom
309/// transports.
310///
311/// # Errors
312///
313/// Returns [`McpError`] if the connection cannot be established.
314#[async_trait]
315pub trait McpTransportFactory: Send + Sync {
316    /// Establishes a new transport connection and returns it.
317    async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError>;
318}
319
320/// Bidirectional transport for exchanging [`McpFrame`] messages with an MCP server.
321///
322/// Implement this trait to provide a custom transport. Each transport instance
323/// represents a single, live connection.
324///
325/// # Errors
326///
327/// All methods return [`McpError`] on I/O or protocol failures.
328#[async_trait]
329pub trait McpTransport: Send + Sync {
330    /// Sends a JSON-RPC frame to the server.
331    async fn send(&mut self, message: McpFrame) -> Result<(), McpError>;
332    /// Receives the next JSON-RPC frame from the server, or `None` if the stream has ended.
333    async fn recv(&mut self) -> Result<Option<McpFrame>, McpError>;
334    /// Closes the transport, releasing any underlying resources.
335    async fn close(&mut self) -> Result<(), McpError>;
336}
337
338/// Factory that spawns a child process and connects via stdin/stdout.
339///
340/// Created from a [`StdioTransportConfig`]. Each call to
341/// [`connect`](McpTransportFactory::connect) spawns a new child process.
342pub struct StdioTransportFactory {
343    config: StdioTransportConfig,
344}
345
346impl StdioTransportFactory {
347    /// Creates a new factory from the given stdio transport configuration.
348    pub fn new(config: StdioTransportConfig) -> Self {
349        Self { config }
350    }
351}
352
353#[async_trait]
354impl McpTransportFactory for StdioTransportFactory {
355    async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
356        let mut command = Command::new(&self.config.command);
357        command.args(&self.config.args);
358        command.stdin(Stdio::piped());
359        command.stdout(Stdio::piped());
360        command.stderr(Stdio::inherit());
361
362        if let Some(cwd) = &self.config.cwd {
363            command.current_dir(cwd);
364        }
365
366        for (key, value) in &self.config.env {
367            command.env(key, value);
368        }
369
370        let mut child = command.spawn().map_err(McpError::Io)?;
371        let stdin = child
372            .stdin
373            .take()
374            .ok_or_else(|| McpError::Transport("failed to capture MCP stdin".into()))?;
375        let stdout = child
376            .stdout
377            .take()
378            .ok_or_else(|| McpError::Transport("failed to capture MCP stdout".into()))?;
379
380        Ok(Box::new(StdioTransport {
381            child,
382            stdin,
383            stdout: BufReader::new(stdout),
384        }))
385    }
386}
387
388/// Factory that opens an HTTP SSE stream and connects via Server-Sent Events.
389///
390/// Created from an [`SseTransportConfig`]. Each call to
391/// [`connect`](McpTransportFactory::connect) opens a new HTTP connection.
392pub struct SseTransportFactory {
393    config: SseTransportConfig,
394}
395
396impl SseTransportFactory {
397    /// Creates a new factory from the given SSE transport configuration.
398    pub fn new(config: SseTransportConfig) -> Self {
399        Self { config }
400    }
401}
402
403/// Factory that connects to a Streamable HTTP MCP endpoint.
404///
405/// Created from a [`StreamableHttpTransportConfig`]. Each call to
406/// [`connect`](McpTransportFactory::connect) creates a new HTTP-backed MCP session.
407pub struct StreamableHttpTransportFactory {
408    config: StreamableHttpTransportConfig,
409}
410
411impl StreamableHttpTransportFactory {
412    /// Creates a new factory from the given Streamable HTTP transport configuration.
413    pub fn new(config: StreamableHttpTransportConfig) -> Self {
414        Self { config }
415    }
416}
417
418#[async_trait]
419impl McpTransportFactory for SseTransportFactory {
420    async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
421        let client = Client::builder()
422            .user_agent(concat!("agentkit-mcp/", env!("CARGO_PKG_VERSION")))
423            .build()
424            .map_err(McpError::Http)?;
425
426        let mut request = client
427            .get(&self.config.url)
428            .header("Accept", "text/event-stream")
429            .header("Cache-Control", "no-cache");
430
431        for (key, value) in &self.config.headers {
432            request = request.header(key, value);
433        }
434
435        let response = request.send().await.map_err(McpError::Http)?;
436        let status = response.status();
437        if !status.is_success() {
438            let body = response
439                .text()
440                .await
441                .unwrap_or_else(|_| "<unreadable response body>".into());
442            return Err(McpError::Transport(format!(
443                "SSE connection failed with status {status}: {body}"
444            )));
445        }
446
447        let response_url = response.url().clone();
448        let stream = response.bytes_stream().map_err(std::io::Error::other);
449        let reader = BufReader::new(StreamReader::new(stream));
450        let (frame_tx, frame_rx) = mpsc::unbounded_channel();
451        let (endpoint_tx, endpoint_rx) = oneshot::channel();
452        let read_task = tokio::spawn(read_sse_stream(reader, response_url, frame_tx, endpoint_tx));
453
454        let endpoint_url = endpoint_rx
455            .await
456            .map_err(|_| McpError::Transport("SSE stream closed before endpoint event".into()))??;
457
458        Ok(Box::new(SseTransport {
459            client,
460            endpoint_url,
461            headers: self.config.headers.clone(),
462            frame_rx,
463            read_task,
464        }))
465    }
466}
467
468#[async_trait]
469impl McpTransportFactory for StreamableHttpTransportFactory {
470    async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
471        let client = Client::builder()
472            .user_agent(concat!("agentkit-mcp/", env!("CARGO_PKG_VERSION")))
473            .build()
474            .map_err(McpError::Http)?;
475
476        let endpoint_url = Url::parse(&self.config.url)
477            .map_err(|error| McpError::Transport(format!("invalid MCP endpoint URL: {error}")))?;
478
479        Ok(Box::new(StreamableHttpTransport {
480            client,
481            endpoint_url,
482            headers: self.config.headers.clone(),
483            protocol_version: None,
484            session_id: None,
485            pending_frames: VecDeque::new(),
486        }))
487    }
488}
489
490struct StdioTransport {
491    child: Child,
492    stdin: ChildStdin,
493    stdout: BufReader<ChildStdout>,
494}
495
496struct SseTransport {
497    client: Client,
498    endpoint_url: Url,
499    headers: Vec<(String, String)>,
500    frame_rx: mpsc::UnboundedReceiver<Result<McpFrame, McpError>>,
501    read_task: JoinHandle<()>,
502}
503
504struct StreamableHttpTransport {
505    client: Client,
506    endpoint_url: Url,
507    headers: Vec<(String, String)>,
508    protocol_version: Option<String>,
509    session_id: Option<String>,
510    pending_frames: VecDeque<McpFrame>,
511}
512
513#[async_trait]
514impl McpTransport for StdioTransport {
515    async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
516        let mut encoded = serde_json::to_vec(&message.value).map_err(McpError::Serialize)?;
517        encoded.push(b'\n');
518        self.stdin.write_all(&encoded).await.map_err(McpError::Io)?;
519        self.stdin.flush().await.map_err(McpError::Io)?;
520        Ok(())
521    }
522
523    async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
524        let mut line = String::new();
525        let read = self
526            .stdout
527            .read_line(&mut line)
528            .await
529            .map_err(McpError::Io)?;
530        if read == 0 {
531            return Ok(None);
532        }
533
534        let value = serde_json::from_str(line.trim()).map_err(McpError::Serialize)?;
535        Ok(Some(McpFrame { value }))
536    }
537
538    async fn close(&mut self) -> Result<(), McpError> {
539        let _ = self.stdin.shutdown().await;
540        let _ = self.child.kill().await;
541        Ok(())
542    }
543}
544
545#[async_trait]
546impl McpTransport for SseTransport {
547    async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
548        let mut request = self
549            .client
550            .post(self.endpoint_url.clone())
551            .header("Content-Type", "application/json");
552
553        for (key, value) in &self.headers {
554            request = request.header(key, value);
555        }
556
557        let response = request
558            .json(&message.value)
559            .send()
560            .await
561            .map_err(McpError::Http)?;
562        let status = response.status();
563        if !status.is_success() {
564            let body = response
565                .text()
566                .await
567                .unwrap_or_else(|_| "<unreadable response body>".into());
568            return Err(McpError::Transport(format!(
569                "SSE POST failed with status {status}: {body}"
570            )));
571        }
572
573        Ok(())
574    }
575
576    async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
577        match self.frame_rx.recv().await {
578            Some(Ok(frame)) => Ok(Some(frame)),
579            Some(Err(error)) => Err(error),
580            None => Ok(None),
581        }
582    }
583
584    async fn close(&mut self) -> Result<(), McpError> {
585        self.read_task.abort();
586        Ok(())
587    }
588}
589
590#[async_trait]
591impl McpTransport for StreamableHttpTransport {
592    async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
593        let is_request = is_jsonrpc_request(&message.value);
594        let request_id = message.value.get("id").cloned();
595        let is_initialize =
596            message.value.get("method").and_then(Value::as_str) == Some("initialize");
597
598        let mut request = self
599            .client
600            .post(self.endpoint_url.clone())
601            .header("Content-Type", "application/json")
602            .header("Accept", "application/json, text/event-stream");
603
604        request = apply_streamable_http_headers(
605            request,
606            &self.headers,
607            self.protocol_version.as_deref(),
608            self.session_id.as_deref(),
609        );
610
611        let response = request
612            .json(&message.value)
613            .send()
614            .await
615            .map_err(McpError::Http)?;
616
617        if is_initialize {
618            self.capture_session_id(response.headers());
619        }
620
621        let status = response.status();
622        if !status.is_success() {
623            return Err(
624                streamable_http_status_error("Streamable HTTP POST", status, response).await,
625            );
626        }
627
628        if !is_request {
629            return Ok(());
630        }
631
632        let content_type = response
633            .headers()
634            .get(reqwest::header::CONTENT_TYPE)
635            .and_then(|value| value.to_str().ok())
636            .unwrap_or_default()
637            .to_string();
638
639        if content_type.starts_with("application/json") {
640            let value = response.json::<Value>().await.map_err(McpError::Http)?;
641            self.maybe_update_protocol_version(&message.value, &value)?;
642            self.pending_frames.push_back(McpFrame { value });
643            return Ok(());
644        }
645
646        if !content_type.starts_with("text/event-stream") {
647            let body = response
648                .text()
649                .await
650                .unwrap_or_else(|_| "<unreadable response body>".into());
651            return Err(McpError::Transport(format!(
652                "unexpected Streamable HTTP response content type {content_type:?}: {body}"
653            )));
654        }
655
656        let request_id = request_id.ok_or_else(|| {
657            McpError::Protocol("JSON-RPC request over Streamable HTTP is missing an id".into())
658        })?;
659        self.collect_streamable_http_response(response, &message.value, &request_id)
660            .await
661    }
662
663    async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
664        Ok(self.pending_frames.pop_front())
665    }
666
667    async fn close(&mut self) -> Result<(), McpError> {
668        let Some(session_id) = self.session_id.clone() else {
669            return Ok(());
670        };
671
672        let mut request = self.client.delete(self.endpoint_url.clone());
673        request = apply_streamable_http_headers(
674            request,
675            &self.headers,
676            self.protocol_version.as_deref(),
677            Some(session_id.as_str()),
678        );
679
680        let response = request.send().await.map_err(McpError::Http)?;
681        if response.status().is_success()
682            || response.status() == StatusCode::METHOD_NOT_ALLOWED
683            || response.status() == StatusCode::NOT_FOUND
684        {
685            self.session_id = None;
686            return Ok(());
687        }
688
689        Err(
690            streamable_http_status_error("Streamable HTTP DELETE", response.status(), response)
691                .await,
692        )
693    }
694}
695
696impl StreamableHttpTransport {
697    async fn collect_streamable_http_response(
698        &mut self,
699        response: reqwest::Response,
700        request_message: &Value,
701        request_id: &Value,
702    ) -> Result<(), McpError> {
703        let mut retry_delay = Duration::from_millis(0);
704        let mut last_event_id = None;
705        let mut saw_response = false;
706
707        saw_response |= self
708            .read_streamable_http_events(
709                response,
710                request_message,
711                request_id,
712                &mut last_event_id,
713                &mut retry_delay,
714            )
715            .await?;
716
717        while !saw_response && last_event_id.is_some() {
718            if !retry_delay.is_zero() {
719                sleep(retry_delay).await;
720            }
721
722            let response = self
723                .resume_streamable_http_stream(last_event_id.as_deref().unwrap())
724                .await?;
725            saw_response |= self
726                .read_streamable_http_events(
727                    response,
728                    request_message,
729                    request_id,
730                    &mut last_event_id,
731                    &mut retry_delay,
732                )
733                .await?;
734        }
735
736        Ok(())
737    }
738
739    async fn read_streamable_http_events(
740        &mut self,
741        response: reqwest::Response,
742        request_message: &Value,
743        request_id: &Value,
744        last_event_id: &mut Option<String>,
745        retry_delay: &mut Duration,
746    ) -> Result<bool, McpError> {
747        let stream = response.bytes_stream().map_err(std::io::Error::other);
748        let mut reader = BufReader::new(StreamReader::new(stream));
749        let mut saw_response = false;
750
751        while let Some(event) = read_next_sse_event(&mut reader).await? {
752            if let Some(id) = event.id.clone() {
753                *last_event_id = Some(id);
754            }
755            if let Some(retry_ms) = event.retry_ms {
756                *retry_delay = Duration::from_millis(retry_ms);
757            }
758
759            let Some(frame) = streamable_http_event_to_frame(event)? else {
760                continue;
761            };
762
763            self.maybe_update_protocol_version(request_message, &frame.value)?;
764            if frame.value.get("id") == Some(request_id) {
765                saw_response = true;
766            }
767            self.pending_frames.push_back(frame);
768        }
769
770        Ok(saw_response)
771    }
772
773    async fn resume_streamable_http_stream(
774        &self,
775        last_event_id: &str,
776    ) -> Result<reqwest::Response, McpError> {
777        let mut request = self
778            .client
779            .get(self.endpoint_url.clone())
780            .header("Accept", "text/event-stream")
781            .header("Cache-Control", "no-cache")
782            .header("Last-Event-ID", last_event_id);
783
784        request = apply_streamable_http_headers(
785            request,
786            &self.headers,
787            self.protocol_version.as_deref(),
788            self.session_id.as_deref(),
789        );
790
791        let response = request.send().await.map_err(McpError::Http)?;
792        let status = response.status();
793        if !status.is_success() {
794            return Err(
795                streamable_http_status_error("Streamable HTTP GET", status, response).await,
796            );
797        }
798
799        let content_type = response
800            .headers()
801            .get(reqwest::header::CONTENT_TYPE)
802            .and_then(|value| value.to_str().ok())
803            .unwrap_or_default();
804        if !content_type.starts_with("text/event-stream") {
805            let content_type = content_type.to_string();
806            let body = response
807                .text()
808                .await
809                .unwrap_or_else(|_| "<unreadable response body>".into());
810            return Err(McpError::Transport(format!(
811                "Streamable HTTP GET expected text/event-stream, got {content_type:?}: {body}"
812            )));
813        }
814
815        Ok(response)
816    }
817
818    fn maybe_update_protocol_version(
819        &mut self,
820        request_message: &Value,
821        response_value: &Value,
822    ) -> Result<(), McpError> {
823        if request_message.get("method").and_then(Value::as_str) != Some("initialize") {
824            return Ok(());
825        }
826
827        let protocol_version = response_value
828            .get("result")
829            .and_then(|result| result.get("protocolVersion"))
830            .and_then(Value::as_str);
831
832        if let Some(protocol_version) = protocol_version {
833            self.protocol_version = Some(protocol_version.to_string());
834        }
835
836        Ok(())
837    }
838
839    fn capture_session_id(&mut self, headers: &reqwest::header::HeaderMap) {
840        self.session_id = headers
841            .get("MCP-Session-Id")
842            .and_then(|value| value.to_str().ok())
843            .map(|value| value.to_string());
844    }
845}
846
847/// Descriptor for a tool advertised by an MCP server.
848///
849/// Returned as part of a [`McpDiscoverySnapshot`] after server discovery. The
850/// [`input_schema`](Self::input_schema) field is the JSON Schema that describes
851/// the tool's expected input.
852#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
853pub struct McpToolDescriptor {
854    /// The tool name as reported by the MCP server.
855    pub name: String,
856    /// Optional human-readable description of the tool.
857    pub description: Option<String>,
858    /// JSON Schema describing the tool's input parameters.
859    pub input_schema: Value,
860    /// Arbitrary metadata attached to this descriptor.
861    pub metadata: MetadataMap,
862}
863
864/// Descriptor for a resource advertised by an MCP server.
865///
866/// Resources represent data that the server can provide (e.g. files, database
867/// records). Each resource is identified by a URI.
868#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
869pub struct McpResourceDescriptor {
870    /// The resource URI (e.g. `"file:///tmp/example.txt"`).
871    pub id: String,
872    /// Human-readable name of the resource.
873    pub name: String,
874    /// Optional description of the resource.
875    pub description: Option<String>,
876    /// Optional MIME type (e.g. `"text/plain"`, `"application/json"`).
877    pub mime_type: Option<String>,
878    /// Arbitrary metadata attached to this descriptor.
879    pub metadata: MetadataMap,
880}
881
882/// Descriptor for a prompt template advertised by an MCP server.
883///
884/// Prompts are reusable message templates that can be parameterized with arguments.
885/// The [`input_schema`](Self::input_schema) describes the expected arguments.
886#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
887pub struct McpPromptDescriptor {
888    /// Unique identifier for the prompt (typically the same as `name`).
889    pub id: String,
890    /// Human-readable name of the prompt.
891    pub name: String,
892    /// Optional description of what the prompt does.
893    pub description: Option<String>,
894    /// JSON Schema describing the prompt's input arguments.
895    pub input_schema: Value,
896    /// Arbitrary metadata attached to this descriptor.
897    pub metadata: MetadataMap,
898}
899
900/// A snapshot of all capabilities discovered from a single MCP server.
901///
902/// Obtained by calling [`McpConnection::discover`] or as part of a
903/// [`McpServerHandle`]. Contains the full list of tools, resources, and prompts
904/// that the server advertised at discovery time.
905#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
906pub struct McpDiscoverySnapshot {
907    /// The server this snapshot was taken from.
908    pub server_id: McpServerId,
909    /// Tools advertised by the server.
910    pub tools: Vec<McpToolDescriptor>,
911    /// Resources advertised by the server.
912    pub resources: Vec<McpResourceDescriptor>,
913    /// Prompts advertised by the server.
914    pub prompts: Vec<McpPromptDescriptor>,
915    /// Arbitrary metadata attached to this snapshot.
916    pub metadata: MetadataMap,
917}
918
919/// A live connection to a single MCP server.
920///
921/// Handles JSON-RPC request/response framing, automatic auth enrichment, and
922/// high-level methods for tool calls, resource reads, prompt retrieval, and
923/// capability discovery.
924///
925/// Create a connection with [`McpConnection::connect`] or indirectly through
926/// [`McpServerManager::connect_server`].
927///
928/// # Example
929///
930/// ```rust,no_run
931/// use agentkit_mcp::{McpConnection, McpServerConfig, McpTransportBinding, StdioTransportConfig};
932///
933/// # #[tokio::main]
934/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
935/// let config = McpServerConfig::new(
936///     "filesystem",
937///     McpTransportBinding::Stdio(StdioTransportConfig::new("npx")
938///         .with_arg("-y")
939///         .with_arg("@modelcontextprotocol/server-filesystem")),
940/// );
941///
942/// let connection = McpConnection::connect(&config).await?;
943/// let snapshot = connection.discover().await?;
944/// println!("found {} tools", snapshot.tools.len());
945/// # Ok(())
946/// # }
947/// ```
948pub struct McpConnection {
949    server_id: McpServerId,
950    transport: Mutex<Box<dyn McpTransport>>,
951    auth: Mutex<Option<MetadataMap>>,
952    next_id: AtomicU64,
953}
954
955/// The result of replaying an MCP operation after auth resolution.
956///
957/// Returned by [`McpConnection::replay_auth_operation`] and
958/// [`McpServerManager::resolve_auth_and_resume`].
959#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
960pub enum McpOperationResult {
961    /// The server was successfully (re)connected; contains the discovery snapshot.
962    Connected(McpDiscoverySnapshot),
963    /// A tool call completed; contains the raw JSON result.
964    Tool(Value),
965    /// A resource was read successfully.
966    Resource(ResourceContents),
967    /// A prompt was retrieved successfully.
968    Prompt(PromptContents),
969}
970
971impl McpConnection {
972    /// Connects to an MCP server, performs the JSON-RPC `initialize` handshake, and
973    /// returns a ready-to-use connection.
974    ///
975    /// # Errors
976    ///
977    /// Returns [`McpError`] if the transport fails to connect, the handshake is
978    /// rejected, or the server requires authentication ([`McpError::AuthRequired`]).
979    pub async fn connect(config: &McpServerConfig) -> Result<Self, McpError> {
980        Self::connect_with_auth(config, None).await
981    }
982
983    async fn connect_with_auth(
984        config: &McpServerConfig,
985        auth: Option<&MetadataMap>,
986    ) -> Result<Self, McpError> {
987        let factory: Arc<dyn McpTransportFactory> = match &config.transport {
988            McpTransportBinding::Stdio(binding) => {
989                Arc::new(StdioTransportFactory::new(binding.clone()))
990            }
991            McpTransportBinding::StreamableHttp(binding) => {
992                Arc::new(StreamableHttpTransportFactory::new(binding.clone()))
993            }
994            McpTransportBinding::Sse(binding) => {
995                Arc::new(SseTransportFactory::new(binding.clone()))
996            }
997            McpTransportBinding::Custom(factory) => factory.clone(),
998        };
999
1000        let mut transport = factory.connect().await?;
1001        let mut params = serde_json::Map::new();
1002        params.insert(
1003            "protocolVersion".into(),
1004            Value::String(MCP_LATEST_PROTOCOL_VERSION.into()),
1005        );
1006        params.insert("capabilities".into(), json!({}));
1007        params.insert(
1008            "clientInfo".into(),
1009            json!({
1010                "name": "agentkit-mcp",
1011                "version": env!("CARGO_PKG_VERSION")
1012            }),
1013        );
1014        if let Some(auth) = auth {
1015            params.insert("auth".into(), metadata_to_value(auth));
1016        }
1017        let init_params = Value::Object(params.clone());
1018        transport
1019            .send(McpFrame {
1020                value: json!({
1021                    "jsonrpc": "2.0",
1022                    "id": 0,
1023                    "method": "initialize",
1024                    "params": init_params.clone()
1025                }),
1026            })
1027            .await?;
1028        let init_response = transport.recv().await?.ok_or_else(|| {
1029            McpError::Transport("transport closed during MCP initialization".into())
1030        })?;
1031        if let Some(error) = init_response.value.get("error") {
1032            if let Some(auth_request) =
1033                parse_auth_request(&config.id, "initialize", &init_params, error)
1034            {
1035                return Err(McpError::AuthRequired(Box::new(auth_request)));
1036            }
1037            return Err(McpError::Invocation(error.to_string()));
1038        }
1039        let negotiated_protocol_version = init_response
1040            .value
1041            .get("result")
1042            .and_then(|result| result.get("protocolVersion"))
1043            .and_then(Value::as_str)
1044            .ok_or_else(|| {
1045                McpError::Protocol("initialize response missing result.protocolVersion".into())
1046            })?;
1047        if !MCP_SUPPORTED_PROTOCOL_VERSIONS.contains(&negotiated_protocol_version) {
1048            return Err(McpError::Protocol(format!(
1049                "unsupported MCP protocol version negotiated during initialize: {negotiated_protocol_version}"
1050            )));
1051        }
1052        transport
1053            .send(McpFrame {
1054                value: json!({
1055                    "jsonrpc": "2.0",
1056                    "method": "notifications/initialized",
1057                    "params": {}
1058                }),
1059            })
1060            .await?;
1061
1062        Ok(Self {
1063            server_id: config.id.clone(),
1064            transport: Mutex::new(transport),
1065            auth: Mutex::new(auth.cloned()),
1066            next_id: AtomicU64::new(1),
1067        })
1068    }
1069
1070    /// Returns the [`McpServerId`] for this connection.
1071    pub fn server_id(&self) -> &McpServerId {
1072        &self.server_id
1073    }
1074
1075    /// Closes the underlying transport, shutting down the connection to the server.
1076    ///
1077    /// # Errors
1078    ///
1079    /// Returns [`McpError`] if the transport cannot be closed cleanly.
1080    pub async fn close(&self) -> Result<(), McpError> {
1081        let mut transport = self.transport.lock().await;
1082        transport.close().await
1083    }
1084
1085    /// Stores or clears authentication credentials for future requests on this
1086    /// connection.
1087    ///
1088    /// After calling this method with [`AuthResolution::Provided`], every subsequent
1089    /// JSON-RPC request will include the credentials in an `auth` field.
1090    ///
1091    /// # Errors
1092    ///
1093    /// Returns [`McpError`] if the resolution cannot be applied.
1094    pub async fn resolve_auth(&self, resolution: AuthResolution) -> Result<(), McpError> {
1095        let mut auth = self.auth.lock().await;
1096        match resolution {
1097            AuthResolution::Provided { credentials, .. } => {
1098                *auth = Some(credentials);
1099            }
1100            AuthResolution::Cancelled { .. } => {
1101                *auth = None;
1102            }
1103        }
1104        Ok(())
1105    }
1106
1107    /// Performs full capability discovery by listing tools, resources, and prompts.
1108    ///
1109    /// Returns an [`McpDiscoverySnapshot`] containing everything the server advertises.
1110    ///
1111    /// # Errors
1112    ///
1113    /// Returns [`McpError`] if any of the list requests fail.
1114    pub async fn discover(&self) -> Result<McpDiscoverySnapshot, McpError> {
1115        Ok(McpDiscoverySnapshot {
1116            server_id: self.server_id.clone(),
1117            tools: self.list_tools().await?,
1118            resources: self.list_resources().await?,
1119            prompts: self.list_prompts().await?,
1120            metadata: MetadataMap::new(),
1121        })
1122    }
1123
1124    /// Lists all tools advertised by the connected MCP server.
1125    ///
1126    /// # Errors
1127    ///
1128    /// Returns [`McpError`] if the `tools/list` request fails.
1129    pub async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>, McpError> {
1130        let result = self.request("tools/list", json!({})).await?;
1131        result
1132            .get("tools")
1133            .and_then(Value::as_array)
1134            .cloned()
1135            .unwrap_or_default()
1136            .into_iter()
1137            .map(parse_tool_descriptor)
1138            .collect()
1139    }
1140
1141    /// Lists all resources advertised by the connected MCP server.
1142    ///
1143    /// # Errors
1144    ///
1145    /// Returns [`McpError`] if the `resources/list` request fails.
1146    pub async fn list_resources(&self) -> Result<Vec<McpResourceDescriptor>, McpError> {
1147        let result = self.request("resources/list", json!({})).await?;
1148        result
1149            .get("resources")
1150            .and_then(Value::as_array)
1151            .cloned()
1152            .unwrap_or_default()
1153            .into_iter()
1154            .map(parse_resource_descriptor)
1155            .collect()
1156    }
1157
1158    /// Lists all prompts advertised by the connected MCP server.
1159    ///
1160    /// # Errors
1161    ///
1162    /// Returns [`McpError`] if the `prompts/list` request fails.
1163    pub async fn list_prompts(&self) -> Result<Vec<McpPromptDescriptor>, McpError> {
1164        let result = self.request("prompts/list", json!({})).await?;
1165        result
1166            .get("prompts")
1167            .and_then(Value::as_array)
1168            .cloned()
1169            .unwrap_or_default()
1170            .into_iter()
1171            .map(parse_prompt_descriptor)
1172            .collect()
1173    }
1174
1175    /// Invokes a tool on the MCP server and returns the raw JSON result.
1176    ///
1177    /// # Arguments
1178    ///
1179    /// * `name` - The tool name as it appears in the server's tool list.
1180    /// * `arguments` - A JSON value matching the tool's input schema.
1181    ///
1182    /// # Errors
1183    ///
1184    /// Returns [`McpError::AuthRequired`] if the server demands authentication,
1185    /// or another [`McpError`] variant on transport or protocol failures.
1186    pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value, McpError> {
1187        self.request(
1188            "tools/call",
1189            json!({
1190                "name": name,
1191                "arguments": arguments,
1192            }),
1193        )
1194        .await
1195    }
1196
1197    /// Reads a resource from the MCP server by URI.
1198    ///
1199    /// # Arguments
1200    ///
1201    /// * `uri` - The resource URI (e.g. `"file:///tmp/example.txt"`).
1202    ///
1203    /// # Errors
1204    ///
1205    /// Returns [`McpError`] if the resource cannot be read or the response is malformed.
1206    pub async fn read_resource(&self, uri: &str) -> Result<ResourceContents, McpError> {
1207        let result = self
1208            .request(
1209                "resources/read",
1210                json!({
1211                    "uri": uri,
1212                }),
1213            )
1214            .await?;
1215        let content = result
1216            .get("contents")
1217            .and_then(Value::as_array)
1218            .and_then(|values| values.first())
1219            .cloned()
1220            .ok_or_else(|| McpError::Protocol("resources/read returned no contents".into()))?;
1221
1222        let data = if let Some(text) = content.get("text").and_then(Value::as_str) {
1223            DataRef::InlineText(text.into())
1224        } else if let Some(found_uri) = content.get("uri").and_then(Value::as_str) {
1225            DataRef::Uri(found_uri.into())
1226        } else {
1227            return Err(McpError::Protocol(
1228                "unsupported resource content shape".into(),
1229            ));
1230        };
1231
1232        Ok(ResourceContents {
1233            data,
1234            metadata: MetadataMap::new(),
1235        })
1236    }
1237
1238    /// Retrieves a prompt from the MCP server, rendering it with the given arguments.
1239    ///
1240    /// # Arguments
1241    ///
1242    /// * `name` - The prompt name as it appears in the server's prompt list.
1243    /// * `arguments` - A JSON value containing the prompt's input arguments.
1244    ///
1245    /// # Errors
1246    ///
1247    /// Returns [`McpError`] if the prompt cannot be retrieved or the response is malformed.
1248    pub async fn get_prompt(
1249        &self,
1250        name: &str,
1251        arguments: Value,
1252    ) -> Result<PromptContents, McpError> {
1253        let result = self
1254            .request(
1255                "prompts/get",
1256                json!({
1257                    "name": name,
1258                    "arguments": arguments,
1259                }),
1260            )
1261            .await?;
1262        let items = result
1263            .get("messages")
1264            .and_then(Value::as_array)
1265            .cloned()
1266            .unwrap_or_default()
1267            .into_iter()
1268            .map(parse_prompt_message)
1269            .collect::<Result<Vec<_>, _>>()?;
1270
1271        Ok(PromptContents {
1272            items,
1273            metadata: MetadataMap::new(),
1274        })
1275    }
1276
1277    async fn request(&self, method: &str, params: Value) -> Result<Value, McpError> {
1278        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1279        let params = self.enrich_params(params.clone()).await;
1280        let mut transport = self.transport.lock().await;
1281        transport
1282            .send(McpFrame {
1283                value: json!({
1284                    "jsonrpc": "2.0",
1285                    "id": id,
1286                    "method": method,
1287                    "params": params,
1288                }),
1289            })
1290            .await?;
1291
1292        loop {
1293            let Some(frame) = transport.recv().await? else {
1294                return Err(McpError::Transport(
1295                    "transport closed while waiting for MCP response".into(),
1296                ));
1297            };
1298
1299            if frame.value.get("id").and_then(Value::as_u64) != Some(id) {
1300                continue;
1301            }
1302
1303            if let Some(error) = frame.value.get("error") {
1304                if let Some(auth_request) =
1305                    parse_auth_request(&self.server_id, method, &params, error)
1306                {
1307                    return Err(McpError::AuthRequired(Box::new(auth_request)));
1308                }
1309                return Err(McpError::Invocation(error.to_string()));
1310            }
1311
1312            return frame
1313                .value
1314                .get("result")
1315                .cloned()
1316                .ok_or_else(|| McpError::Protocol("MCP response missing result".into()));
1317        }
1318    }
1319
1320    async fn enrich_params(&self, params: Value) -> Value {
1321        let auth = self.auth.lock().await;
1322        let Some(auth) = auth.as_ref() else {
1323            return params;
1324        };
1325
1326        match params {
1327            Value::Object(mut object) => {
1328                object
1329                    .entry("auth")
1330                    .or_insert_with(|| metadata_to_value(auth));
1331                Value::Object(object)
1332            }
1333            other => other,
1334        }
1335    }
1336
1337    /// Replays an MCP operation that previously failed with an auth challenge.
1338    ///
1339    /// This is called after credentials have been resolved via [`resolve_auth`](Self::resolve_auth).
1340    /// The operation is re-issued with the stored credentials attached.
1341    ///
1342    /// # Errors
1343    ///
1344    /// Returns [`McpError::AuthResolution`] if the operation targets a different server,
1345    /// or other [`McpError`] variants if the replayed operation itself fails.
1346    pub async fn replay_auth_operation(
1347        &self,
1348        operation: &AuthOperation,
1349    ) -> Result<McpOperationResult, McpError> {
1350        match operation {
1351            AuthOperation::McpToolCall {
1352                server_id,
1353                tool_name,
1354                input,
1355                ..
1356            } => {
1357                self.ensure_server_match(server_id)?;
1358                self.call_tool(tool_name, input.clone())
1359                    .await
1360                    .map(McpOperationResult::Tool)
1361            }
1362            AuthOperation::McpResourceRead {
1363                server_id,
1364                resource_id,
1365                ..
1366            } => {
1367                self.ensure_server_match(server_id)?;
1368                self.read_resource(resource_id)
1369                    .await
1370                    .map(McpOperationResult::Resource)
1371            }
1372            AuthOperation::McpPromptGet {
1373                server_id,
1374                prompt_id,
1375                args,
1376                ..
1377            } => {
1378                self.ensure_server_match(server_id)?;
1379                self.get_prompt(prompt_id, args.clone())
1380                    .await
1381                    .map(McpOperationResult::Prompt)
1382            }
1383            AuthOperation::ToolCall {
1384                tool_name,
1385                input,
1386                metadata,
1387                ..
1388            } => {
1389                if let Some(server_id) = metadata.get("server_id").and_then(Value::as_str) {
1390                    self.ensure_server_match(server_id)?;
1391                }
1392                let tool_name = normalize_mcp_tool_name(self.server_id(), tool_name);
1393                self.call_tool(&tool_name, input.clone())
1394                    .await
1395                    .map(McpOperationResult::Tool)
1396            }
1397            AuthOperation::McpConnect { .. } => Err(McpError::AuthResolution(
1398                "connect operations must be replayed through the server manager".into(),
1399            )),
1400            AuthOperation::Custom { kind, .. } => Err(McpError::AuthResolution(format!(
1401                "unsupported auth operation for replay: {kind}"
1402            ))),
1403        }
1404    }
1405
1406    fn ensure_server_match(&self, server_id: &str) -> Result<(), McpError> {
1407        if self.server_id.0 == server_id {
1408            Ok(())
1409        } else {
1410            Err(McpError::AuthResolution(format!(
1411                "auth operation targets server {server_id}, but connection is for {}",
1412                self.server_id
1413            )))
1414        }
1415    }
1416}
1417
1418/// Adapter that exposes an MCP tool as an [`Invocable`] for the capabilities system.
1419///
1420/// This is the capabilities-layer adapter. For the tool-layer adapter, see
1421/// [`McpToolAdapter`]. Names are prefixed with `mcp_<server_id>_<tool_name>`
1422/// so they satisfy provider validators that only allow `[a-zA-Z0-9_-]`
1423/// (e.g. Anthropic on Vertex).
1424pub struct McpInvocable {
1425    connection: Arc<McpConnection>,
1426    descriptor: McpToolDescriptor,
1427    spec: InvocableSpec,
1428}
1429
1430impl McpInvocable {
1431    /// Creates a new invocable adapter for the given MCP tool.
1432    ///
1433    /// # Arguments
1434    ///
1435    /// * `connection` - A shared connection to the MCP server that owns the tool.
1436    /// * `descriptor` - The tool descriptor obtained from discovery.
1437    pub fn new(connection: Arc<McpConnection>, descriptor: McpToolDescriptor) -> Self {
1438        let spec = InvocableSpec {
1439            name: CapabilityName::new(format!(
1440                "mcp_{}_{}",
1441                connection.server_id(),
1442                descriptor.name
1443            )),
1444            description: descriptor
1445                .description
1446                .clone()
1447                .unwrap_or_else(|| descriptor.name.clone()),
1448            input_schema: descriptor.input_schema.clone(),
1449            metadata: descriptor.metadata.clone(),
1450        };
1451
1452        Self {
1453            connection,
1454            descriptor,
1455            spec,
1456        }
1457    }
1458}
1459
1460#[async_trait]
1461impl Invocable for McpInvocable {
1462    fn spec(&self) -> &InvocableSpec {
1463        &self.spec
1464    }
1465
1466    async fn invoke(
1467        &self,
1468        request: InvocableRequest,
1469        _ctx: &mut CapabilityContext<'_>,
1470    ) -> Result<InvocableResult, CapabilityError> {
1471        let result = self
1472            .connection
1473            .call_tool(&self.descriptor.name, request.input)
1474            .await
1475            .map_err(|error| match error {
1476                McpError::AuthRequired(request) => {
1477                    CapabilityError::Unavailable(format!("auth required: {:?}", request))
1478                }
1479                other => CapabilityError::ExecutionFailed(other.to_string()),
1480            })?;
1481
1482        Ok(InvocableResult {
1483            output: value_to_invocable_output(result),
1484            metadata: MetadataMap::new(),
1485        })
1486    }
1487}
1488
1489/// Adapter that exposes a single MCP resource as a [`ResourceProvider`].
1490///
1491/// Created automatically by [`McpCapabilityProvider::from_snapshot`] for each
1492/// resource discovered on the server.
1493pub struct McpResourceHandle {
1494    connection: Arc<McpConnection>,
1495    descriptor: ResourceDescriptor,
1496}
1497
1498#[async_trait]
1499impl ResourceProvider for McpResourceHandle {
1500    async fn list_resources(&self) -> Result<Vec<ResourceDescriptor>, CapabilityError> {
1501        Ok(vec![self.descriptor.clone()])
1502    }
1503
1504    async fn read_resource(
1505        &self,
1506        id: &ResourceId,
1507        _ctx: &mut CapabilityContext<'_>,
1508    ) -> Result<ResourceContents, CapabilityError> {
1509        self.connection
1510            .read_resource(&id.0)
1511            .await
1512            .map_err(|error| match error {
1513                McpError::AuthRequired(request) => {
1514                    CapabilityError::Unavailable(format!("auth required: {:?}", request))
1515                }
1516                other => CapabilityError::ExecutionFailed(other.to_string()),
1517            })
1518    }
1519}
1520
1521/// Adapter that exposes a single MCP prompt as a [`PromptProvider`].
1522///
1523/// Created automatically by [`McpCapabilityProvider::from_snapshot`] for each
1524/// prompt discovered on the server.
1525pub struct McpPromptHandle {
1526    connection: Arc<McpConnection>,
1527    descriptor: PromptDescriptor,
1528}
1529
1530#[async_trait]
1531impl PromptProvider for McpPromptHandle {
1532    async fn list_prompts(&self) -> Result<Vec<PromptDescriptor>, CapabilityError> {
1533        Ok(vec![self.descriptor.clone()])
1534    }
1535
1536    async fn get_prompt(
1537        &self,
1538        id: &PromptId,
1539        args: Value,
1540        _ctx: &mut CapabilityContext<'_>,
1541    ) -> Result<PromptContents, CapabilityError> {
1542        self.connection
1543            .get_prompt(&id.0, args)
1544            .await
1545            .map_err(|error| match error {
1546                McpError::AuthRequired(request) => {
1547                    CapabilityError::Unavailable(format!("auth required: {:?}", request))
1548                }
1549                other => CapabilityError::ExecutionFailed(other.to_string()),
1550            })
1551    }
1552}
1553
1554/// A [`CapabilityProvider`] that surfaces MCP tools, resources, and prompts into the
1555/// agentkit capabilities system.
1556///
1557/// Built from a discovery snapshot, this provider wraps each MCP tool as an
1558/// [`McpInvocable`], each resource as an [`McpResourceHandle`], and each prompt as
1559/// an [`McpPromptHandle`].
1560///
1561/// # Example
1562///
1563/// ```rust,no_run
1564/// use std::sync::Arc;
1565/// use agentkit_mcp::{McpCapabilityProvider, McpServerConfig, McpTransportBinding, StdioTransportConfig};
1566///
1567/// # #[tokio::main]
1568/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
1569/// let config = McpServerConfig::new(
1570///     "filesystem",
1571///     McpTransportBinding::Stdio(StdioTransportConfig::new("npx")
1572///         .with_arg("-y")
1573///         .with_arg("@modelcontextprotocol/server-filesystem")),
1574/// );
1575/// let (connection, provider, snapshot) = McpCapabilityProvider::connect(&config).await?;
1576/// // `provider` implements CapabilityProvider and can be registered with an agent.
1577/// # Ok(())
1578/// # }
1579/// ```
1580pub struct McpCapabilityProvider {
1581    invocables: Vec<Arc<dyn Invocable>>,
1582    resources: Vec<Arc<dyn ResourceProvider>>,
1583    prompts: Vec<Arc<dyn PromptProvider>>,
1584}
1585
1586impl McpCapabilityProvider {
1587    /// Creates a capability provider from an existing connection and its discovery
1588    /// snapshot.
1589    ///
1590    /// Each tool, resource, and prompt in the snapshot is wrapped in the appropriate
1591    /// adapter type.
1592    pub fn from_snapshot(connection: Arc<McpConnection>, snapshot: &McpDiscoverySnapshot) -> Self {
1593        let invocables = snapshot
1594            .tools
1595            .iter()
1596            .cloned()
1597            .map(|descriptor| {
1598                Arc::new(McpInvocable::new(connection.clone(), descriptor)) as Arc<dyn Invocable>
1599            })
1600            .collect();
1601
1602        let resources = snapshot
1603            .resources
1604            .iter()
1605            .cloned()
1606            .map(|descriptor| {
1607                Arc::new(McpResourceHandle {
1608                    connection: connection.clone(),
1609                    descriptor: ResourceDescriptor {
1610                        id: ResourceId::new(descriptor.id),
1611                        name: descriptor.name,
1612                        description: descriptor.description,
1613                        mime_type: descriptor.mime_type,
1614                        metadata: descriptor.metadata,
1615                    },
1616                }) as Arc<dyn ResourceProvider>
1617            })
1618            .collect();
1619
1620        let prompts = snapshot
1621            .prompts
1622            .iter()
1623            .cloned()
1624            .map(|descriptor| {
1625                Arc::new(McpPromptHandle {
1626                    connection: connection.clone(),
1627                    descriptor: PromptDescriptor {
1628                        id: PromptId::new(descriptor.id),
1629                        name: descriptor.name,
1630                        description: descriptor.description,
1631                        input_schema: descriptor.input_schema,
1632                        metadata: descriptor.metadata,
1633                    },
1634                }) as Arc<dyn PromptProvider>
1635            })
1636            .collect();
1637
1638        Self {
1639            invocables,
1640            resources,
1641            prompts,
1642        }
1643    }
1644
1645    /// Merges multiple capability providers into a single provider.
1646    ///
1647    /// This is useful when managing several MCP servers through a
1648    /// [`McpServerManager`] and you want one combined provider for the agent.
1649    pub fn merge<I>(providers: I) -> Self
1650    where
1651        I: IntoIterator<Item = Self>,
1652    {
1653        let mut invocables = Vec::new();
1654        let mut resources = Vec::new();
1655        let mut prompts = Vec::new();
1656
1657        for provider in providers {
1658            invocables.extend(provider.invocables);
1659            resources.extend(provider.resources);
1660            prompts.extend(provider.prompts);
1661        }
1662
1663        Self {
1664            invocables,
1665            resources,
1666            prompts,
1667        }
1668    }
1669
1670    /// Connects to an MCP server, performs discovery, and builds a capability
1671    /// provider in one step.
1672    ///
1673    /// Returns the shared connection, the provider, and the discovery snapshot.
1674    ///
1675    /// # Errors
1676    ///
1677    /// Returns [`McpError`] if connection or discovery fails.
1678    pub async fn connect(
1679        config: &McpServerConfig,
1680    ) -> Result<(Arc<McpConnection>, Self, McpDiscoverySnapshot), McpError> {
1681        let connection = Arc::new(McpConnection::connect(config).await?);
1682        let snapshot = connection.discover().await?;
1683        let provider = Self::from_snapshot(connection.clone(), &snapshot);
1684
1685        Ok((connection, provider, snapshot))
1686    }
1687}
1688
1689impl CapabilityProvider for McpCapabilityProvider {
1690    fn invocables(&self) -> Vec<Arc<dyn Invocable>> {
1691        self.invocables.clone()
1692    }
1693
1694    fn resources(&self) -> Vec<Arc<dyn ResourceProvider>> {
1695        self.resources.clone()
1696    }
1697
1698    fn prompts(&self) -> Vec<Arc<dyn PromptProvider>> {
1699        self.prompts.clone()
1700    }
1701}
1702
1703/// A connected MCP server together with its configuration and discovery snapshot.
1704///
1705/// Obtained from [`McpServerManager::connect_server`] or
1706/// [`McpServerManager::connect_all`]. Provides convenience methods to create
1707/// tool registries and capability providers from the server's discovered capabilities.
1708#[derive(Clone)]
1709pub struct McpServerHandle {
1710    config: McpServerConfig,
1711    connection: Arc<McpConnection>,
1712    snapshot: McpDiscoverySnapshot,
1713}
1714
1715impl McpServerHandle {
1716    /// Returns the original configuration used to connect this server.
1717    pub fn config(&self) -> &McpServerConfig {
1718        &self.config
1719    }
1720
1721    /// Returns the server's unique identifier.
1722    pub fn server_id(&self) -> &McpServerId {
1723        self.connection.server_id()
1724    }
1725
1726    /// Returns a shared reference to the underlying [`McpConnection`].
1727    pub fn connection(&self) -> Arc<McpConnection> {
1728        self.connection.clone()
1729    }
1730
1731    /// Returns the discovery snapshot captured when the server was connected.
1732    pub fn snapshot(&self) -> &McpDiscoverySnapshot {
1733        &self.snapshot
1734    }
1735
1736    /// Builds a [`ToolRegistry`] containing an [`McpToolAdapter`] for each tool
1737    /// discovered on this server.
1738    pub fn tool_registry(&self) -> ToolRegistry {
1739        self.snapshot
1740            .tools
1741            .iter()
1742            .cloned()
1743            .fold(ToolRegistry::new(), |registry, descriptor| {
1744                registry.with(McpToolAdapter::new(
1745                    self.server_id(),
1746                    self.connection.clone(),
1747                    descriptor,
1748                ))
1749            })
1750    }
1751
1752    /// Builds an [`McpCapabilityProvider`] from this server's discovery snapshot.
1753    pub fn capability_provider(&self) -> McpCapabilityProvider {
1754        McpCapabilityProvider::from_snapshot(self.connection.clone(), &self.snapshot)
1755    }
1756}
1757
1758/// Manages the lifecycle of one or more MCP servers: registration, connection,
1759/// discovery, refresh, disconnection, and auth resolution.
1760///
1761/// This is the primary entry point for integrating MCP servers into an agentkit
1762/// application. Register server configurations, connect them, and then obtain a
1763/// combined [`ToolRegistry`] or [`McpCapabilityProvider`] for use in an agent loop.
1764///
1765/// # Example
1766///
1767/// ```rust,no_run
1768/// use agentkit_mcp::{
1769///     McpServerConfig, McpServerManager, McpTransportBinding, StdioTransportConfig,
1770/// };
1771///
1772/// # #[tokio::main]
1773/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
1774/// let mut manager = McpServerManager::new()
1775///     .with_server(McpServerConfig::new(
1776///         "filesystem",
1777///         McpTransportBinding::Stdio(
1778///             StdioTransportConfig::new("npx")
1779///                 .with_arg("-y")
1780///                 .with_arg("@modelcontextprotocol/server-filesystem"),
1781///         ),
1782///     ))
1783///     .with_server(McpServerConfig::new(
1784///         "github",
1785///         McpTransportBinding::Stdio(
1786///             StdioTransportConfig::new("npx")
1787///                 .with_arg("-y")
1788///                 .with_arg("@modelcontextprotocol/server-github"),
1789///         ),
1790///     ));
1791///
1792/// let handles = manager.connect_all().await?;
1793/// let registry = manager.tool_registry();
1794/// println!("tools: {:?}", registry.specs().iter().map(|s| &s.name).collect::<Vec<_>>());
1795/// # Ok(())
1796/// # }
1797/// ```
1798#[derive(Default)]
1799pub struct McpServerManager {
1800    configs: BTreeMap<McpServerId, McpServerConfig>,
1801    connections: BTreeMap<McpServerId, McpServerHandle>,
1802    auth: BTreeMap<McpServerId, MetadataMap>,
1803}
1804
1805impl McpServerManager {
1806    /// Creates an empty server manager with no registered servers.
1807    pub fn new() -> Self {
1808        Self::default()
1809    }
1810
1811    /// Registers a server configuration and returns `self` for chaining.
1812    ///
1813    /// The server is not connected until [`connect_server`](Self::connect_server) or
1814    /// [`connect_all`](Self::connect_all) is called.
1815    pub fn with_server(mut self, config: McpServerConfig) -> Self {
1816        self.register_server(config);
1817        self
1818    }
1819
1820    /// Registers a server configuration by mutable reference.
1821    ///
1822    /// The server is not connected until [`connect_server`](Self::connect_server) or
1823    /// [`connect_all`](Self::connect_all) is called.
1824    pub fn register_server(&mut self, config: McpServerConfig) -> &mut Self {
1825        self.configs.insert(config.id.clone(), config);
1826        self
1827    }
1828
1829    /// Returns the handle for a connected server, or `None` if it is not connected.
1830    pub fn connected_server(&self, server_id: &McpServerId) -> Option<&McpServerHandle> {
1831        self.connections.get(server_id)
1832    }
1833
1834    /// Returns handles for all currently connected servers.
1835    pub fn connected_servers(&self) -> Vec<&McpServerHandle> {
1836        self.connections.values().collect()
1837    }
1838
1839    /// Connects a single registered server by its identifier.
1840    ///
1841    /// Performs the MCP handshake and full capability discovery.
1842    ///
1843    /// # Errors
1844    ///
1845    /// Returns [`McpError::UnknownServer`] if the server ID has not been registered,
1846    /// or other [`McpError`] variants if connection or discovery fails.
1847    pub async fn connect_server(
1848        &mut self,
1849        server_id: &McpServerId,
1850    ) -> Result<McpServerHandle, McpError> {
1851        let config = self
1852            .configs
1853            .get(server_id)
1854            .cloned()
1855            .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))?;
1856        let connection =
1857            Arc::new(McpConnection::connect_with_auth(&config, self.auth.get(server_id)).await?);
1858        let snapshot = connection.discover().await?;
1859        let handle = McpServerHandle {
1860            config,
1861            connection,
1862            snapshot,
1863        };
1864        self.connections.insert(server_id.clone(), handle.clone());
1865        Ok(handle)
1866    }
1867
1868    /// Connects all registered servers sequentially.
1869    ///
1870    /// Returns a handle for each server in registration order. If any server fails
1871    /// to connect, the error is returned immediately and remaining servers are
1872    /// not attempted.
1873    ///
1874    /// # Errors
1875    ///
1876    /// Returns the first [`McpError`] encountered during connection.
1877    pub async fn connect_all(&mut self) -> Result<Vec<McpServerHandle>, McpError> {
1878        let server_ids = self.configs.keys().cloned().collect::<Vec<_>>();
1879        let mut handles = Vec::with_capacity(server_ids.len());
1880
1881        for server_id in server_ids {
1882            handles.push(self.connect_server(&server_id).await?);
1883        }
1884
1885        Ok(handles)
1886    }
1887
1888    /// Re-discovers capabilities for a connected server, updating the stored snapshot.
1889    ///
1890    /// Call this after the server's capabilities may have changed (e.g. after
1891    /// installing a plugin).
1892    ///
1893    /// # Errors
1894    ///
1895    /// Returns [`McpError::UnknownServer`] if the server is not connected, or other
1896    /// [`McpError`] variants if discovery fails.
1897    pub async fn refresh_server(
1898        &mut self,
1899        server_id: &McpServerId,
1900    ) -> Result<McpDiscoverySnapshot, McpError> {
1901        let handle = self
1902            .connections
1903            .get_mut(server_id)
1904            .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))?;
1905        let snapshot = handle.connection.discover().await?;
1906        handle.snapshot = snapshot.clone();
1907        Ok(snapshot)
1908    }
1909
1910    /// Disconnects a server and removes it from the active connections.
1911    ///
1912    /// The server configuration remains registered and can be reconnected later
1913    /// with [`connect_server`](Self::connect_server).
1914    ///
1915    /// # Errors
1916    ///
1917    /// Returns [`McpError::UnknownServer`] if the server is not connected.
1918    pub async fn disconnect_server(&mut self, server_id: &McpServerId) -> Result<(), McpError> {
1919        let Some(handle) = self.connections.remove(server_id) else {
1920            return Err(McpError::UnknownServer(server_id.to_string()));
1921        };
1922        handle.connection.close().await
1923    }
1924
1925    /// Stores or clears authentication credentials for a server and, if already
1926    /// connected, updates the live connection as well.
1927    ///
1928    /// # Errors
1929    ///
1930    /// Returns [`McpError::UnknownServer`] if the server ID from the resolution
1931    /// does not match any registered server.
1932    pub async fn resolve_auth(&mut self, resolution: AuthResolution) -> Result<(), McpError> {
1933        let server_id = resolution
1934            .request()
1935            .server_id()
1936            .ok_or_else(|| McpError::AuthResolution("auth resolution missing server id".into()))?;
1937        let server_id = McpServerId::new(server_id);
1938        match &resolution {
1939            AuthResolution::Provided { credentials, .. } => {
1940                self.auth.insert(server_id.clone(), credentials.clone());
1941            }
1942            AuthResolution::Cancelled { .. } => {
1943                self.auth.remove(&server_id);
1944            }
1945        }
1946
1947        if let Some(handle) = self.connections.get(&server_id) {
1948            handle.connection.resolve_auth(resolution).await?;
1949            return Ok(());
1950        }
1951
1952        if self.configs.contains_key(&server_id) {
1953            Ok(())
1954        } else {
1955            Err(McpError::UnknownServer(server_id.to_string()))
1956        }
1957    }
1958
1959    /// Resolves authentication and immediately replays the operation that originally
1960    /// triggered the auth challenge.
1961    ///
1962    /// This is a convenience method combining [`resolve_auth`](Self::resolve_auth)
1963    /// and [`replay_auth_request`](Self::replay_auth_request).
1964    ///
1965    /// # Errors
1966    ///
1967    /// Returns [`McpError`] if auth resolution or the replayed operation fails.
1968    pub async fn resolve_auth_and_resume(
1969        &mut self,
1970        resolution: AuthResolution,
1971    ) -> Result<McpOperationResult, McpError> {
1972        let request = resolution.request().clone();
1973        self.resolve_auth(resolution).await?;
1974        self.replay_auth_request(&request).await
1975    }
1976
1977    /// Replays an auth request's original MCP operation using stored credentials.
1978    ///
1979    /// For connect operations the server is (re)connected. For tool calls, resource
1980    /// reads, and prompt retrievals the request is re-issued on the existing or
1981    /// newly established connection.
1982    ///
1983    /// # Errors
1984    ///
1985    /// Returns [`McpError`] if the operation cannot be replayed.
1986    pub async fn replay_auth_request(
1987        &mut self,
1988        request: &AuthRequest,
1989    ) -> Result<McpOperationResult, McpError> {
1990        match &request.operation {
1991            AuthOperation::McpConnect { server_id, .. } => {
1992                let server_id = McpServerId::new(server_id);
1993                let handle = self.connect_server(&server_id).await?;
1994                Ok(McpOperationResult::Connected(handle.snapshot.clone()))
1995            }
1996            AuthOperation::McpToolCall { server_id, .. }
1997            | AuthOperation::McpResourceRead { server_id, .. }
1998            | AuthOperation::McpPromptGet { server_id, .. } => {
1999                let connection = self.connection_for_auth_server(server_id).await?;
2000                connection.replay_auth_operation(&request.operation).await
2001            }
2002            AuthOperation::ToolCall { metadata, .. } => {
2003                let server_id = metadata
2004                    .get("server_id")
2005                    .and_then(Value::as_str)
2006                    .ok_or_else(|| {
2007                        McpError::AuthResolution(
2008                            "tool-call auth replay requires metadata.server_id".into(),
2009                        )
2010                    })?;
2011                let connection = self.connection_for_auth_server(server_id).await?;
2012                connection.replay_auth_operation(&request.operation).await
2013            }
2014            AuthOperation::Custom { kind, .. } => Err(McpError::AuthResolution(format!(
2015                "unsupported auth operation for replay: {kind}"
2016            ))),
2017        }
2018    }
2019
2020    async fn connection_for_auth_server(
2021        &mut self,
2022        server_id: &str,
2023    ) -> Result<Arc<McpConnection>, McpError> {
2024        let server_id = McpServerId::new(server_id);
2025        if !self.connections.contains_key(&server_id) {
2026            self.connect_server(&server_id).await?;
2027        }
2028        self.connections
2029            .get(&server_id)
2030            .map(McpServerHandle::connection)
2031            .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))
2032    }
2033
2034    /// Builds a combined [`ToolRegistry`] containing [`McpToolAdapter`]s for every
2035    /// tool discovered across all connected servers.
2036    ///
2037    /// Tool names are prefixed as `mcp_<server_id>_<tool_name>`.
2038    pub fn tool_registry(&self) -> ToolRegistry {
2039        self.connections
2040            .values()
2041            .fold(ToolRegistry::new(), |mut registry, handle| {
2042                for tool in handle.snapshot.tools.iter().cloned() {
2043                    registry.register(McpToolAdapter::new(
2044                        handle.server_id(),
2045                        handle.connection.clone(),
2046                        tool,
2047                    ));
2048                }
2049                registry
2050            })
2051    }
2052
2053    /// Builds a combined [`McpCapabilityProvider`] from all connected servers,
2054    /// merging their tools, resources, and prompts.
2055    pub fn capability_provider(&self) -> McpCapabilityProvider {
2056        McpCapabilityProvider::merge(
2057            self.connections
2058                .values()
2059                .map(McpServerHandle::capability_provider),
2060        )
2061    }
2062}
2063
2064/// Adapter that exposes an MCP tool as an agentkit [`Tool`].
2065///
2066/// This is the tool-layer adapter for the tool registry. For the capabilities-layer
2067/// adapter, see [`McpInvocable`]. Tool names are prefixed as
2068/// `mcp_<server_id>_<tool_name>`.
2069///
2070/// # Example
2071///
2072/// ```rust
2073/// use std::sync::Arc;
2074/// use agentkit_core::MetadataMap;
2075/// use agentkit_mcp::{McpToolAdapter, McpToolDescriptor, McpServerId};
2076/// # // McpToolAdapter::new requires a connection which we cannot construct in a doc test,
2077/// # // so this example only shows the construction pattern.
2078/// ```
2079pub struct McpToolAdapter {
2080    descriptor: McpToolDescriptor,
2081    connection: Arc<McpConnection>,
2082    spec: ToolSpec,
2083}
2084
2085impl McpToolAdapter {
2086    /// Creates a new tool adapter for the given MCP tool.
2087    ///
2088    /// # Arguments
2089    ///
2090    /// * `server_id` - The server's identifier, used to namespace the tool name.
2091    /// * `connection` - A shared connection to the owning MCP server.
2092    /// * `descriptor` - The tool descriptor obtained from discovery.
2093    pub fn new(
2094        server_id: &McpServerId,
2095        connection: Arc<McpConnection>,
2096        descriptor: McpToolDescriptor,
2097    ) -> Self {
2098        let spec = ToolSpec {
2099            name: ToolName::new(format!("mcp_{}_{}", server_id, descriptor.name)),
2100            description: descriptor
2101                .description
2102                .clone()
2103                .unwrap_or_else(|| descriptor.name.clone()),
2104            input_schema: descriptor.input_schema.clone(),
2105            annotations: ToolAnnotations::default(),
2106            metadata: descriptor.metadata.clone(),
2107        };
2108
2109        Self {
2110            descriptor,
2111            connection,
2112            spec,
2113        }
2114    }
2115}
2116
2117#[async_trait]
2118impl Tool for McpToolAdapter {
2119    fn spec(&self) -> &ToolSpec {
2120        &self.spec
2121    }
2122
2123    async fn invoke(
2124        &self,
2125        request: ToolRequest,
2126        _ctx: &mut ToolContext<'_>,
2127    ) -> Result<ToolResult, ToolError> {
2128        let result = self
2129            .connection
2130            .call_tool(&self.descriptor.name, request.input)
2131            .await
2132            .map_err(|error| match error {
2133                McpError::AuthRequired(request) => ToolError::AuthRequired(request),
2134                other => ToolError::ExecutionFailed(other.to_string()),
2135            })?;
2136
2137        Ok(ToolResult {
2138            result: ToolResultPart {
2139                call_id: request.call_id,
2140                output: invocable_output_to_tool_output(value_to_invocable_output(result)),
2141                is_error: false,
2142                metadata: MetadataMap::new(),
2143            },
2144            duration: None,
2145            metadata: MetadataMap::new(),
2146        })
2147    }
2148}
2149
2150fn parse_tool_descriptor(value: Value) -> Result<McpToolDescriptor, McpError> {
2151    Ok(McpToolDescriptor {
2152        name: required_string(&value, "name")?,
2153        description: value
2154            .get("description")
2155            .and_then(Value::as_str)
2156            .map(str::to_owned),
2157        input_schema: value
2158            .get("inputSchema")
2159            .cloned()
2160            .unwrap_or_else(|| json!({ "type": "object" })),
2161        metadata: MetadataMap::new(),
2162    })
2163}
2164
2165fn parse_resource_descriptor(value: Value) -> Result<McpResourceDescriptor, McpError> {
2166    Ok(McpResourceDescriptor {
2167        id: required_string(&value, "uri")?,
2168        name: value
2169            .get("name")
2170            .and_then(Value::as_str)
2171            .map(str::to_owned)
2172            .unwrap_or_else(|| {
2173                value
2174                    .get("uri")
2175                    .and_then(Value::as_str)
2176                    .unwrap_or_default()
2177                    .to_string()
2178            }),
2179        description: value
2180            .get("description")
2181            .and_then(Value::as_str)
2182            .map(str::to_owned),
2183        mime_type: value
2184            .get("mimeType")
2185            .and_then(Value::as_str)
2186            .map(str::to_owned),
2187        metadata: MetadataMap::new(),
2188    })
2189}
2190
2191fn parse_prompt_descriptor(value: Value) -> Result<McpPromptDescriptor, McpError> {
2192    let name = required_string(&value, "name")?;
2193    let properties = value
2194        .get("arguments")
2195        .and_then(Value::as_array)
2196        .cloned()
2197        .unwrap_or_default()
2198        .into_iter()
2199        .filter_map(|arg| {
2200            let name = arg.get("name")?.as_str()?.to_string();
2201            Some((name, json!({ "type": "string" })))
2202        })
2203        .collect::<serde_json::Map<String, Value>>();
2204
2205    Ok(McpPromptDescriptor {
2206        id: name.clone(),
2207        name,
2208        description: value
2209            .get("description")
2210            .and_then(Value::as_str)
2211            .map(str::to_owned),
2212        input_schema: json!({
2213            "type": "object",
2214            "properties": properties,
2215        }),
2216        metadata: MetadataMap::new(),
2217    })
2218}
2219
2220fn parse_prompt_message(value: Value) -> Result<Item, McpError> {
2221    let role = value.get("role").and_then(Value::as_str).unwrap_or("user");
2222    let kind = match role {
2223        "assistant" => ItemKind::Assistant,
2224        "system" => ItemKind::System,
2225        _ => ItemKind::User,
2226    };
2227
2228    let content = value.get("content").cloned().unwrap_or(Value::Null);
2229    let text = if let Some(text) = content.get("text").and_then(Value::as_str) {
2230        text.to_string()
2231    } else if let Some(text) = content.as_str() {
2232        text.to_string()
2233    } else {
2234        content.to_string()
2235    };
2236
2237    Ok(Item {
2238        id: None,
2239        kind,
2240        parts: vec![Part::Text(TextPart {
2241            text,
2242            metadata: MetadataMap::new(),
2243        })],
2244        metadata: MetadataMap::new(),
2245    })
2246}
2247
2248fn required_string(value: &Value, field: &str) -> Result<String, McpError> {
2249    value
2250        .get(field)
2251        .and_then(Value::as_str)
2252        .map(str::to_owned)
2253        .ok_or_else(|| McpError::Protocol(format!("missing string field {field}")))
2254}
2255
2256fn value_to_invocable_output(value: Value) -> InvocableOutput {
2257    if let Some(content) = value.get("content").and_then(Value::as_array) {
2258        let text = content
2259            .iter()
2260            .filter_map(|item| item.get("text").and_then(Value::as_str))
2261            .collect::<Vec<_>>()
2262            .join("\n");
2263        if !text.is_empty() {
2264            return InvocableOutput::Text(text);
2265        }
2266    }
2267
2268    if let Some(text) = value.as_str() {
2269        InvocableOutput::Text(text.to_string())
2270    } else {
2271        InvocableOutput::Structured(value)
2272    }
2273}
2274
2275fn invocable_output_to_tool_output(output: InvocableOutput) -> ToolOutput {
2276    match output {
2277        InvocableOutput::Text(text) => ToolOutput::Text(text),
2278        InvocableOutput::Structured(value) => ToolOutput::Structured(value),
2279        InvocableOutput::Items(items) => {
2280            ToolOutput::Parts(items.into_iter().flat_map(|item| item.parts).collect())
2281        }
2282        InvocableOutput::Data(data) => ToolOutput::Structured(json!({ "data": data })),
2283    }
2284}
2285
2286fn metadata_to_value(metadata: &MetadataMap) -> Value {
2287    Value::Object(
2288        metadata
2289            .iter()
2290            .map(|(key, value)| (key.clone(), value.clone()))
2291            .collect(),
2292    )
2293}
2294
2295fn parse_auth_request(
2296    server_id: &McpServerId,
2297    method: &str,
2298    params: &Value,
2299    error: &Value,
2300) -> Option<AuthRequest> {
2301    let code = error.get("code").and_then(Value::as_i64);
2302    let message = error.get("message").and_then(Value::as_str);
2303    let data = error.get("data");
2304
2305    let auth_marker = matches!(code, Some(401 | -32001))
2306        || data
2307            .and_then(|data| data.get("auth_required"))
2308            .and_then(Value::as_bool)
2309            == Some(true)
2310        || data.and_then(|data| data.get("auth")).is_some();
2311
2312    if !auth_marker {
2313        return None;
2314    }
2315
2316    let mut challenge = MetadataMap::new();
2317    challenge.insert("server_id".into(), Value::String(server_id.to_string()));
2318    challenge.insert("method".into(), Value::String(method.into()));
2319
2320    if let Some(code) = code {
2321        challenge.insert("code".into(), Value::Number(code.into()));
2322    }
2323    if let Some(message) = message {
2324        challenge.insert("message".into(), Value::String(message.into()));
2325    }
2326    if let Some(data) = data {
2327        challenge.insert("data".into(), data.clone());
2328    }
2329
2330    Some(AuthRequest {
2331        task_id: None,
2332        id: format!("mcp:{}:{}", server_id, method),
2333        provider: format!("mcp.{}", server_id),
2334        operation: auth_operation_for_method(server_id, method, params),
2335        challenge,
2336    })
2337}
2338
2339fn auth_operation_for_method(
2340    server_id: &McpServerId,
2341    method: &str,
2342    params: &Value,
2343) -> AuthOperation {
2344    match method {
2345        "initialize" => AuthOperation::McpConnect {
2346            server_id: server_id.to_string(),
2347            metadata: MetadataMap::new(),
2348        },
2349        "tools/call" => AuthOperation::McpToolCall {
2350            server_id: server_id.to_string(),
2351            tool_name: params
2352                .get("name")
2353                .and_then(Value::as_str)
2354                .unwrap_or_default()
2355                .to_string(),
2356            input: params
2357                .get("arguments")
2358                .cloned()
2359                .unwrap_or_else(|| json!({})),
2360            metadata: MetadataMap::new(),
2361        },
2362        "resources/read" => AuthOperation::McpResourceRead {
2363            server_id: server_id.to_string(),
2364            resource_id: params
2365                .get("uri")
2366                .and_then(Value::as_str)
2367                .unwrap_or_default()
2368                .to_string(),
2369            metadata: MetadataMap::new(),
2370        },
2371        "prompts/get" => AuthOperation::McpPromptGet {
2372            server_id: server_id.to_string(),
2373            prompt_id: params
2374                .get("name")
2375                .and_then(Value::as_str)
2376                .unwrap_or_default()
2377                .to_string(),
2378            args: params
2379                .get("arguments")
2380                .cloned()
2381                .unwrap_or_else(|| json!({})),
2382            metadata: MetadataMap::new(),
2383        },
2384        other => AuthOperation::Custom {
2385            kind: format!("mcp.{other}"),
2386            payload: params.clone(),
2387            metadata: {
2388                let mut metadata = MetadataMap::new();
2389                metadata.insert("server_id".into(), Value::String(server_id.to_string()));
2390                metadata
2391            },
2392        },
2393    }
2394}
2395
2396fn normalize_mcp_tool_name(server_id: &McpServerId, tool_name: &str) -> String {
2397    let prefix = format!("mcp_{server_id}_");
2398    tool_name
2399        .strip_prefix(&prefix)
2400        .unwrap_or(tool_name)
2401        .to_string()
2402}
2403
2404async fn read_sse_stream<R>(
2405    mut reader: R,
2406    response_url: Url,
2407    frame_tx: mpsc::UnboundedSender<Result<McpFrame, McpError>>,
2408    endpoint_tx: oneshot::Sender<Result<Url, McpError>>,
2409) where
2410    R: AsyncBufRead + Unpin,
2411{
2412    let mut endpoint_tx = Some(endpoint_tx);
2413    loop {
2414        match read_next_sse_event(&mut reader).await {
2415            Ok(Some(event)) => {
2416                if let Some(endpoint) = legacy_sse_event_to_endpoint(&response_url, &event) {
2417                    if let Some(tx) = endpoint_tx.take() {
2418                        let _ = tx.send(endpoint);
2419                    }
2420                    continue;
2421                }
2422
2423                if let Some(frame) = legacy_sse_event_to_frame(event) {
2424                    let _ = frame_tx.send(frame);
2425                }
2426            }
2427            Ok(None) => break,
2428            Err(error) => {
2429                if let Some(tx) = endpoint_tx.take() {
2430                    let _ = tx.send(Err(error));
2431                } else {
2432                    let _ = frame_tx.send(Err(error));
2433                }
2434                return;
2435            }
2436        }
2437    }
2438
2439    if let Some(tx) = endpoint_tx.take() {
2440        let _ = tx.send(Err(McpError::Transport(
2441            "SSE stream ended before endpoint event".into(),
2442        )));
2443    }
2444}
2445
2446fn resolve_sse_endpoint(response_url: &Url, endpoint: &str) -> Result<Url, McpError> {
2447    response_url
2448        .join(endpoint.trim())
2449        .map_err(|error| McpError::Transport(format!("invalid SSE endpoint URL: {error}")))
2450}
2451
2452#[derive(Debug)]
2453struct SseEvent {
2454    event_name: Option<String>,
2455    data: String,
2456    id: Option<String>,
2457    retry_ms: Option<u64>,
2458}
2459
2460async fn read_next_sse_event<R>(reader: &mut R) -> Result<Option<SseEvent>, McpError>
2461where
2462    R: AsyncBufRead + Unpin,
2463{
2464    let mut event_name = None;
2465    let mut data_lines = Vec::new();
2466    let mut id = None;
2467    let mut retry_ms = None;
2468
2469    loop {
2470        let mut line = String::new();
2471        let read = reader.read_line(&mut line).await.map_err(McpError::Io)?;
2472        if read == 0 {
2473            if event_name.is_none() && data_lines.is_empty() && id.is_none() && retry_ms.is_none() {
2474                return Ok(None);
2475            }
2476            return Ok(Some(SseEvent {
2477                event_name,
2478                data: data_lines.join("\n"),
2479                id,
2480                retry_ms,
2481            }));
2482        }
2483
2484        let line = line.trim_end_matches(['\r', '\n']);
2485        if line.is_empty() {
2486            if event_name.is_none() && data_lines.is_empty() && id.is_none() && retry_ms.is_none() {
2487                continue;
2488            }
2489            return Ok(Some(SseEvent {
2490                event_name,
2491                data: data_lines.join("\n"),
2492                id,
2493                retry_ms,
2494            }));
2495        }
2496
2497        if line.starts_with(':') {
2498            continue;
2499        }
2500
2501        if let Some(rest) = line.strip_prefix("event:") {
2502            event_name = Some(rest.trim_start().to_string());
2503            continue;
2504        }
2505        if let Some(rest) = line.strip_prefix("data:") {
2506            data_lines.push(rest.trim_start().to_string());
2507            continue;
2508        }
2509        if let Some(rest) = line.strip_prefix("id:") {
2510            id = Some(rest.trim_start().to_string());
2511            continue;
2512        }
2513        if let Some(rest) = line.strip_prefix("retry:") {
2514            retry_ms = rest.trim_start().parse().ok();
2515        }
2516    }
2517}
2518
2519fn legacy_sse_event_to_endpoint(
2520    response_url: &Url,
2521    event: &SseEvent,
2522) -> Option<Result<Url, McpError>> {
2523    if event.event_name.as_deref() != Some("endpoint") {
2524        return None;
2525    }
2526    if event.data.is_empty() {
2527        return Some(Err(McpError::Transport(
2528            "legacy SSE endpoint event is missing data".into(),
2529        )));
2530    }
2531    Some(resolve_sse_endpoint(response_url, &event.data))
2532}
2533
2534fn legacy_sse_event_to_frame(event: SseEvent) -> Option<Result<McpFrame, McpError>> {
2535    let event_name = event.event_name.unwrap_or_else(|| "message".into());
2536    if event_name != "message" || event.data.is_empty() {
2537        return None;
2538    }
2539
2540    Some(
2541        serde_json::from_str(&event.data)
2542            .map_err(McpError::Serialize)
2543            .map(|value| McpFrame { value }),
2544    )
2545}
2546
2547fn streamable_http_event_to_frame(event: SseEvent) -> Result<Option<McpFrame>, McpError> {
2548    let event_name = event.event_name.unwrap_or_else(|| "message".into());
2549    if event_name != "message" || event.data.is_empty() {
2550        return Ok(None);
2551    }
2552
2553    let value = serde_json::from_str(&event.data).map_err(McpError::Serialize)?;
2554    Ok(Some(McpFrame { value }))
2555}
2556
2557fn is_jsonrpc_request(value: &Value) -> bool {
2558    value.get("method").is_some() && value.get("id").is_some()
2559}
2560
2561fn apply_streamable_http_headers(
2562    mut request: reqwest::RequestBuilder,
2563    headers: &[(String, String)],
2564    protocol_version: Option<&str>,
2565    session_id: Option<&str>,
2566) -> reqwest::RequestBuilder {
2567    for (key, value) in headers {
2568        request = request.header(key, value);
2569    }
2570
2571    if let Some(protocol_version) = protocol_version {
2572        request = request.header("MCP-Protocol-Version", protocol_version);
2573    }
2574    if let Some(session_id) = session_id {
2575        request = request.header("MCP-Session-Id", session_id);
2576    }
2577
2578    request
2579}
2580
2581async fn streamable_http_status_error(
2582    operation: &str,
2583    status: StatusCode,
2584    response: reqwest::Response,
2585) -> McpError {
2586    let body = response
2587        .text()
2588        .await
2589        .unwrap_or_else(|_| "<unreadable response body>".into());
2590    McpError::Transport(format!("{operation} failed with status {status}: {body}"))
2591}
2592
2593/// Errors produced by MCP transport, protocol, and lifecycle operations.
2594#[derive(Debug, Error)]
2595pub enum McpError {
2596    /// An underlying I/O error (e.g. spawning a child process or reading from a pipe).
2597    #[error("io error: {0}")]
2598    Io(#[from] std::io::Error),
2599    /// An HTTP-level error from the SSE transport.
2600    #[error("http error: {0}")]
2601    Http(#[from] reqwest::Error),
2602    /// A JSON serialization or deserialization error.
2603    #[error("serialization error: {0}")]
2604    Serialize(#[from] serde_json::Error),
2605    /// A transport-level error (e.g. unexpected disconnection or bad SSE response).
2606    #[error("transport error: {0}")]
2607    Transport(String),
2608    /// An MCP protocol violation (e.g. missing required fields in a response).
2609    #[error("protocol error: {0}")]
2610    Protocol(String),
2611    /// The server requires authentication before the operation can proceed.
2612    /// Contains the [`AuthRequest`] that describes the challenge.
2613    #[error("MCP auth required: {0:?}")]
2614    AuthRequired(Box<AuthRequest>),
2615    /// An error occurred while resolving or replaying authentication.
2616    #[error("auth resolution error: {0}")]
2617    AuthResolution(String),
2618    /// The MCP server returned an error for the invoked method.
2619    #[error("invocation error: {0}")]
2620    Invocation(String),
2621    /// The referenced server ID is not registered in the [`McpServerManager`].
2622    #[error("unknown MCP server: {0}")]
2623    UnknownServer(String),
2624}
2625
2626#[cfg(test)]
2627mod tests {
2628    use std::collections::VecDeque;
2629    use std::sync::{Arc as StdArc, Mutex as StdMutex};
2630
2631    use super::*;
2632    use agentkit_tools_core::{PermissionChecker, PermissionDecision, PermissionRequest};
2633    use tokio::io::{AsyncReadExt, AsyncWriteExt};
2634    use tokio::net::TcpListener;
2635
2636    struct AllowAll;
2637
2638    impl PermissionChecker for AllowAll {
2639        fn evaluate(&self, _request: &dyn PermissionRequest) -> PermissionDecision {
2640            PermissionDecision::Allow
2641        }
2642    }
2643
2644    struct FakeTransport {
2645        recv: VecDeque<Value>,
2646    }
2647
2648    #[async_trait]
2649    impl McpTransport for FakeTransport {
2650        async fn send(&mut self, _message: McpFrame) -> Result<(), McpError> {
2651            Ok(())
2652        }
2653
2654        async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
2655            Ok(self.recv.pop_front().map(|value| McpFrame { value }))
2656        }
2657
2658        async fn close(&mut self) -> Result<(), McpError> {
2659            Ok(())
2660        }
2661    }
2662
2663    fn fake_connection(responses: Vec<Value>) -> McpConnection {
2664        McpConnection {
2665            server_id: McpServerId::new("fake"),
2666            transport: Mutex::new(Box::new(FakeTransport {
2667                recv: responses.into(),
2668            })),
2669            auth: Mutex::new(None),
2670            next_id: AtomicU64::new(1),
2671        }
2672    }
2673
2674    #[derive(Clone)]
2675    struct FakeTransportFactory {
2676        responses: StdArc<StdMutex<VecDeque<Vec<Value>>>>,
2677    }
2678
2679    impl FakeTransportFactory {
2680        fn new(sequences: Vec<Vec<Value>>) -> Self {
2681            Self {
2682                responses: StdArc::new(StdMutex::new(sequences.into())),
2683            }
2684        }
2685    }
2686
2687    #[async_trait]
2688    impl McpTransportFactory for FakeTransportFactory {
2689        async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
2690            let responses =
2691                self.responses.lock().unwrap().pop_front().ok_or_else(|| {
2692                    McpError::Transport("no fake transport responses left".into())
2693                })?;
2694            Ok(Box::new(FakeTransport {
2695                recv: responses.into(),
2696            }))
2697        }
2698    }
2699
2700    #[tokio::test]
2701    async fn discovery_parses_snapshot() {
2702        let connection = fake_connection(vec![
2703            json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "echo", "description": "Echo", "inputSchema": {"type": "object"} }] } }),
2704            json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [{ "uri": "file:///tmp/example.txt", "name": "example.txt", "mimeType": "text/plain" }] } }),
2705            json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [{ "name": "summarize", "description": "Summarize", "arguments": [{ "name": "path" }] }] } }),
2706        ]);
2707
2708        let snapshot = connection.discover().await.unwrap();
2709        assert_eq!(snapshot.tools[0].name, "echo");
2710        assert_eq!(snapshot.resources[0].id, "file:///tmp/example.txt");
2711        assert_eq!(snapshot.prompts[0].id, "summarize");
2712    }
2713
2714    #[tokio::test]
2715    async fn tool_adapter_returns_text_output() {
2716        let connection = Arc::new(fake_connection(vec![json!({
2717            "jsonrpc": "2.0",
2718            "id": 1,
2719            "result": { "content": [{ "type": "text", "text": "pong" }] }
2720        })]));
2721        let server_id = connection.server_id().clone();
2722        let adapter = McpToolAdapter::new(
2723            &server_id,
2724            connection,
2725            McpToolDescriptor {
2726                name: "echo".into(),
2727                description: Some("Echo".into()),
2728                input_schema: json!({ "type": "object" }),
2729                metadata: MetadataMap::new(),
2730            },
2731        );
2732        let metadata = MetadataMap::new();
2733        let mut ctx = ToolContext {
2734            capability: CapabilityContext {
2735                session_id: None,
2736                turn_id: None,
2737                metadata: &metadata,
2738            },
2739            permissions: &AllowAll,
2740            resources: &(),
2741            cancellation: None,
2742        };
2743
2744        let result = adapter
2745            .invoke(
2746                ToolRequest {
2747                    call_id: "call-1".into(),
2748                    tool_name: ToolName::new("mcp_fake_echo"),
2749                    input: json!({}),
2750                    session_id: "session-1".into(),
2751                    turn_id: "turn-1".into(),
2752                    metadata: MetadataMap::new(),
2753                },
2754                &mut ctx,
2755            )
2756            .await
2757            .unwrap();
2758
2759        assert_eq!(result.result.output, ToolOutput::Text("pong".into()));
2760    }
2761
2762    #[tokio::test]
2763    async fn request_surfaces_auth_required_errors() {
2764        let connection = fake_connection(vec![json!({
2765            "jsonrpc": "2.0",
2766            "id": 1,
2767            "error": {
2768                "code": -32001,
2769                "message": "authentication required",
2770                "data": {
2771                    "auth_required": true,
2772                    "scope": "secrets.read"
2773                }
2774            }
2775        })]);
2776
2777        let error = connection.call_tool("echo", json!({})).await.unwrap_err();
2778        match error {
2779            McpError::AuthRequired(request) => {
2780                assert_eq!(request.provider, "mcp.fake");
2781                assert_eq!(
2782                    request.challenge.get("method"),
2783                    Some(&Value::String("tools/call".into()))
2784                );
2785                assert!(matches!(
2786                    request.operation,
2787                    AuthOperation::McpToolCall { ref tool_name, .. } if tool_name == "echo"
2788                ));
2789            }
2790            other => panic!("unexpected error: {other:?}"),
2791        }
2792    }
2793
2794    #[tokio::test]
2795    async fn tool_adapter_maps_auth_required_into_tool_error() {
2796        let connection = Arc::new(fake_connection(vec![json!({
2797            "jsonrpc": "2.0",
2798            "id": 1,
2799            "error": {
2800                "code": -32001,
2801                "message": "authentication required",
2802                "data": { "auth_required": true }
2803            }
2804        })]));
2805        let server_id = connection.server_id().clone();
2806        let adapter = McpToolAdapter::new(
2807            &server_id,
2808            connection,
2809            McpToolDescriptor {
2810                name: "echo".into(),
2811                description: Some("Echo".into()),
2812                input_schema: json!({ "type": "object" }),
2813                metadata: MetadataMap::new(),
2814            },
2815        );
2816        let metadata = MetadataMap::new();
2817        let mut ctx = ToolContext {
2818            capability: CapabilityContext {
2819                session_id: None,
2820                turn_id: None,
2821                metadata: &metadata,
2822            },
2823            permissions: &AllowAll,
2824            resources: &(),
2825            cancellation: None,
2826        };
2827
2828        let error = adapter
2829            .invoke(
2830                ToolRequest {
2831                    call_id: "call-1".into(),
2832                    tool_name: ToolName::new("mcp_fake_echo"),
2833                    input: json!({}),
2834                    session_id: "session-1".into(),
2835                    turn_id: "turn-1".into(),
2836                    metadata: MetadataMap::new(),
2837                },
2838                &mut ctx,
2839            )
2840            .await
2841            .unwrap_err();
2842
2843        match error {
2844            ToolError::AuthRequired(request) => {
2845                assert_eq!(request.provider, "mcp.fake");
2846            }
2847            other => panic!("unexpected error: {other:?}"),
2848        }
2849    }
2850
2851    struct RecordingTransport {
2852        recv: VecDeque<Value>,
2853        sent: StdArc<StdMutex<Vec<Value>>>,
2854    }
2855
2856    #[async_trait]
2857    impl McpTransport for RecordingTransport {
2858        async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
2859            self.sent.lock().unwrap().push(message.value);
2860            Ok(())
2861        }
2862
2863        async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
2864            Ok(self.recv.pop_front().map(|value| McpFrame { value }))
2865        }
2866
2867        async fn close(&mut self) -> Result<(), McpError> {
2868            Ok(())
2869        }
2870    }
2871
2872    #[derive(Clone)]
2873    struct RecordingTransportFactory {
2874        responses: StdArc<StdMutex<VecDeque<Vec<Value>>>>,
2875        sent: StdArc<StdMutex<Vec<Value>>>,
2876    }
2877
2878    impl RecordingTransportFactory {
2879        fn new(sequences: Vec<Vec<Value>>) -> Self {
2880            Self {
2881                responses: StdArc::new(StdMutex::new(sequences.into())),
2882                sent: StdArc::new(StdMutex::new(Vec::new())),
2883            }
2884        }
2885
2886        fn sent(&self) -> Vec<Value> {
2887            self.sent.lock().unwrap().clone()
2888        }
2889    }
2890
2891    #[async_trait]
2892    impl McpTransportFactory for RecordingTransportFactory {
2893        async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
2894            let responses = self.responses.lock().unwrap().pop_front().ok_or_else(|| {
2895                McpError::Transport("no recording transport responses left".into())
2896            })?;
2897            Ok(Box::new(RecordingTransport {
2898                recv: responses.into(),
2899                sent: self.sent.clone(),
2900            }))
2901        }
2902    }
2903
2904    #[tokio::test]
2905    async fn connection_includes_resolved_auth_in_future_requests() {
2906        let factory = RecordingTransportFactory::new(vec![vec![
2907            json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
2908            json!({ "jsonrpc": "2.0", "id": 1, "result": { "content": [{ "type": "text", "text": "ok" }] } }),
2909        ]]);
2910        let config = McpServerConfig::new(
2911            "recording",
2912            McpTransportBinding::Custom(Arc::new(factory.clone())),
2913        );
2914        let connection = McpConnection::connect(&config).await.unwrap();
2915        let mut auth = MetadataMap::new();
2916        auth.insert("token".into(), json!("secret-token"));
2917        let request = AuthRequest {
2918            task_id: None,
2919            id: "auth-recording-tool".into(),
2920            provider: "mcp.recording".into(),
2921            operation: AuthOperation::McpToolCall {
2922                server_id: "recording".into(),
2923                tool_name: "echo".into(),
2924                input: json!({}),
2925                metadata: MetadataMap::new(),
2926            },
2927            challenge: MetadataMap::new(),
2928        };
2929        connection
2930            .resolve_auth(agentkit_tools_core::AuthResolution::Provided {
2931                request,
2932                credentials: auth,
2933            })
2934            .await
2935            .unwrap();
2936
2937        let _ = connection.call_tool("echo", json!({})).await.unwrap();
2938        let sent = factory.sent();
2939        assert!(
2940            sent.iter().any(|value| {
2941                value
2942                    .get("params")
2943                    .and_then(|params| params.get("auth"))
2944                    .and_then(|auth| auth.get("token"))
2945                    == Some(&json!("secret-token"))
2946            }),
2947            "expected an MCP request to include the resolved auth payload, saw {:?}",
2948            sent
2949        );
2950    }
2951
2952    #[tokio::test]
2953    async fn manager_reuses_stored_auth_on_connect() {
2954        let factory = RecordingTransportFactory::new(vec![vec![
2955            json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
2956            json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
2957            json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
2958            json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
2959        ]]);
2960        let server_id = McpServerId::new("recording");
2961        let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
2962            server_id.to_string(),
2963            McpTransportBinding::Custom(Arc::new(factory.clone())),
2964        ));
2965        let mut auth = MetadataMap::new();
2966        auth.insert("token".into(), json!("seed-token"));
2967        let request = AuthRequest {
2968            task_id: None,
2969            id: "auth-recording-connect".into(),
2970            provider: "mcp.recording".into(),
2971            operation: AuthOperation::McpConnect {
2972                server_id: server_id.to_string(),
2973                metadata: MetadataMap::new(),
2974            },
2975            challenge: MetadataMap::new(),
2976        };
2977        manager
2978            .resolve_auth(agentkit_tools_core::AuthResolution::Provided {
2979                request,
2980                credentials: auth,
2981            })
2982            .await
2983            .unwrap();
2984
2985        manager.connect_server(&server_id).await.unwrap();
2986        let sent = factory.sent();
2987        assert!(
2988            sent.iter().any(|value| {
2989                value.get("method").and_then(Value::as_str) == Some("initialize")
2990                    && value
2991                        .get("params")
2992                        .and_then(|params| params.get("auth"))
2993                        .and_then(|auth| auth.get("token"))
2994                        == Some(&json!("seed-token"))
2995            }),
2996            "expected initialize to include stored auth, saw {:?}",
2997            sent
2998        );
2999    }
3000
3001    #[tokio::test]
3002    async fn manager_resolves_auth_and_replays_resource_read() {
3003        let factory = RecordingTransportFactory::new(vec![vec![
3004            json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
3005            json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
3006            json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3007            json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3008            json!({
3009                "jsonrpc": "2.0",
3010                "id": 4,
3011                "result": {
3012                    "contents": [
3013                        {
3014                            "uri": "file:///tmp/secret.txt",
3015                            "text": "secret from resource"
3016                        }
3017                    ]
3018                }
3019            }),
3020        ]]);
3021        let server_id = McpServerId::new("recording");
3022        let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
3023            server_id.to_string(),
3024            McpTransportBinding::Custom(Arc::new(factory.clone())),
3025        ));
3026        let mut auth = MetadataMap::new();
3027        auth.insert("token".into(), json!("resource-token"));
3028        let request = AuthRequest {
3029            task_id: None,
3030            id: "auth-recording-resource".into(),
3031            provider: "mcp.recording".into(),
3032            operation: AuthOperation::McpResourceRead {
3033                server_id: server_id.to_string(),
3034                resource_id: "file:///tmp/secret.txt".into(),
3035                metadata: MetadataMap::new(),
3036            },
3037            challenge: MetadataMap::new(),
3038        };
3039
3040        let result = manager
3041            .resolve_auth_and_resume(agentkit_tools_core::AuthResolution::Provided {
3042                request,
3043                credentials: auth,
3044            })
3045            .await
3046            .unwrap();
3047
3048        match result {
3049            McpOperationResult::Resource(contents) => {
3050                assert_eq!(
3051                    contents.data,
3052                    DataRef::InlineText("secret from resource".into())
3053                );
3054            }
3055            other => panic!("unexpected replay result: {other:?}"),
3056        }
3057
3058        let sent = factory.sent();
3059        assert!(
3060            sent.iter().any(|value| {
3061                value.get("method").and_then(Value::as_str) == Some("resources/read")
3062                    && value
3063                        .get("params")
3064                        .and_then(|params| params.get("auth"))
3065                        .and_then(|auth| auth.get("token"))
3066                        == Some(&json!("resource-token"))
3067            }),
3068            "expected resources/read to include resolved auth, saw {:?}",
3069            sent
3070        );
3071    }
3072
3073    #[tokio::test]
3074    async fn manager_resolves_auth_and_replays_connect() {
3075        let factory = RecordingTransportFactory::new(vec![vec![
3076            json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
3077            json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
3078            json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3079            json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3080        ]]);
3081        let server_id = McpServerId::new("recording");
3082        let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
3083            server_id.to_string(),
3084            McpTransportBinding::Custom(Arc::new(factory.clone())),
3085        ));
3086        let mut auth = MetadataMap::new();
3087        auth.insert("token".into(), json!("connect-token"));
3088        let request = AuthRequest {
3089            task_id: None,
3090            id: "auth-recording-connect-replay".into(),
3091            provider: "mcp.recording".into(),
3092            operation: AuthOperation::McpConnect {
3093                server_id: server_id.to_string(),
3094                metadata: MetadataMap::new(),
3095            },
3096            challenge: MetadataMap::new(),
3097        };
3098
3099        let result = manager
3100            .resolve_auth_and_resume(agentkit_tools_core::AuthResolution::Provided {
3101                request,
3102                credentials: auth,
3103            })
3104            .await
3105            .unwrap();
3106
3107        match result {
3108            McpOperationResult::Connected(snapshot) => {
3109                assert_eq!(snapshot.server_id, server_id);
3110            }
3111            other => panic!("unexpected replay result: {other:?}"),
3112        }
3113    }
3114
3115    #[tokio::test]
3116    async fn sse_transport_posts_messages_and_receives_frames() {
3117        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3118        let address = listener.local_addr().unwrap();
3119        let requests = StdArc::new(StdMutex::new(Vec::new()));
3120        let captured = requests.clone();
3121
3122        let server = tokio::spawn(async move {
3123            for _ in 0..2 {
3124                let (mut socket, _) = listener.accept().await.unwrap();
3125                let mut buffer = vec![0_u8; 4096];
3126                let read = socket.read(&mut buffer).await.unwrap();
3127                let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3128
3129                if request.starts_with("GET /sse ") {
3130                    let body = concat!(
3131                        "event: endpoint\n",
3132                        "data: /messages\n\n",
3133                        "event: message\n",
3134                        "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}\n\n"
3135                    );
3136                    let response = format!(
3137                        "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3138                        body.len(),
3139                        body
3140                    );
3141                    socket.write_all(response.as_bytes()).await.unwrap();
3142                } else {
3143                    captured.lock().unwrap().push(request);
3144                    socket
3145                        .write_all(
3146                            b"HTTP/1.1 202 Accepted\r\ncontent-length: 0\r\nconnection: close\r\n\r\n",
3147                        )
3148                        .await
3149                        .unwrap();
3150                }
3151            }
3152        });
3153
3154        let factory =
3155            SseTransportFactory::new(SseTransportConfig::new(format!("http://{address}/sse")));
3156        let mut transport = factory.connect().await.unwrap();
3157        transport
3158            .send(McpFrame {
3159                value: json!({
3160                    "jsonrpc": "2.0",
3161                    "id": 1,
3162                    "method": "tools/list",
3163                    "params": {}
3164                }),
3165            })
3166            .await
3167            .unwrap();
3168        let frame = transport.recv().await.unwrap().unwrap();
3169        transport.close().await.unwrap();
3170        server.await.unwrap();
3171
3172        assert_eq!(frame.value["result"]["tools"], json!([]));
3173        let requests = requests.lock().unwrap();
3174        assert_eq!(requests.len(), 1);
3175        assert!(requests[0].starts_with("POST /messages "));
3176        assert!(requests[0].contains("\"method\":\"tools/list\""));
3177    }
3178
3179    #[tokio::test]
3180    async fn streamable_http_connection_tracks_session_and_protocol_headers() {
3181        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3182        let address = listener.local_addr().unwrap();
3183        let requests = StdArc::new(StdMutex::new(Vec::new()));
3184        let captured = requests.clone();
3185
3186        let server = tokio::spawn(async move {
3187            for _ in 0..4 {
3188                let (mut socket, _) = listener.accept().await.unwrap();
3189                let mut buffer = vec![0_u8; 8192];
3190                let read = socket.read(&mut buffer).await.unwrap();
3191                let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3192                captured.lock().unwrap().push(request.clone());
3193
3194                let response = if request.contains("\"method\":\"initialize\"") {
3195                    let body = "{\"jsonrpc\":\"2.0\",\"id\":0,\"result\":{\"protocolVersion\":\"2025-11-25\",\"capabilities\":{},\"serverInfo\":{\"name\":\"remote\",\"version\":\"1.0.0\"}}}";
3196                    format!(
3197                        "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\nMCP-Session-Id: session-123\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3198                        body.len(),
3199                        body
3200                    )
3201                } else if request.contains("\"method\":\"notifications/initialized\"") {
3202                    "HTTP/1.1 202 Accepted\r\ncontent-length: 0\r\nconnection: close\r\n\r\n"
3203                        .to_string()
3204                } else if request.starts_with("DELETE /mcp ") {
3205                    "HTTP/1.1 204 No Content\r\ncontent-length: 0\r\nconnection: close\r\n\r\n"
3206                        .to_string()
3207                } else {
3208                    let body = "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}";
3209                    format!(
3210                        "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3211                        body.len(),
3212                        body
3213                    )
3214                };
3215
3216                socket.write_all(response.as_bytes()).await.unwrap();
3217            }
3218        });
3219
3220        let config = McpServerConfig::new(
3221            "remote",
3222            McpTransportBinding::StreamableHttp(StreamableHttpTransportConfig::new(format!(
3223                "http://{address}/mcp"
3224            ))),
3225        );
3226        let connection = McpConnection::connect(&config).await.unwrap();
3227        let _ = connection.list_tools().await.unwrap();
3228        connection.close().await.unwrap();
3229        server.await.unwrap();
3230
3231        let requests = requests.lock().unwrap();
3232        assert_eq!(requests.len(), 4);
3233        let normalized = requests
3234            .iter()
3235            .map(|request| request.to_ascii_lowercase())
3236            .collect::<Vec<_>>();
3237        assert!(requests[0].starts_with("POST /mcp "));
3238        assert!(!requests[0].contains("MCP-Session-Id:"));
3239        assert!(normalized[1].contains("mcp-session-id: session-123"));
3240        assert!(normalized[1].contains("mcp-protocol-version: 2025-11-25"));
3241        assert!(normalized[2].contains("mcp-session-id: session-123"));
3242        assert!(normalized[2].contains("mcp-protocol-version: 2025-11-25"));
3243        assert!(requests[3].starts_with("DELETE /mcp "));
3244        assert!(normalized[3].contains("mcp-session-id: session-123"));
3245    }
3246
3247    #[tokio::test]
3248    async fn streamable_http_transport_resumes_sse_streams_until_response_arrives() {
3249        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3250        let address = listener.local_addr().unwrap();
3251        let requests = StdArc::new(StdMutex::new(Vec::new()));
3252        let captured = requests.clone();
3253
3254        let server = tokio::spawn(async move {
3255            for _ in 0..2 {
3256                let (mut socket, _) = listener.accept().await.unwrap();
3257                let mut buffer = vec![0_u8; 8192];
3258                let read = socket.read(&mut buffer).await.unwrap();
3259                let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3260                captured.lock().unwrap().push(request.clone());
3261
3262                let response = if request.starts_with("POST /mcp ") {
3263                    let body = concat!(
3264                        "id: evt-1\n",
3265                        "event: message\n",
3266                        "data: {\"jsonrpc\":\"2.0\",\"method\":\"notifications/message\",\"params\":{\"phase\":\"stream-start\"}}\n\n"
3267                    );
3268                    format!(
3269                        "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3270                        body.len(),
3271                        body
3272                    )
3273                } else {
3274                    let body = concat!(
3275                        "id: evt-2\n",
3276                        "event: message\n",
3277                        "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}\n\n"
3278                    );
3279                    format!(
3280                        "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3281                        body.len(),
3282                        body
3283                    )
3284                };
3285
3286                socket.write_all(response.as_bytes()).await.unwrap();
3287            }
3288        });
3289
3290        let factory = StreamableHttpTransportFactory::new(StreamableHttpTransportConfig::new(
3291            format!("http://{address}/mcp"),
3292        ));
3293        let mut transport = factory.connect().await.unwrap();
3294        transport
3295            .send(McpFrame {
3296                value: json!({
3297                    "jsonrpc": "2.0",
3298                    "id": 1,
3299                    "method": "tools/list",
3300                    "params": {}
3301                }),
3302            })
3303            .await
3304            .unwrap();
3305
3306        let first = transport.recv().await.unwrap().unwrap();
3307        let second = transport.recv().await.unwrap().unwrap();
3308        transport.close().await.unwrap();
3309        server.await.unwrap();
3310
3311        assert_eq!(
3312            first.value["method"],
3313            Value::String("notifications/message".into())
3314        );
3315        assert_eq!(second.value["result"]["tools"], json!([]));
3316
3317        let requests = requests.lock().unwrap();
3318        assert_eq!(requests.len(), 2);
3319        assert!(requests[0].starts_with("POST /mcp "));
3320        assert!(requests[1].starts_with("GET /mcp "));
3321        assert!(
3322            requests[1].contains("last-event-id: evt-1")
3323                || requests[1].contains("Last-Event-ID: evt-1")
3324        );
3325    }
3326
3327    #[tokio::test]
3328    async fn server_manager_connects_refreshes_and_aggregates_tools() {
3329        let alpha = McpServerConfig::new(
3330            "alpha",
3331            McpTransportBinding::Custom(Arc::new(FakeTransportFactory::new(vec![vec![
3332                json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "alpha", "version": "1.0.0" } } }),
3333                json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "echo", "description": "Echo", "inputSchema": {"type": "object"} }] } }),
3334                json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3335                json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3336                json!({ "jsonrpc": "2.0", "id": 4, "result": { "tools": [{ "name": "echo_v2", "description": "Echo 2", "inputSchema": {"type": "object"} }] } }),
3337                json!({ "jsonrpc": "2.0", "id": 5, "result": { "resources": [] } }),
3338                json!({ "jsonrpc": "2.0", "id": 6, "result": { "prompts": [] } }),
3339            ]]))),
3340        );
3341        let beta = McpServerConfig::new(
3342            "beta",
3343            McpTransportBinding::Custom(Arc::new(FakeTransportFactory::new(vec![vec![
3344                json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "beta", "version": "1.0.0" } } }),
3345                json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "search", "description": "Search", "inputSchema": {"type": "object"} }] } }),
3346                json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3347                json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3348            ]]))),
3349        );
3350
3351        let mut manager = McpServerManager::new().with_server(alpha).with_server(beta);
3352
3353        let handles = manager.connect_all().await.unwrap();
3354        assert_eq!(handles.len(), 2);
3355        assert_eq!(
3356            manager
3357                .tool_registry()
3358                .specs()
3359                .into_iter()
3360                .map(|spec| spec.name.0)
3361                .collect::<Vec<_>>(),
3362            vec!["mcp_alpha_echo".to_string(), "mcp_beta_search".to_string()]
3363        );
3364
3365        let refreshed = manager
3366            .refresh_server(&McpServerId::new("alpha"))
3367            .await
3368            .unwrap();
3369        assert_eq!(refreshed.tools[0].name, "echo_v2");
3370        assert_eq!(
3371            manager
3372                .connected_server(&McpServerId::new("alpha"))
3373                .unwrap()
3374                .snapshot()
3375                .tools[0]
3376                .name,
3377            "echo_v2"
3378        );
3379
3380        let capabilities = manager.capability_provider();
3381        assert_eq!(capabilities.invocables().len(), 2);
3382
3383        manager
3384            .disconnect_server(&McpServerId::new("alpha"))
3385            .await
3386            .unwrap();
3387        assert!(
3388            manager
3389                .connected_server(&McpServerId::new("alpha"))
3390                .is_none()
3391        );
3392    }
3393}