Skip to main content

agentkit_mcp/
lib.rs

1use std::collections::{BTreeMap, VecDeque};
2use std::fmt;
3use std::process::Stdio;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::time::Duration;
7
8use agentkit_capabilities::{
9    CapabilityContext, CapabilityError, CapabilityName, CapabilityProvider, Invocable,
10    InvocableOutput, InvocableRequest, InvocableResult, InvocableSpec, PromptContents,
11    PromptDescriptor, PromptId, PromptProvider, ResourceContents, ResourceDescriptor, ResourceId,
12    ResourceProvider,
13};
14use agentkit_core::{
15    DataRef, Item, ItemKind, MetadataMap, Part, TextPart, ToolOutput, ToolResultPart,
16};
17use agentkit_http::{
18    HeaderMap, Http, HttpError, HttpRequestBuilder, HttpResponse, StatusCode, header as http_header,
19};
20use agentkit_tools_core::{
21    AuthOperation, AuthRequest, AuthResolution, Tool, ToolAnnotations, ToolContext, ToolError,
22    ToolName, ToolRegistry, ToolRequest, ToolResult, ToolSpec,
23};
24use async_trait::async_trait;
25use futures_util::TryStreamExt;
26use serde::{Deserialize, Serialize};
27use serde_json::{Value, json};
28use thiserror::Error;
29use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader};
30use tokio::process::{Child, ChildStdin, ChildStdout, Command};
31use tokio::sync::{Mutex, mpsc, oneshot};
32use tokio::task::JoinHandle;
33use tokio::time::sleep;
34use tokio_util::io::StreamReader;
35use url::Url;
36
37const MCP_LATEST_PROTOCOL_VERSION: &str = "2025-11-25";
38const MCP_SUPPORTED_PROTOCOL_VERSIONS: &[&str] =
39    &["2025-11-25", "2025-06-18", "2025-03-26", "2024-11-05"];
40
41/// Unique identifier for a registered MCP server.
42///
43/// Each MCP server in a [`McpServerManager`] is addressed by its `McpServerId`.
44/// The inner string is typically a short, human-readable name such as `"filesystem"`
45/// or `"github"`.
46///
47/// # Example
48///
49/// ```rust
50/// use agentkit_mcp::McpServerId;
51///
52/// let id = McpServerId::new("filesystem");
53/// assert_eq!(id.to_string(), "filesystem");
54/// ```
55#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
56pub struct McpServerId(pub String);
57
58impl McpServerId {
59    /// Creates a new server identifier from any string-like value.
60    pub fn new(value: impl Into<String>) -> Self {
61        Self(value.into())
62    }
63}
64
65impl fmt::Display for McpServerId {
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        self.0.fmt(f)
68    }
69}
70
71/// Configuration for an MCP server that communicates over standard I/O (stdin/stdout).
72///
73/// This is the most common transport for local MCP servers. The specified command is
74/// spawned as a child process, and JSON-RPC messages are exchanged line-by-line over
75/// its stdin and stdout streams.
76///
77/// # Example
78///
79/// ```rust
80/// use agentkit_mcp::StdioTransportConfig;
81///
82/// let config = StdioTransportConfig::new("npx")
83///     .with_arg("-y")
84///     .with_arg("@modelcontextprotocol/server-filesystem")
85///     .with_env("HOME", "/home/user")
86///     .with_cwd("/tmp");
87/// ```
88#[derive(Clone, Debug, PartialEq, Eq)]
89pub struct StdioTransportConfig {
90    /// The executable to launch (e.g. `"npx"`, `"python"`, `"node"`).
91    pub command: String,
92    /// Command-line arguments passed to the executable.
93    pub args: Vec<String>,
94    /// Additional environment variables set for the child process.
95    pub env: Vec<(String, String)>,
96    /// Optional working directory for the child process.
97    pub cwd: Option<std::path::PathBuf>,
98}
99
100impl StdioTransportConfig {
101    /// Creates a new stdio transport configuration for the given command.
102    pub fn new(command: impl Into<String>) -> Self {
103        Self {
104            command: command.into(),
105            args: Vec::new(),
106            env: Vec::new(),
107            cwd: None,
108        }
109    }
110
111    /// Appends a command-line argument. Returns `self` for chaining.
112    pub fn with_arg(mut self, arg: impl Into<String>) -> Self {
113        self.args.push(arg.into());
114        self
115    }
116
117    /// Adds an environment variable for the child process. Returns `self` for chaining.
118    pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
119        self.env.push((key.into(), value.into()));
120        self
121    }
122
123    /// Sets the working directory for the child process. Returns `self` for chaining.
124    pub fn with_cwd(mut self, cwd: impl Into<std::path::PathBuf>) -> Self {
125        self.cwd = Some(cwd.into());
126        self
127    }
128}
129
130/// Configuration for an MCP server that communicates over Server-Sent Events (SSE).
131///
132/// Use this transport for remote MCP servers exposed over HTTP. The client opens an
133/// SSE stream to the given URL, receives an `endpoint` event pointing to the POST
134/// endpoint, and then exchanges JSON-RPC messages over that endpoint.
135///
136/// Auth headers and other per-request customisation live on the [`Http`] client
137/// — either via `reqwest::ClientBuilder::default_headers` or a custom
138/// [`agentkit_http::HttpClient`] implementation — passed in through
139/// [`with_client`](Self::with_client).
140#[derive(Clone, Debug)]
141pub struct SseTransportConfig {
142    /// The SSE endpoint URL to connect to.
143    pub url: String,
144    /// HTTP client used for all requests on this transport. When `None`, the
145    /// factory builds a default reqwest-backed client with agentkit's user agent.
146    pub client: Option<Http>,
147}
148
149impl SseTransportConfig {
150    /// Creates a new SSE transport configuration for the given URL.
151    pub fn new(url: impl Into<String>) -> Self {
152        Self {
153            url: url.into(),
154            client: None,
155        }
156    }
157
158    /// Supplies a pre-configured HTTP client. Use this to attach auth headers
159    /// (via `reqwest::ClientBuilder::default_headers`), install retry/tracing
160    /// middleware, or plug in a non-reqwest backend.
161    pub fn with_client(mut self, client: Http) -> Self {
162        self.client = Some(client);
163        self
164    }
165}
166
167/// Configuration for an MCP server that communicates over Streamable HTTP.
168///
169/// Use this transport for modern remote MCP servers that expose a single HTTP
170/// endpoint supporting JSON-RPC over POST, with optional SSE responses for
171/// streaming server messages.
172///
173/// Auth headers and other per-request customisation live on the [`Http`] client
174/// passed in via [`with_client`](Self::with_client).
175#[derive(Clone, Debug)]
176pub struct StreamableHttpTransportConfig {
177    /// The MCP endpoint URL to connect to.
178    pub url: String,
179    /// HTTP client used for all requests on this transport. When `None`, the
180    /// factory builds a default reqwest-backed client with agentkit's user agent.
181    pub client: Option<Http>,
182}
183
184impl StreamableHttpTransportConfig {
185    /// Creates a new Streamable HTTP transport configuration for the given MCP endpoint.
186    pub fn new(url: impl Into<String>) -> Self {
187        Self {
188            url: url.into(),
189            client: None,
190        }
191    }
192
193    /// Supplies a pre-configured HTTP client. See
194    /// [`SseTransportConfig::with_client`] for the typical use cases.
195    pub fn with_client(mut self, client: Http) -> Self {
196        self.client = Some(client);
197        self
198    }
199}
200
201/// Selects which transport an MCP server should use.
202///
203/// This enum is passed into [`McpServerConfig`] and determines how the client will
204/// communicate with the MCP server. The built-in options are [`Stdio`](Self::Stdio),
205/// [`StreamableHttp`](Self::StreamableHttp), and the legacy [`Sse`](Self::Sse);
206/// use [`Custom`](Self::Custom) to provide your own [`McpTransportFactory`].
207#[derive(Clone)]
208pub enum McpTransportBinding {
209    /// Communicate over the child process's stdin/stdout.
210    Stdio(StdioTransportConfig),
211    /// Communicate over the MCP Streamable HTTP transport.
212    StreamableHttp(StreamableHttpTransportConfig),
213    /// Communicate over HTTP Server-Sent Events.
214    Sse(SseTransportConfig),
215    /// A user-supplied transport factory.
216    Custom(Arc<dyn McpTransportFactory>),
217}
218
219/// Full configuration for a single MCP server, combining an identifier, a transport
220/// binding, and optional metadata.
221///
222/// Register one or more of these with [`McpServerManager`] to manage the lifecycle
223/// of MCP servers in an agentkit runtime.
224///
225/// # Example
226///
227/// ```rust
228/// use agentkit_mcp::{McpServerConfig, McpTransportBinding, StdioTransportConfig};
229///
230/// let config = McpServerConfig::new(
231///     "filesystem",
232///     McpTransportBinding::Stdio(
233///         StdioTransportConfig::new("npx")
234///             .with_arg("-y")
235///             .with_arg("@modelcontextprotocol/server-filesystem"),
236///     ),
237/// );
238/// ```
239#[derive(Clone)]
240pub struct McpServerConfig {
241    /// Unique identifier for this server.
242    pub id: McpServerId,
243    /// Transport binding that determines how communication happens.
244    pub transport: McpTransportBinding,
245    /// Arbitrary metadata attached to this server configuration.
246    pub metadata: MetadataMap,
247}
248
249impl McpServerConfig {
250    /// Creates a new server configuration with the given identifier and transport.
251    ///
252    /// # Arguments
253    ///
254    /// * `id` - A unique name for this server (e.g. `"filesystem"`).
255    /// * `transport` - The [`McpTransportBinding`] that determines how to connect.
256    pub fn new(id: impl Into<String>, transport: McpTransportBinding) -> Self {
257        Self {
258            id: McpServerId::new(id),
259            transport,
260            metadata: MetadataMap::new(),
261        }
262    }
263
264    /// Creates a stdio-backed server configuration.
265    pub fn stdio(id: impl Into<String>, command: impl Into<String>) -> Self {
266        Self::new(
267            id,
268            McpTransportBinding::Stdio(StdioTransportConfig::new(command)),
269        )
270    }
271
272    /// Creates an SSE-backed server configuration.
273    pub fn sse(id: impl Into<String>, url: impl Into<String>) -> Self {
274        Self::new(id, McpTransportBinding::Sse(SseTransportConfig::new(url)))
275    }
276
277    /// Creates a Streamable HTTP-backed server configuration.
278    pub fn streamable_http(id: impl Into<String>, url: impl Into<String>) -> Self {
279        Self::new(
280            id,
281            McpTransportBinding::StreamableHttp(StreamableHttpTransportConfig::new(url)),
282        )
283    }
284
285    /// Replaces the configuration metadata.
286    pub fn with_metadata(mut self, metadata: MetadataMap) -> Self {
287        self.metadata = metadata;
288        self
289    }
290}
291
292/// A single JSON-RPC frame exchanged with an MCP server.
293///
294/// This is the low-level wire unit. Most users will not interact with `McpFrame`
295/// directly; instead use [`McpConnection`] or the higher-level adapters.
296#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
297pub struct McpFrame {
298    /// The raw JSON-RPC value (request, response, or notification).
299    pub value: Value,
300}
301
302/// Factory trait for creating new [`McpTransport`] connections.
303///
304/// Implement this trait to provide a custom transport mechanism. The built-in
305/// [`StdioTransportFactory`] and [`SseTransportFactory`] cover the two standard
306/// MCP transports; use this trait for in-memory, WebSocket, or other custom
307/// transports.
308///
309/// # Errors
310///
311/// Returns [`McpError`] if the connection cannot be established.
312#[async_trait]
313pub trait McpTransportFactory: Send + Sync {
314    /// Establishes a new transport connection and returns it.
315    async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError>;
316}
317
318/// Bidirectional transport for exchanging [`McpFrame`] messages with an MCP server.
319///
320/// Implement this trait to provide a custom transport. Each transport instance
321/// represents a single, live connection.
322///
323/// # Errors
324///
325/// All methods return [`McpError`] on I/O or protocol failures.
326#[async_trait]
327pub trait McpTransport: Send + Sync {
328    /// Sends a JSON-RPC frame to the server.
329    async fn send(&mut self, message: McpFrame) -> Result<(), McpError>;
330    /// Receives the next JSON-RPC frame from the server, or `None` if the stream has ended.
331    async fn recv(&mut self) -> Result<Option<McpFrame>, McpError>;
332    /// Closes the transport, releasing any underlying resources.
333    async fn close(&mut self) -> Result<(), McpError>;
334}
335
336/// Factory that spawns a child process and connects via stdin/stdout.
337///
338/// Created from a [`StdioTransportConfig`]. Each call to
339/// [`connect`](McpTransportFactory::connect) spawns a new child process.
340pub struct StdioTransportFactory {
341    config: StdioTransportConfig,
342}
343
344impl StdioTransportFactory {
345    /// Creates a new factory from the given stdio transport configuration.
346    pub fn new(config: StdioTransportConfig) -> Self {
347        Self { config }
348    }
349}
350
351#[async_trait]
352impl McpTransportFactory for StdioTransportFactory {
353    async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
354        let mut command = Command::new(&self.config.command);
355        command.args(&self.config.args);
356        command.stdin(Stdio::piped());
357        command.stdout(Stdio::piped());
358        command.stderr(Stdio::inherit());
359
360        if let Some(cwd) = &self.config.cwd {
361            command.current_dir(cwd);
362        }
363
364        for (key, value) in &self.config.env {
365            command.env(key, value);
366        }
367
368        let mut child = command.spawn().map_err(McpError::Io)?;
369        let stdin = child
370            .stdin
371            .take()
372            .ok_or_else(|| McpError::Transport("failed to capture MCP stdin".into()))?;
373        let stdout = child
374            .stdout
375            .take()
376            .ok_or_else(|| McpError::Transport("failed to capture MCP stdout".into()))?;
377
378        Ok(Box::new(StdioTransport {
379            child,
380            stdin,
381            stdout: BufReader::new(stdout),
382        }))
383    }
384}
385
386/// Factory that opens an HTTP SSE stream and connects via Server-Sent Events.
387///
388/// Created from an [`SseTransportConfig`]. Each call to
389/// [`connect`](McpTransportFactory::connect) opens a new HTTP connection.
390pub struct SseTransportFactory {
391    config: SseTransportConfig,
392}
393
394impl SseTransportFactory {
395    /// Creates a new factory from the given SSE transport configuration.
396    pub fn new(config: SseTransportConfig) -> Self {
397        Self { config }
398    }
399}
400
401/// Factory that connects to a Streamable HTTP MCP endpoint.
402///
403/// Created from a [`StreamableHttpTransportConfig`]. Each call to
404/// [`connect`](McpTransportFactory::connect) creates a new HTTP-backed MCP session.
405pub struct StreamableHttpTransportFactory {
406    config: StreamableHttpTransportConfig,
407}
408
409impl StreamableHttpTransportFactory {
410    /// Creates a new factory from the given Streamable HTTP transport configuration.
411    pub fn new(config: StreamableHttpTransportConfig) -> Self {
412        Self { config }
413    }
414}
415
416#[async_trait]
417impl McpTransportFactory for SseTransportFactory {
418    async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
419        let client = resolve_http_client(self.config.client.as_ref())?;
420
421        let response = client
422            .get(self.config.url.as_str())
423            .header("Accept", "text/event-stream")
424            .header("Cache-Control", "no-cache")
425            .send()
426            .await?;
427
428        let status = response.status();
429        if !status.is_success() {
430            let body = response
431                .text()
432                .await
433                .unwrap_or_else(|_| "<unreadable response body>".into());
434            return Err(McpError::Transport(format!(
435                "SSE connection failed with status {status}: {body}"
436            )));
437        }
438
439        let response_url = Url::parse(response.url())
440            .map_err(|error| McpError::Transport(format!("invalid SSE response URL: {error}")))?;
441        let stream = response.bytes_stream().map_err(std::io::Error::other);
442        let reader = BufReader::new(StreamReader::new(stream));
443        let (frame_tx, frame_rx) = mpsc::unbounded_channel();
444        let (endpoint_tx, endpoint_rx) = oneshot::channel();
445        let read_task = tokio::spawn(read_sse_stream(reader, response_url, frame_tx, endpoint_tx));
446
447        let endpoint_url = endpoint_rx
448            .await
449            .map_err(|_| McpError::Transport("SSE stream closed before endpoint event".into()))??;
450
451        Ok(Box::new(SseTransport {
452            client,
453            endpoint_url,
454            frame_rx,
455            read_task,
456        }))
457    }
458}
459
460#[async_trait]
461impl McpTransportFactory for StreamableHttpTransportFactory {
462    async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
463        let client = resolve_http_client(self.config.client.as_ref())?;
464
465        let endpoint_url = Url::parse(&self.config.url)
466            .map_err(|error| McpError::Transport(format!("invalid MCP endpoint URL: {error}")))?;
467
468        Ok(Box::new(StreamableHttpTransport {
469            client,
470            endpoint_url,
471            protocol_version: None,
472            session_id: None,
473            pending_frames: VecDeque::new(),
474        }))
475    }
476}
477
478fn resolve_http_client(configured: Option<&Http>) -> Result<Http, McpError> {
479    if let Some(client) = configured {
480        return Ok(client.clone());
481    }
482    #[cfg(feature = "reqwest-client")]
483    {
484        reqwest::Client::builder()
485            .user_agent(concat!("agentkit-mcp/", env!("CARGO_PKG_VERSION")))
486            .build()
487            .map(Http::new)
488            .map_err(|error| McpError::Http(HttpError::request(error)))
489    }
490    #[cfg(not(feature = "reqwest-client"))]
491    {
492        Err(McpError::Transport(
493            "no HTTP client configured; enable the `reqwest-client` feature or supply one via `with_client`".into(),
494        ))
495    }
496}
497
498struct StdioTransport {
499    child: Child,
500    stdin: ChildStdin,
501    stdout: BufReader<ChildStdout>,
502}
503
504struct SseTransport {
505    client: Http,
506    endpoint_url: Url,
507    frame_rx: mpsc::UnboundedReceiver<Result<McpFrame, McpError>>,
508    read_task: JoinHandle<()>,
509}
510
511struct StreamableHttpTransport {
512    client: Http,
513    endpoint_url: Url,
514    protocol_version: Option<String>,
515    session_id: Option<String>,
516    pending_frames: VecDeque<McpFrame>,
517}
518
519#[async_trait]
520impl McpTransport for StdioTransport {
521    async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
522        let mut encoded = serde_json::to_vec(&message.value).map_err(McpError::Serialize)?;
523        encoded.push(b'\n');
524        self.stdin.write_all(&encoded).await.map_err(McpError::Io)?;
525        self.stdin.flush().await.map_err(McpError::Io)?;
526        Ok(())
527    }
528
529    async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
530        let mut line = String::new();
531        let read = self
532            .stdout
533            .read_line(&mut line)
534            .await
535            .map_err(McpError::Io)?;
536        if read == 0 {
537            return Ok(None);
538        }
539
540        let value = serde_json::from_str(line.trim()).map_err(McpError::Serialize)?;
541        Ok(Some(McpFrame { value }))
542    }
543
544    async fn close(&mut self) -> Result<(), McpError> {
545        let _ = self.stdin.shutdown().await;
546        let _ = self.child.kill().await;
547        Ok(())
548    }
549}
550
551#[async_trait]
552impl McpTransport for SseTransport {
553    async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
554        let response = self
555            .client
556            .post(self.endpoint_url.as_str())
557            .header("Content-Type", "application/json")
558            .json(&message.value)
559            .send()
560            .await?;
561        let status = response.status();
562        if !status.is_success() {
563            let body = response
564                .text()
565                .await
566                .unwrap_or_else(|_| "<unreadable response body>".into());
567            return Err(McpError::Transport(format!(
568                "SSE POST failed with status {status}: {body}"
569            )));
570        }
571
572        Ok(())
573    }
574
575    async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
576        match self.frame_rx.recv().await {
577            Some(Ok(frame)) => Ok(Some(frame)),
578            Some(Err(error)) => Err(error),
579            None => Ok(None),
580        }
581    }
582
583    async fn close(&mut self) -> Result<(), McpError> {
584        self.read_task.abort();
585        Ok(())
586    }
587}
588
589#[async_trait]
590impl McpTransport for StreamableHttpTransport {
591    async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
592        let is_request = is_jsonrpc_request(&message.value);
593        let request_id = message.value.get("id").cloned();
594        let is_initialize =
595            message.value.get("method").and_then(Value::as_str) == Some("initialize");
596
597        let mut request = self
598            .client
599            .post(self.endpoint_url.as_str())
600            .header("Content-Type", "application/json")
601            .header("Accept", "application/json, text/event-stream");
602
603        request = apply_streamable_http_headers(
604            request,
605            self.protocol_version.as_deref(),
606            self.session_id.as_deref(),
607        );
608
609        let response = request.json(&message.value).send().await?;
610
611        if is_initialize {
612            self.capture_session_id(response.headers());
613        }
614
615        let status = response.status();
616        if !status.is_success() {
617            return Err(
618                streamable_http_status_error("Streamable HTTP POST", status, response).await,
619            );
620        }
621
622        if !is_request {
623            return Ok(());
624        }
625
626        let content_type = response
627            .headers()
628            .get(http_header::CONTENT_TYPE)
629            .and_then(|value| value.to_str().ok())
630            .unwrap_or_default()
631            .to_string();
632
633        if content_type.starts_with("application/json") {
634            let value: Value = response.json().await?;
635            self.maybe_update_protocol_version(&message.value, &value)?;
636            self.pending_frames.push_back(McpFrame { value });
637            return Ok(());
638        }
639
640        if !content_type.starts_with("text/event-stream") {
641            let body = response
642                .text()
643                .await
644                .unwrap_or_else(|_| "<unreadable response body>".into());
645            return Err(McpError::Transport(format!(
646                "unexpected Streamable HTTP response content type {content_type:?}: {body}"
647            )));
648        }
649
650        let request_id = request_id.ok_or_else(|| {
651            McpError::Protocol("JSON-RPC request over Streamable HTTP is missing an id".into())
652        })?;
653        self.collect_streamable_http_response(response, &message.value, &request_id)
654            .await
655    }
656
657    async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
658        Ok(self.pending_frames.pop_front())
659    }
660
661    async fn close(&mut self) -> Result<(), McpError> {
662        let Some(session_id) = self.session_id.clone() else {
663            return Ok(());
664        };
665
666        let mut request = self.client.delete(self.endpoint_url.as_str());
667        request = apply_streamable_http_headers(
668            request,
669            self.protocol_version.as_deref(),
670            Some(session_id.as_str()),
671        );
672
673        let response = request.send().await?;
674        if response.status().is_success()
675            || response.status() == StatusCode::METHOD_NOT_ALLOWED
676            || response.status() == StatusCode::NOT_FOUND
677        {
678            self.session_id = None;
679            return Ok(());
680        }
681
682        Err(
683            streamable_http_status_error("Streamable HTTP DELETE", response.status(), response)
684                .await,
685        )
686    }
687}
688
689impl StreamableHttpTransport {
690    async fn collect_streamable_http_response(
691        &mut self,
692        response: HttpResponse,
693        request_message: &Value,
694        request_id: &Value,
695    ) -> Result<(), McpError> {
696        let mut retry_delay = Duration::from_millis(0);
697        let mut last_event_id = None;
698        let mut saw_response = false;
699
700        saw_response |= self
701            .read_streamable_http_events(
702                response,
703                request_message,
704                request_id,
705                &mut last_event_id,
706                &mut retry_delay,
707            )
708            .await?;
709
710        while !saw_response && last_event_id.is_some() {
711            if !retry_delay.is_zero() {
712                sleep(retry_delay).await;
713            }
714
715            let response = self
716                .resume_streamable_http_stream(last_event_id.as_deref().unwrap())
717                .await?;
718            saw_response |= self
719                .read_streamable_http_events(
720                    response,
721                    request_message,
722                    request_id,
723                    &mut last_event_id,
724                    &mut retry_delay,
725                )
726                .await?;
727        }
728
729        Ok(())
730    }
731
732    async fn read_streamable_http_events(
733        &mut self,
734        response: HttpResponse,
735        request_message: &Value,
736        request_id: &Value,
737        last_event_id: &mut Option<String>,
738        retry_delay: &mut Duration,
739    ) -> Result<bool, McpError> {
740        let stream = response.bytes_stream().map_err(std::io::Error::other);
741        let mut reader = BufReader::new(StreamReader::new(stream));
742        let mut saw_response = false;
743
744        while let Some(event) = read_next_sse_event(&mut reader).await? {
745            if let Some(id) = event.id.clone() {
746                *last_event_id = Some(id);
747            }
748            if let Some(retry_ms) = event.retry_ms {
749                *retry_delay = Duration::from_millis(retry_ms);
750            }
751
752            let Some(frame) = streamable_http_event_to_frame(event)? else {
753                continue;
754            };
755
756            self.maybe_update_protocol_version(request_message, &frame.value)?;
757            if frame.value.get("id") == Some(request_id) {
758                saw_response = true;
759            }
760            self.pending_frames.push_back(frame);
761        }
762
763        Ok(saw_response)
764    }
765
766    async fn resume_streamable_http_stream(
767        &self,
768        last_event_id: &str,
769    ) -> Result<HttpResponse, McpError> {
770        let mut request = self
771            .client
772            .get(self.endpoint_url.as_str())
773            .header("Accept", "text/event-stream")
774            .header("Cache-Control", "no-cache")
775            .header("Last-Event-ID", last_event_id);
776
777        request = apply_streamable_http_headers(
778            request,
779            self.protocol_version.as_deref(),
780            self.session_id.as_deref(),
781        );
782
783        let response = request.send().await?;
784        let status = response.status();
785        if !status.is_success() {
786            return Err(
787                streamable_http_status_error("Streamable HTTP GET", status, response).await,
788            );
789        }
790
791        let content_type = response
792            .headers()
793            .get(http_header::CONTENT_TYPE)
794            .and_then(|value| value.to_str().ok())
795            .unwrap_or_default();
796        if !content_type.starts_with("text/event-stream") {
797            let content_type = content_type.to_string();
798            let body = response
799                .text()
800                .await
801                .unwrap_or_else(|_| "<unreadable response body>".into());
802            return Err(McpError::Transport(format!(
803                "Streamable HTTP GET expected text/event-stream, got {content_type:?}: {body}"
804            )));
805        }
806
807        Ok(response)
808    }
809
810    fn maybe_update_protocol_version(
811        &mut self,
812        request_message: &Value,
813        response_value: &Value,
814    ) -> Result<(), McpError> {
815        if request_message.get("method").and_then(Value::as_str) != Some("initialize") {
816            return Ok(());
817        }
818
819        let protocol_version = response_value
820            .get("result")
821            .and_then(|result| result.get("protocolVersion"))
822            .and_then(Value::as_str);
823
824        if let Some(protocol_version) = protocol_version {
825            self.protocol_version = Some(protocol_version.to_string());
826        }
827
828        Ok(())
829    }
830
831    fn capture_session_id(&mut self, headers: &HeaderMap) {
832        self.session_id = headers
833            .get("MCP-Session-Id")
834            .and_then(|value| value.to_str().ok())
835            .map(|value| value.to_string());
836    }
837}
838
839/// Descriptor for a tool advertised by an MCP server.
840///
841/// Returned as part of a [`McpDiscoverySnapshot`] after server discovery. The
842/// [`input_schema`](Self::input_schema) field is the JSON Schema that describes
843/// the tool's expected input.
844#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
845pub struct McpToolDescriptor {
846    /// The tool name as reported by the MCP server.
847    pub name: String,
848    /// Optional human-readable description of the tool.
849    pub description: Option<String>,
850    /// JSON Schema describing the tool's input parameters.
851    pub input_schema: Value,
852    /// Arbitrary metadata attached to this descriptor.
853    pub metadata: MetadataMap,
854}
855
856/// Descriptor for a resource advertised by an MCP server.
857///
858/// Resources represent data that the server can provide (e.g. files, database
859/// records). Each resource is identified by a URI.
860#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
861pub struct McpResourceDescriptor {
862    /// The resource URI (e.g. `"file:///tmp/example.txt"`).
863    pub id: String,
864    /// Human-readable name of the resource.
865    pub name: String,
866    /// Optional description of the resource.
867    pub description: Option<String>,
868    /// Optional MIME type (e.g. `"text/plain"`, `"application/json"`).
869    pub mime_type: Option<String>,
870    /// Arbitrary metadata attached to this descriptor.
871    pub metadata: MetadataMap,
872}
873
874/// Descriptor for a prompt template advertised by an MCP server.
875///
876/// Prompts are reusable message templates that can be parameterized with arguments.
877/// The [`input_schema`](Self::input_schema) describes the expected arguments.
878#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
879pub struct McpPromptDescriptor {
880    /// Unique identifier for the prompt (typically the same as `name`).
881    pub id: String,
882    /// Human-readable name of the prompt.
883    pub name: String,
884    /// Optional description of what the prompt does.
885    pub description: Option<String>,
886    /// JSON Schema describing the prompt's input arguments.
887    pub input_schema: Value,
888    /// Arbitrary metadata attached to this descriptor.
889    pub metadata: MetadataMap,
890}
891
892/// A snapshot of all capabilities discovered from a single MCP server.
893///
894/// Obtained by calling [`McpConnection::discover`] or as part of a
895/// [`McpServerHandle`]. Contains the full list of tools, resources, and prompts
896/// that the server advertised at discovery time.
897#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
898pub struct McpDiscoverySnapshot {
899    /// The server this snapshot was taken from.
900    pub server_id: McpServerId,
901    /// Tools advertised by the server.
902    pub tools: Vec<McpToolDescriptor>,
903    /// Resources advertised by the server.
904    pub resources: Vec<McpResourceDescriptor>,
905    /// Prompts advertised by the server.
906    pub prompts: Vec<McpPromptDescriptor>,
907    /// Arbitrary metadata attached to this snapshot.
908    pub metadata: MetadataMap,
909}
910
911/// A live connection to a single MCP server.
912///
913/// Handles JSON-RPC request/response framing, automatic auth enrichment, and
914/// high-level methods for tool calls, resource reads, prompt retrieval, and
915/// capability discovery.
916///
917/// Create a connection with [`McpConnection::connect`] or indirectly through
918/// [`McpServerManager::connect_server`].
919///
920/// # Example
921///
922/// ```rust,no_run
923/// use agentkit_mcp::{McpConnection, McpServerConfig, McpTransportBinding, StdioTransportConfig};
924///
925/// # #[tokio::main]
926/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
927/// let config = McpServerConfig::new(
928///     "filesystem",
929///     McpTransportBinding::Stdio(StdioTransportConfig::new("npx")
930///         .with_arg("-y")
931///         .with_arg("@modelcontextprotocol/server-filesystem")),
932/// );
933///
934/// let connection = McpConnection::connect(&config).await?;
935/// let snapshot = connection.discover().await?;
936/// println!("found {} tools", snapshot.tools.len());
937/// # Ok(())
938/// # }
939/// ```
940pub struct McpConnection {
941    server_id: McpServerId,
942    transport: Mutex<Box<dyn McpTransport>>,
943    auth: Mutex<Option<MetadataMap>>,
944    next_id: AtomicU64,
945}
946
947/// The result of replaying an MCP operation after auth resolution.
948///
949/// Returned by [`McpConnection::replay_auth_operation`] and
950/// [`McpServerManager::resolve_auth_and_resume`].
951#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
952pub enum McpOperationResult {
953    /// The server was successfully (re)connected; contains the discovery snapshot.
954    Connected(McpDiscoverySnapshot),
955    /// A tool call completed; contains the raw JSON result.
956    Tool(Value),
957    /// A resource was read successfully.
958    Resource(ResourceContents),
959    /// A prompt was retrieved successfully.
960    Prompt(PromptContents),
961}
962
963impl McpConnection {
964    /// Connects to an MCP server, performs the JSON-RPC `initialize` handshake, and
965    /// returns a ready-to-use connection.
966    ///
967    /// # Errors
968    ///
969    /// Returns [`McpError`] if the transport fails to connect, the handshake is
970    /// rejected, or the server requires authentication ([`McpError::AuthRequired`]).
971    pub async fn connect(config: &McpServerConfig) -> Result<Self, McpError> {
972        Self::connect_with_auth(config, None).await
973    }
974
975    async fn connect_with_auth(
976        config: &McpServerConfig,
977        auth: Option<&MetadataMap>,
978    ) -> Result<Self, McpError> {
979        let factory: Arc<dyn McpTransportFactory> = match &config.transport {
980            McpTransportBinding::Stdio(binding) => {
981                Arc::new(StdioTransportFactory::new(binding.clone()))
982            }
983            McpTransportBinding::StreamableHttp(binding) => {
984                Arc::new(StreamableHttpTransportFactory::new(binding.clone()))
985            }
986            McpTransportBinding::Sse(binding) => {
987                Arc::new(SseTransportFactory::new(binding.clone()))
988            }
989            McpTransportBinding::Custom(factory) => factory.clone(),
990        };
991
992        let mut transport = factory.connect().await?;
993        let mut params = serde_json::Map::new();
994        params.insert(
995            "protocolVersion".into(),
996            Value::String(MCP_LATEST_PROTOCOL_VERSION.into()),
997        );
998        params.insert("capabilities".into(), json!({}));
999        params.insert(
1000            "clientInfo".into(),
1001            json!({
1002                "name": "agentkit-mcp",
1003                "version": env!("CARGO_PKG_VERSION")
1004            }),
1005        );
1006        if let Some(auth) = auth {
1007            params.insert("auth".into(), metadata_to_value(auth));
1008        }
1009        let init_params = Value::Object(params.clone());
1010        transport
1011            .send(McpFrame {
1012                value: json!({
1013                    "jsonrpc": "2.0",
1014                    "id": 0,
1015                    "method": "initialize",
1016                    "params": init_params.clone()
1017                }),
1018            })
1019            .await?;
1020        let init_response = transport.recv().await?.ok_or_else(|| {
1021            McpError::Transport("transport closed during MCP initialization".into())
1022        })?;
1023        if let Some(error) = init_response.value.get("error") {
1024            if let Some(auth_request) =
1025                parse_auth_request(&config.id, "initialize", &init_params, error)
1026            {
1027                return Err(McpError::AuthRequired(Box::new(auth_request)));
1028            }
1029            return Err(McpError::Invocation(error.to_string()));
1030        }
1031        let negotiated_protocol_version = init_response
1032            .value
1033            .get("result")
1034            .and_then(|result| result.get("protocolVersion"))
1035            .and_then(Value::as_str)
1036            .ok_or_else(|| {
1037                McpError::Protocol("initialize response missing result.protocolVersion".into())
1038            })?;
1039        if !MCP_SUPPORTED_PROTOCOL_VERSIONS.contains(&negotiated_protocol_version) {
1040            return Err(McpError::Protocol(format!(
1041                "unsupported MCP protocol version negotiated during initialize: {negotiated_protocol_version}"
1042            )));
1043        }
1044        transport
1045            .send(McpFrame {
1046                value: json!({
1047                    "jsonrpc": "2.0",
1048                    "method": "notifications/initialized",
1049                    "params": {}
1050                }),
1051            })
1052            .await?;
1053
1054        Ok(Self {
1055            server_id: config.id.clone(),
1056            transport: Mutex::new(transport),
1057            auth: Mutex::new(auth.cloned()),
1058            next_id: AtomicU64::new(1),
1059        })
1060    }
1061
1062    /// Returns the [`McpServerId`] for this connection.
1063    pub fn server_id(&self) -> &McpServerId {
1064        &self.server_id
1065    }
1066
1067    /// Closes the underlying transport, shutting down the connection to the server.
1068    ///
1069    /// # Errors
1070    ///
1071    /// Returns [`McpError`] if the transport cannot be closed cleanly.
1072    pub async fn close(&self) -> Result<(), McpError> {
1073        let mut transport = self.transport.lock().await;
1074        transport.close().await
1075    }
1076
1077    /// Stores or clears authentication credentials for future requests on this
1078    /// connection.
1079    ///
1080    /// After calling this method with [`AuthResolution::Provided`], every subsequent
1081    /// JSON-RPC request will include the credentials in an `auth` field.
1082    ///
1083    /// # Errors
1084    ///
1085    /// Returns [`McpError`] if the resolution cannot be applied.
1086    pub async fn resolve_auth(&self, resolution: AuthResolution) -> Result<(), McpError> {
1087        let mut auth = self.auth.lock().await;
1088        match resolution {
1089            AuthResolution::Provided { credentials, .. } => {
1090                *auth = Some(credentials);
1091            }
1092            AuthResolution::Cancelled { .. } => {
1093                *auth = None;
1094            }
1095        }
1096        Ok(())
1097    }
1098
1099    /// Performs full capability discovery by listing tools, resources, and prompts.
1100    ///
1101    /// Returns an [`McpDiscoverySnapshot`] containing everything the server advertises.
1102    ///
1103    /// # Errors
1104    ///
1105    /// Returns [`McpError`] if any of the list requests fail.
1106    pub async fn discover(&self) -> Result<McpDiscoverySnapshot, McpError> {
1107        Ok(McpDiscoverySnapshot {
1108            server_id: self.server_id.clone(),
1109            tools: self.list_tools().await?,
1110            resources: self.list_resources().await?,
1111            prompts: self.list_prompts().await?,
1112            metadata: MetadataMap::new(),
1113        })
1114    }
1115
1116    /// Lists all tools advertised by the connected MCP server.
1117    ///
1118    /// # Errors
1119    ///
1120    /// Returns [`McpError`] if the `tools/list` request fails.
1121    pub async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>, McpError> {
1122        let result = self.request("tools/list", json!({})).await?;
1123        result
1124            .get("tools")
1125            .and_then(Value::as_array)
1126            .cloned()
1127            .unwrap_or_default()
1128            .into_iter()
1129            .map(parse_tool_descriptor)
1130            .collect()
1131    }
1132
1133    /// Lists all resources advertised by the connected MCP server.
1134    ///
1135    /// # Errors
1136    ///
1137    /// Returns [`McpError`] if the `resources/list` request fails.
1138    pub async fn list_resources(&self) -> Result<Vec<McpResourceDescriptor>, McpError> {
1139        let result = self.request("resources/list", json!({})).await?;
1140        result
1141            .get("resources")
1142            .and_then(Value::as_array)
1143            .cloned()
1144            .unwrap_or_default()
1145            .into_iter()
1146            .map(parse_resource_descriptor)
1147            .collect()
1148    }
1149
1150    /// Lists all prompts advertised by the connected MCP server.
1151    ///
1152    /// # Errors
1153    ///
1154    /// Returns [`McpError`] if the `prompts/list` request fails.
1155    pub async fn list_prompts(&self) -> Result<Vec<McpPromptDescriptor>, McpError> {
1156        let result = self.request("prompts/list", json!({})).await?;
1157        result
1158            .get("prompts")
1159            .and_then(Value::as_array)
1160            .cloned()
1161            .unwrap_or_default()
1162            .into_iter()
1163            .map(parse_prompt_descriptor)
1164            .collect()
1165    }
1166
1167    /// Invokes a tool on the MCP server and returns the raw JSON result.
1168    ///
1169    /// # Arguments
1170    ///
1171    /// * `name` - The tool name as it appears in the server's tool list.
1172    /// * `arguments` - A JSON value matching the tool's input schema.
1173    ///
1174    /// # Errors
1175    ///
1176    /// Returns [`McpError::AuthRequired`] if the server demands authentication,
1177    /// or another [`McpError`] variant on transport or protocol failures.
1178    pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value, McpError> {
1179        self.request(
1180            "tools/call",
1181            json!({
1182                "name": name,
1183                "arguments": arguments,
1184            }),
1185        )
1186        .await
1187    }
1188
1189    /// Reads a resource from the MCP server by URI.
1190    ///
1191    /// # Arguments
1192    ///
1193    /// * `uri` - The resource URI (e.g. `"file:///tmp/example.txt"`).
1194    ///
1195    /// # Errors
1196    ///
1197    /// Returns [`McpError`] if the resource cannot be read or the response is malformed.
1198    pub async fn read_resource(&self, uri: &str) -> Result<ResourceContents, McpError> {
1199        let result = self
1200            .request(
1201                "resources/read",
1202                json!({
1203                    "uri": uri,
1204                }),
1205            )
1206            .await?;
1207        let content = result
1208            .get("contents")
1209            .and_then(Value::as_array)
1210            .and_then(|values| values.first())
1211            .cloned()
1212            .ok_or_else(|| McpError::Protocol("resources/read returned no contents".into()))?;
1213
1214        let data = if let Some(text) = content.get("text").and_then(Value::as_str) {
1215            DataRef::InlineText(text.into())
1216        } else if let Some(found_uri) = content.get("uri").and_then(Value::as_str) {
1217            DataRef::Uri(found_uri.into())
1218        } else {
1219            return Err(McpError::Protocol(
1220                "unsupported resource content shape".into(),
1221            ));
1222        };
1223
1224        Ok(ResourceContents {
1225            data,
1226            metadata: MetadataMap::new(),
1227        })
1228    }
1229
1230    /// Retrieves a prompt from the MCP server, rendering it with the given arguments.
1231    ///
1232    /// # Arguments
1233    ///
1234    /// * `name` - The prompt name as it appears in the server's prompt list.
1235    /// * `arguments` - A JSON value containing the prompt's input arguments.
1236    ///
1237    /// # Errors
1238    ///
1239    /// Returns [`McpError`] if the prompt cannot be retrieved or the response is malformed.
1240    pub async fn get_prompt(
1241        &self,
1242        name: &str,
1243        arguments: Value,
1244    ) -> Result<PromptContents, McpError> {
1245        let result = self
1246            .request(
1247                "prompts/get",
1248                json!({
1249                    "name": name,
1250                    "arguments": arguments,
1251                }),
1252            )
1253            .await?;
1254        let items = result
1255            .get("messages")
1256            .and_then(Value::as_array)
1257            .cloned()
1258            .unwrap_or_default()
1259            .into_iter()
1260            .map(parse_prompt_message)
1261            .collect::<Result<Vec<_>, _>>()?;
1262
1263        Ok(PromptContents {
1264            items,
1265            metadata: MetadataMap::new(),
1266        })
1267    }
1268
1269    async fn request(&self, method: &str, params: Value) -> Result<Value, McpError> {
1270        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
1271        let params = self.enrich_params(params.clone()).await;
1272        let mut transport = self.transport.lock().await;
1273        transport
1274            .send(McpFrame {
1275                value: json!({
1276                    "jsonrpc": "2.0",
1277                    "id": id,
1278                    "method": method,
1279                    "params": params,
1280                }),
1281            })
1282            .await?;
1283
1284        loop {
1285            let Some(frame) = transport.recv().await? else {
1286                return Err(McpError::Transport(
1287                    "transport closed while waiting for MCP response".into(),
1288                ));
1289            };
1290
1291            if frame.value.get("id").and_then(Value::as_u64) != Some(id) {
1292                continue;
1293            }
1294
1295            if let Some(error) = frame.value.get("error") {
1296                if let Some(auth_request) =
1297                    parse_auth_request(&self.server_id, method, &params, error)
1298                {
1299                    return Err(McpError::AuthRequired(Box::new(auth_request)));
1300                }
1301                return Err(McpError::Invocation(error.to_string()));
1302            }
1303
1304            return frame
1305                .value
1306                .get("result")
1307                .cloned()
1308                .ok_or_else(|| McpError::Protocol("MCP response missing result".into()));
1309        }
1310    }
1311
1312    async fn enrich_params(&self, params: Value) -> Value {
1313        let auth = self.auth.lock().await;
1314        let Some(auth) = auth.as_ref() else {
1315            return params;
1316        };
1317
1318        match params {
1319            Value::Object(mut object) => {
1320                object
1321                    .entry("auth")
1322                    .or_insert_with(|| metadata_to_value(auth));
1323                Value::Object(object)
1324            }
1325            other => other,
1326        }
1327    }
1328
1329    /// Replays an MCP operation that previously failed with an auth challenge.
1330    ///
1331    /// This is called after credentials have been resolved via [`resolve_auth`](Self::resolve_auth).
1332    /// The operation is re-issued with the stored credentials attached.
1333    ///
1334    /// # Errors
1335    ///
1336    /// Returns [`McpError::AuthResolution`] if the operation targets a different server,
1337    /// or other [`McpError`] variants if the replayed operation itself fails.
1338    pub async fn replay_auth_operation(
1339        &self,
1340        operation: &AuthOperation,
1341    ) -> Result<McpOperationResult, McpError> {
1342        match operation {
1343            AuthOperation::McpToolCall {
1344                server_id,
1345                tool_name,
1346                input,
1347                ..
1348            } => {
1349                self.ensure_server_match(server_id)?;
1350                self.call_tool(tool_name, input.clone())
1351                    .await
1352                    .map(McpOperationResult::Tool)
1353            }
1354            AuthOperation::McpResourceRead {
1355                server_id,
1356                resource_id,
1357                ..
1358            } => {
1359                self.ensure_server_match(server_id)?;
1360                self.read_resource(resource_id)
1361                    .await
1362                    .map(McpOperationResult::Resource)
1363            }
1364            AuthOperation::McpPromptGet {
1365                server_id,
1366                prompt_id,
1367                args,
1368                ..
1369            } => {
1370                self.ensure_server_match(server_id)?;
1371                self.get_prompt(prompt_id, args.clone())
1372                    .await
1373                    .map(McpOperationResult::Prompt)
1374            }
1375            AuthOperation::ToolCall {
1376                tool_name,
1377                input,
1378                metadata,
1379                ..
1380            } => {
1381                if let Some(server_id) = metadata.get("server_id").and_then(Value::as_str) {
1382                    self.ensure_server_match(server_id)?;
1383                }
1384                let tool_name = normalize_mcp_tool_name(self.server_id(), tool_name);
1385                self.call_tool(&tool_name, input.clone())
1386                    .await
1387                    .map(McpOperationResult::Tool)
1388            }
1389            AuthOperation::McpConnect { .. } => Err(McpError::AuthResolution(
1390                "connect operations must be replayed through the server manager".into(),
1391            )),
1392            AuthOperation::Custom { kind, .. } => Err(McpError::AuthResolution(format!(
1393                "unsupported auth operation for replay: {kind}"
1394            ))),
1395        }
1396    }
1397
1398    fn ensure_server_match(&self, server_id: &str) -> Result<(), McpError> {
1399        if self.server_id.0 == server_id {
1400            Ok(())
1401        } else {
1402            Err(McpError::AuthResolution(format!(
1403                "auth operation targets server {server_id}, but connection is for {}",
1404                self.server_id
1405            )))
1406        }
1407    }
1408}
1409
1410/// Adapter that exposes an MCP tool as an [`Invocable`] for the capabilities system.
1411///
1412/// This is the capabilities-layer adapter. For the tool-layer adapter, see
1413/// [`McpToolAdapter`]. Names are prefixed with `mcp_<server_id>_<tool_name>`
1414/// so they satisfy provider validators that only allow `[a-zA-Z0-9_-]`
1415/// (e.g. Anthropic on Vertex).
1416pub struct McpInvocable {
1417    connection: Arc<McpConnection>,
1418    descriptor: McpToolDescriptor,
1419    spec: InvocableSpec,
1420}
1421
1422impl McpInvocable {
1423    /// Creates a new invocable adapter for the given MCP tool.
1424    ///
1425    /// # Arguments
1426    ///
1427    /// * `connection` - A shared connection to the MCP server that owns the tool.
1428    /// * `descriptor` - The tool descriptor obtained from discovery.
1429    pub fn new(connection: Arc<McpConnection>, descriptor: McpToolDescriptor) -> Self {
1430        let spec = InvocableSpec {
1431            name: CapabilityName::new(format!(
1432                "mcp_{}_{}",
1433                connection.server_id(),
1434                descriptor.name
1435            )),
1436            description: descriptor
1437                .description
1438                .clone()
1439                .unwrap_or_else(|| descriptor.name.clone()),
1440            input_schema: descriptor.input_schema.clone(),
1441            metadata: descriptor.metadata.clone(),
1442        };
1443
1444        Self {
1445            connection,
1446            descriptor,
1447            spec,
1448        }
1449    }
1450}
1451
1452#[async_trait]
1453impl Invocable for McpInvocable {
1454    fn spec(&self) -> &InvocableSpec {
1455        &self.spec
1456    }
1457
1458    async fn invoke(
1459        &self,
1460        request: InvocableRequest,
1461        _ctx: &mut CapabilityContext<'_>,
1462    ) -> Result<InvocableResult, CapabilityError> {
1463        let result = self
1464            .connection
1465            .call_tool(&self.descriptor.name, request.input)
1466            .await
1467            .map_err(|error| match error {
1468                McpError::AuthRequired(request) => {
1469                    CapabilityError::Unavailable(format!("auth required: {:?}", request))
1470                }
1471                other => CapabilityError::ExecutionFailed(other.to_string()),
1472            })?;
1473
1474        Ok(InvocableResult {
1475            output: value_to_invocable_output(result),
1476            metadata: MetadataMap::new(),
1477        })
1478    }
1479}
1480
1481/// Adapter that exposes a single MCP resource as a [`ResourceProvider`].
1482///
1483/// Created automatically by [`McpCapabilityProvider::from_snapshot`] for each
1484/// resource discovered on the server.
1485pub struct McpResourceHandle {
1486    connection: Arc<McpConnection>,
1487    descriptor: ResourceDescriptor,
1488}
1489
1490#[async_trait]
1491impl ResourceProvider for McpResourceHandle {
1492    async fn list_resources(&self) -> Result<Vec<ResourceDescriptor>, CapabilityError> {
1493        Ok(vec![self.descriptor.clone()])
1494    }
1495
1496    async fn read_resource(
1497        &self,
1498        id: &ResourceId,
1499        _ctx: &mut CapabilityContext<'_>,
1500    ) -> Result<ResourceContents, CapabilityError> {
1501        self.connection
1502            .read_resource(&id.0)
1503            .await
1504            .map_err(|error| match error {
1505                McpError::AuthRequired(request) => {
1506                    CapabilityError::Unavailable(format!("auth required: {:?}", request))
1507                }
1508                other => CapabilityError::ExecutionFailed(other.to_string()),
1509            })
1510    }
1511}
1512
1513/// Adapter that exposes a single MCP prompt as a [`PromptProvider`].
1514///
1515/// Created automatically by [`McpCapabilityProvider::from_snapshot`] for each
1516/// prompt discovered on the server.
1517pub struct McpPromptHandle {
1518    connection: Arc<McpConnection>,
1519    descriptor: PromptDescriptor,
1520}
1521
1522#[async_trait]
1523impl PromptProvider for McpPromptHandle {
1524    async fn list_prompts(&self) -> Result<Vec<PromptDescriptor>, CapabilityError> {
1525        Ok(vec![self.descriptor.clone()])
1526    }
1527
1528    async fn get_prompt(
1529        &self,
1530        id: &PromptId,
1531        args: Value,
1532        _ctx: &mut CapabilityContext<'_>,
1533    ) -> Result<PromptContents, CapabilityError> {
1534        self.connection
1535            .get_prompt(&id.0, args)
1536            .await
1537            .map_err(|error| match error {
1538                McpError::AuthRequired(request) => {
1539                    CapabilityError::Unavailable(format!("auth required: {:?}", request))
1540                }
1541                other => CapabilityError::ExecutionFailed(other.to_string()),
1542            })
1543    }
1544}
1545
1546/// A [`CapabilityProvider`] that surfaces MCP tools, resources, and prompts into the
1547/// agentkit capabilities system.
1548///
1549/// Built from a discovery snapshot, this provider wraps each MCP tool as an
1550/// [`McpInvocable`], each resource as an [`McpResourceHandle`], and each prompt as
1551/// an [`McpPromptHandle`].
1552///
1553/// # Example
1554///
1555/// ```rust,no_run
1556/// use std::sync::Arc;
1557/// use agentkit_mcp::{McpCapabilityProvider, McpServerConfig, McpTransportBinding, StdioTransportConfig};
1558///
1559/// # #[tokio::main]
1560/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
1561/// let config = McpServerConfig::new(
1562///     "filesystem",
1563///     McpTransportBinding::Stdio(StdioTransportConfig::new("npx")
1564///         .with_arg("-y")
1565///         .with_arg("@modelcontextprotocol/server-filesystem")),
1566/// );
1567/// let (connection, provider, snapshot) = McpCapabilityProvider::connect(&config).await?;
1568/// // `provider` implements CapabilityProvider and can be registered with an agent.
1569/// # Ok(())
1570/// # }
1571/// ```
1572pub struct McpCapabilityProvider {
1573    invocables: Vec<Arc<dyn Invocable>>,
1574    resources: Vec<Arc<dyn ResourceProvider>>,
1575    prompts: Vec<Arc<dyn PromptProvider>>,
1576}
1577
1578impl McpCapabilityProvider {
1579    /// Creates a capability provider from an existing connection and its discovery
1580    /// snapshot.
1581    ///
1582    /// Each tool, resource, and prompt in the snapshot is wrapped in the appropriate
1583    /// adapter type.
1584    pub fn from_snapshot(connection: Arc<McpConnection>, snapshot: &McpDiscoverySnapshot) -> Self {
1585        let invocables = snapshot
1586            .tools
1587            .iter()
1588            .cloned()
1589            .map(|descriptor| {
1590                Arc::new(McpInvocable::new(connection.clone(), descriptor)) as Arc<dyn Invocable>
1591            })
1592            .collect();
1593
1594        let resources = snapshot
1595            .resources
1596            .iter()
1597            .cloned()
1598            .map(|descriptor| {
1599                Arc::new(McpResourceHandle {
1600                    connection: connection.clone(),
1601                    descriptor: ResourceDescriptor {
1602                        id: ResourceId::new(descriptor.id),
1603                        name: descriptor.name,
1604                        description: descriptor.description,
1605                        mime_type: descriptor.mime_type,
1606                        metadata: descriptor.metadata,
1607                    },
1608                }) as Arc<dyn ResourceProvider>
1609            })
1610            .collect();
1611
1612        let prompts = snapshot
1613            .prompts
1614            .iter()
1615            .cloned()
1616            .map(|descriptor| {
1617                Arc::new(McpPromptHandle {
1618                    connection: connection.clone(),
1619                    descriptor: PromptDescriptor {
1620                        id: PromptId::new(descriptor.id),
1621                        name: descriptor.name,
1622                        description: descriptor.description,
1623                        input_schema: descriptor.input_schema,
1624                        metadata: descriptor.metadata,
1625                    },
1626                }) as Arc<dyn PromptProvider>
1627            })
1628            .collect();
1629
1630        Self {
1631            invocables,
1632            resources,
1633            prompts,
1634        }
1635    }
1636
1637    /// Merges multiple capability providers into a single provider.
1638    ///
1639    /// This is useful when managing several MCP servers through a
1640    /// [`McpServerManager`] and you want one combined provider for the agent.
1641    pub fn merge<I>(providers: I) -> Self
1642    where
1643        I: IntoIterator<Item = Self>,
1644    {
1645        let mut invocables = Vec::new();
1646        let mut resources = Vec::new();
1647        let mut prompts = Vec::new();
1648
1649        for provider in providers {
1650            invocables.extend(provider.invocables);
1651            resources.extend(provider.resources);
1652            prompts.extend(provider.prompts);
1653        }
1654
1655        Self {
1656            invocables,
1657            resources,
1658            prompts,
1659        }
1660    }
1661
1662    /// Connects to an MCP server, performs discovery, and builds a capability
1663    /// provider in one step.
1664    ///
1665    /// Returns the shared connection, the provider, and the discovery snapshot.
1666    ///
1667    /// # Errors
1668    ///
1669    /// Returns [`McpError`] if connection or discovery fails.
1670    pub async fn connect(
1671        config: &McpServerConfig,
1672    ) -> Result<(Arc<McpConnection>, Self, McpDiscoverySnapshot), McpError> {
1673        let connection = Arc::new(McpConnection::connect(config).await?);
1674        let snapshot = connection.discover().await?;
1675        let provider = Self::from_snapshot(connection.clone(), &snapshot);
1676
1677        Ok((connection, provider, snapshot))
1678    }
1679}
1680
1681impl CapabilityProvider for McpCapabilityProvider {
1682    fn invocables(&self) -> Vec<Arc<dyn Invocable>> {
1683        self.invocables.clone()
1684    }
1685
1686    fn resources(&self) -> Vec<Arc<dyn ResourceProvider>> {
1687        self.resources.clone()
1688    }
1689
1690    fn prompts(&self) -> Vec<Arc<dyn PromptProvider>> {
1691        self.prompts.clone()
1692    }
1693}
1694
1695/// A connected MCP server together with its configuration and discovery snapshot.
1696///
1697/// Obtained from [`McpServerManager::connect_server`] or
1698/// [`McpServerManager::connect_all`]. Provides convenience methods to create
1699/// tool registries and capability providers from the server's discovered capabilities.
1700#[derive(Clone)]
1701pub struct McpServerHandle {
1702    config: McpServerConfig,
1703    connection: Arc<McpConnection>,
1704    snapshot: McpDiscoverySnapshot,
1705}
1706
1707impl McpServerHandle {
1708    /// Returns the original configuration used to connect this server.
1709    pub fn config(&self) -> &McpServerConfig {
1710        &self.config
1711    }
1712
1713    /// Returns the server's unique identifier.
1714    pub fn server_id(&self) -> &McpServerId {
1715        self.connection.server_id()
1716    }
1717
1718    /// Returns a shared reference to the underlying [`McpConnection`].
1719    pub fn connection(&self) -> Arc<McpConnection> {
1720        self.connection.clone()
1721    }
1722
1723    /// Returns the discovery snapshot captured when the server was connected.
1724    pub fn snapshot(&self) -> &McpDiscoverySnapshot {
1725        &self.snapshot
1726    }
1727
1728    /// Builds a [`ToolRegistry`] containing an [`McpToolAdapter`] for each tool
1729    /// discovered on this server.
1730    pub fn tool_registry(&self) -> ToolRegistry {
1731        self.snapshot
1732            .tools
1733            .iter()
1734            .cloned()
1735            .fold(ToolRegistry::new(), |registry, descriptor| {
1736                registry.with(McpToolAdapter::new(
1737                    self.server_id(),
1738                    self.connection.clone(),
1739                    descriptor,
1740                ))
1741            })
1742    }
1743
1744    /// Builds an [`McpCapabilityProvider`] from this server's discovery snapshot.
1745    pub fn capability_provider(&self) -> McpCapabilityProvider {
1746        McpCapabilityProvider::from_snapshot(self.connection.clone(), &self.snapshot)
1747    }
1748}
1749
1750/// Manages the lifecycle of one or more MCP servers: registration, connection,
1751/// discovery, refresh, disconnection, and auth resolution.
1752///
1753/// This is the primary entry point for integrating MCP servers into an agentkit
1754/// application. Register server configurations, connect them, and then obtain a
1755/// combined [`ToolRegistry`] or [`McpCapabilityProvider`] for use in an agent loop.
1756///
1757/// # Example
1758///
1759/// ```rust,no_run
1760/// use agentkit_mcp::{
1761///     McpServerConfig, McpServerManager, McpTransportBinding, StdioTransportConfig,
1762/// };
1763///
1764/// # #[tokio::main]
1765/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
1766/// let mut manager = McpServerManager::new()
1767///     .with_server(McpServerConfig::new(
1768///         "filesystem",
1769///         McpTransportBinding::Stdio(
1770///             StdioTransportConfig::new("npx")
1771///                 .with_arg("-y")
1772///                 .with_arg("@modelcontextprotocol/server-filesystem"),
1773///         ),
1774///     ))
1775///     .with_server(McpServerConfig::new(
1776///         "github",
1777///         McpTransportBinding::Stdio(
1778///             StdioTransportConfig::new("npx")
1779///                 .with_arg("-y")
1780///                 .with_arg("@modelcontextprotocol/server-github"),
1781///         ),
1782///     ));
1783///
1784/// let handles = manager.connect_all().await?;
1785/// let registry = manager.tool_registry();
1786/// println!("tools: {:?}", registry.specs().iter().map(|s| &s.name).collect::<Vec<_>>());
1787/// # Ok(())
1788/// # }
1789/// ```
1790#[derive(Default)]
1791pub struct McpServerManager {
1792    configs: BTreeMap<McpServerId, McpServerConfig>,
1793    connections: BTreeMap<McpServerId, McpServerHandle>,
1794    auth: BTreeMap<McpServerId, MetadataMap>,
1795}
1796
1797impl McpServerManager {
1798    /// Creates an empty server manager with no registered servers.
1799    pub fn new() -> Self {
1800        Self::default()
1801    }
1802
1803    /// Registers a server configuration and returns `self` for chaining.
1804    ///
1805    /// The server is not connected until [`connect_server`](Self::connect_server) or
1806    /// [`connect_all`](Self::connect_all) is called.
1807    pub fn with_server(mut self, config: McpServerConfig) -> Self {
1808        self.register_server(config);
1809        self
1810    }
1811
1812    /// Registers a server configuration by mutable reference.
1813    ///
1814    /// The server is not connected until [`connect_server`](Self::connect_server) or
1815    /// [`connect_all`](Self::connect_all) is called.
1816    pub fn register_server(&mut self, config: McpServerConfig) -> &mut Self {
1817        self.configs.insert(config.id.clone(), config);
1818        self
1819    }
1820
1821    /// Returns the handle for a connected server, or `None` if it is not connected.
1822    pub fn connected_server(&self, server_id: &McpServerId) -> Option<&McpServerHandle> {
1823        self.connections.get(server_id)
1824    }
1825
1826    /// Returns handles for all currently connected servers.
1827    pub fn connected_servers(&self) -> Vec<&McpServerHandle> {
1828        self.connections.values().collect()
1829    }
1830
1831    /// Connects a single registered server by its identifier.
1832    ///
1833    /// Performs the MCP handshake and full capability discovery.
1834    ///
1835    /// # Errors
1836    ///
1837    /// Returns [`McpError::UnknownServer`] if the server ID has not been registered,
1838    /// or other [`McpError`] variants if connection or discovery fails.
1839    pub async fn connect_server(
1840        &mut self,
1841        server_id: &McpServerId,
1842    ) -> Result<McpServerHandle, McpError> {
1843        let config = self
1844            .configs
1845            .get(server_id)
1846            .cloned()
1847            .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))?;
1848        let connection =
1849            Arc::new(McpConnection::connect_with_auth(&config, self.auth.get(server_id)).await?);
1850        let snapshot = connection.discover().await?;
1851        let handle = McpServerHandle {
1852            config,
1853            connection,
1854            snapshot,
1855        };
1856        self.connections.insert(server_id.clone(), handle.clone());
1857        Ok(handle)
1858    }
1859
1860    /// Connects all registered servers sequentially.
1861    ///
1862    /// Returns a handle for each server in registration order. If any server fails
1863    /// to connect, the error is returned immediately and remaining servers are
1864    /// not attempted.
1865    ///
1866    /// # Errors
1867    ///
1868    /// Returns the first [`McpError`] encountered during connection.
1869    pub async fn connect_all(&mut self) -> Result<Vec<McpServerHandle>, McpError> {
1870        let server_ids = self.configs.keys().cloned().collect::<Vec<_>>();
1871        let mut handles = Vec::with_capacity(server_ids.len());
1872
1873        for server_id in server_ids {
1874            handles.push(self.connect_server(&server_id).await?);
1875        }
1876
1877        Ok(handles)
1878    }
1879
1880    /// Re-discovers capabilities for a connected server, updating the stored snapshot.
1881    ///
1882    /// Call this after the server's capabilities may have changed (e.g. after
1883    /// installing a plugin).
1884    ///
1885    /// # Errors
1886    ///
1887    /// Returns [`McpError::UnknownServer`] if the server is not connected, or other
1888    /// [`McpError`] variants if discovery fails.
1889    pub async fn refresh_server(
1890        &mut self,
1891        server_id: &McpServerId,
1892    ) -> Result<McpDiscoverySnapshot, McpError> {
1893        let handle = self
1894            .connections
1895            .get_mut(server_id)
1896            .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))?;
1897        let snapshot = handle.connection.discover().await?;
1898        handle.snapshot = snapshot.clone();
1899        Ok(snapshot)
1900    }
1901
1902    /// Disconnects a server and removes it from the active connections.
1903    ///
1904    /// The server configuration remains registered and can be reconnected later
1905    /// with [`connect_server`](Self::connect_server).
1906    ///
1907    /// # Errors
1908    ///
1909    /// Returns [`McpError::UnknownServer`] if the server is not connected.
1910    pub async fn disconnect_server(&mut self, server_id: &McpServerId) -> Result<(), McpError> {
1911        let Some(handle) = self.connections.remove(server_id) else {
1912            return Err(McpError::UnknownServer(server_id.to_string()));
1913        };
1914        handle.connection.close().await
1915    }
1916
1917    /// Stores or clears authentication credentials for a server and, if already
1918    /// connected, updates the live connection as well.
1919    ///
1920    /// # Errors
1921    ///
1922    /// Returns [`McpError::UnknownServer`] if the server ID from the resolution
1923    /// does not match any registered server.
1924    pub async fn resolve_auth(&mut self, resolution: AuthResolution) -> Result<(), McpError> {
1925        let server_id = resolution
1926            .request()
1927            .server_id()
1928            .ok_or_else(|| McpError::AuthResolution("auth resolution missing server id".into()))?;
1929        let server_id = McpServerId::new(server_id);
1930        match &resolution {
1931            AuthResolution::Provided { credentials, .. } => {
1932                self.auth.insert(server_id.clone(), credentials.clone());
1933            }
1934            AuthResolution::Cancelled { .. } => {
1935                self.auth.remove(&server_id);
1936            }
1937        }
1938
1939        if let Some(handle) = self.connections.get(&server_id) {
1940            handle.connection.resolve_auth(resolution).await?;
1941            return Ok(());
1942        }
1943
1944        if self.configs.contains_key(&server_id) {
1945            Ok(())
1946        } else {
1947            Err(McpError::UnknownServer(server_id.to_string()))
1948        }
1949    }
1950
1951    /// Resolves authentication and immediately replays the operation that originally
1952    /// triggered the auth challenge.
1953    ///
1954    /// This is a convenience method combining [`resolve_auth`](Self::resolve_auth)
1955    /// and [`replay_auth_request`](Self::replay_auth_request).
1956    ///
1957    /// # Errors
1958    ///
1959    /// Returns [`McpError`] if auth resolution or the replayed operation fails.
1960    pub async fn resolve_auth_and_resume(
1961        &mut self,
1962        resolution: AuthResolution,
1963    ) -> Result<McpOperationResult, McpError> {
1964        let request = resolution.request().clone();
1965        self.resolve_auth(resolution).await?;
1966        self.replay_auth_request(&request).await
1967    }
1968
1969    /// Replays an auth request's original MCP operation using stored credentials.
1970    ///
1971    /// For connect operations the server is (re)connected. For tool calls, resource
1972    /// reads, and prompt retrievals the request is re-issued on the existing or
1973    /// newly established connection.
1974    ///
1975    /// # Errors
1976    ///
1977    /// Returns [`McpError`] if the operation cannot be replayed.
1978    pub async fn replay_auth_request(
1979        &mut self,
1980        request: &AuthRequest,
1981    ) -> Result<McpOperationResult, McpError> {
1982        match &request.operation {
1983            AuthOperation::McpConnect { server_id, .. } => {
1984                let server_id = McpServerId::new(server_id);
1985                let handle = self.connect_server(&server_id).await?;
1986                Ok(McpOperationResult::Connected(handle.snapshot.clone()))
1987            }
1988            AuthOperation::McpToolCall { server_id, .. }
1989            | AuthOperation::McpResourceRead { server_id, .. }
1990            | AuthOperation::McpPromptGet { server_id, .. } => {
1991                let connection = self.connection_for_auth_server(server_id).await?;
1992                connection.replay_auth_operation(&request.operation).await
1993            }
1994            AuthOperation::ToolCall { metadata, .. } => {
1995                let server_id = metadata
1996                    .get("server_id")
1997                    .and_then(Value::as_str)
1998                    .ok_or_else(|| {
1999                        McpError::AuthResolution(
2000                            "tool-call auth replay requires metadata.server_id".into(),
2001                        )
2002                    })?;
2003                let connection = self.connection_for_auth_server(server_id).await?;
2004                connection.replay_auth_operation(&request.operation).await
2005            }
2006            AuthOperation::Custom { kind, .. } => Err(McpError::AuthResolution(format!(
2007                "unsupported auth operation for replay: {kind}"
2008            ))),
2009        }
2010    }
2011
2012    async fn connection_for_auth_server(
2013        &mut self,
2014        server_id: &str,
2015    ) -> Result<Arc<McpConnection>, McpError> {
2016        let server_id = McpServerId::new(server_id);
2017        if !self.connections.contains_key(&server_id) {
2018            self.connect_server(&server_id).await?;
2019        }
2020        self.connections
2021            .get(&server_id)
2022            .map(McpServerHandle::connection)
2023            .ok_or_else(|| McpError::UnknownServer(server_id.to_string()))
2024    }
2025
2026    /// Builds a combined [`ToolRegistry`] containing [`McpToolAdapter`]s for every
2027    /// tool discovered across all connected servers.
2028    ///
2029    /// Tool names are prefixed as `mcp_<server_id>_<tool_name>`.
2030    pub fn tool_registry(&self) -> ToolRegistry {
2031        self.connections
2032            .values()
2033            .fold(ToolRegistry::new(), |mut registry, handle| {
2034                for tool in handle.snapshot.tools.iter().cloned() {
2035                    registry.register(McpToolAdapter::new(
2036                        handle.server_id(),
2037                        handle.connection.clone(),
2038                        tool,
2039                    ));
2040                }
2041                registry
2042            })
2043    }
2044
2045    /// Builds a combined [`McpCapabilityProvider`] from all connected servers,
2046    /// merging their tools, resources, and prompts.
2047    pub fn capability_provider(&self) -> McpCapabilityProvider {
2048        McpCapabilityProvider::merge(
2049            self.connections
2050                .values()
2051                .map(McpServerHandle::capability_provider),
2052        )
2053    }
2054}
2055
2056/// Adapter that exposes an MCP tool as an agentkit [`Tool`].
2057///
2058/// This is the tool-layer adapter for the tool registry. For the capabilities-layer
2059/// adapter, see [`McpInvocable`]. Tool names are prefixed as
2060/// `mcp_<server_id>_<tool_name>`.
2061///
2062/// # Example
2063///
2064/// ```rust
2065/// use std::sync::Arc;
2066/// use agentkit_core::MetadataMap;
2067/// use agentkit_mcp::{McpToolAdapter, McpToolDescriptor, McpServerId};
2068/// # // McpToolAdapter::new requires a connection which we cannot construct in a doc test,
2069/// # // so this example only shows the construction pattern.
2070/// ```
2071pub struct McpToolAdapter {
2072    descriptor: McpToolDescriptor,
2073    connection: Arc<McpConnection>,
2074    spec: ToolSpec,
2075}
2076
2077impl McpToolAdapter {
2078    /// Creates a new tool adapter for the given MCP tool.
2079    ///
2080    /// # Arguments
2081    ///
2082    /// * `server_id` - The server's identifier, used to namespace the tool name.
2083    /// * `connection` - A shared connection to the owning MCP server.
2084    /// * `descriptor` - The tool descriptor obtained from discovery.
2085    pub fn new(
2086        server_id: &McpServerId,
2087        connection: Arc<McpConnection>,
2088        descriptor: McpToolDescriptor,
2089    ) -> Self {
2090        let spec = ToolSpec {
2091            name: ToolName::new(format!("mcp_{}_{}", server_id, descriptor.name)),
2092            description: descriptor
2093                .description
2094                .clone()
2095                .unwrap_or_else(|| descriptor.name.clone()),
2096            input_schema: descriptor.input_schema.clone(),
2097            annotations: ToolAnnotations::default(),
2098            metadata: descriptor.metadata.clone(),
2099        };
2100
2101        Self {
2102            descriptor,
2103            connection,
2104            spec,
2105        }
2106    }
2107}
2108
2109#[async_trait]
2110impl Tool for McpToolAdapter {
2111    fn spec(&self) -> &ToolSpec {
2112        &self.spec
2113    }
2114
2115    async fn invoke(
2116        &self,
2117        request: ToolRequest,
2118        _ctx: &mut ToolContext<'_>,
2119    ) -> Result<ToolResult, ToolError> {
2120        let result = self
2121            .connection
2122            .call_tool(&self.descriptor.name, request.input)
2123            .await
2124            .map_err(|error| match error {
2125                McpError::AuthRequired(request) => ToolError::AuthRequired(request),
2126                other => ToolError::ExecutionFailed(other.to_string()),
2127            })?;
2128
2129        Ok(ToolResult {
2130            result: ToolResultPart {
2131                call_id: request.call_id,
2132                output: invocable_output_to_tool_output(value_to_invocable_output(result)),
2133                is_error: false,
2134                metadata: MetadataMap::new(),
2135            },
2136            duration: None,
2137            metadata: MetadataMap::new(),
2138        })
2139    }
2140}
2141
2142fn parse_tool_descriptor(value: Value) -> Result<McpToolDescriptor, McpError> {
2143    Ok(McpToolDescriptor {
2144        name: required_string(&value, "name")?,
2145        description: value
2146            .get("description")
2147            .and_then(Value::as_str)
2148            .map(str::to_owned),
2149        input_schema: value
2150            .get("inputSchema")
2151            .cloned()
2152            .unwrap_or_else(|| json!({ "type": "object" })),
2153        metadata: MetadataMap::new(),
2154    })
2155}
2156
2157fn parse_resource_descriptor(value: Value) -> Result<McpResourceDescriptor, McpError> {
2158    Ok(McpResourceDescriptor {
2159        id: required_string(&value, "uri")?,
2160        name: value
2161            .get("name")
2162            .and_then(Value::as_str)
2163            .map(str::to_owned)
2164            .unwrap_or_else(|| {
2165                value
2166                    .get("uri")
2167                    .and_then(Value::as_str)
2168                    .unwrap_or_default()
2169                    .to_string()
2170            }),
2171        description: value
2172            .get("description")
2173            .and_then(Value::as_str)
2174            .map(str::to_owned),
2175        mime_type: value
2176            .get("mimeType")
2177            .and_then(Value::as_str)
2178            .map(str::to_owned),
2179        metadata: MetadataMap::new(),
2180    })
2181}
2182
2183fn parse_prompt_descriptor(value: Value) -> Result<McpPromptDescriptor, McpError> {
2184    let name = required_string(&value, "name")?;
2185    let properties = value
2186        .get("arguments")
2187        .and_then(Value::as_array)
2188        .cloned()
2189        .unwrap_or_default()
2190        .into_iter()
2191        .filter_map(|arg| {
2192            let name = arg.get("name")?.as_str()?.to_string();
2193            Some((name, json!({ "type": "string" })))
2194        })
2195        .collect::<serde_json::Map<String, Value>>();
2196
2197    Ok(McpPromptDescriptor {
2198        id: name.clone(),
2199        name,
2200        description: value
2201            .get("description")
2202            .and_then(Value::as_str)
2203            .map(str::to_owned),
2204        input_schema: json!({
2205            "type": "object",
2206            "properties": properties,
2207        }),
2208        metadata: MetadataMap::new(),
2209    })
2210}
2211
2212fn parse_prompt_message(value: Value) -> Result<Item, McpError> {
2213    let role = value.get("role").and_then(Value::as_str).unwrap_or("user");
2214    let kind = match role {
2215        "assistant" => ItemKind::Assistant,
2216        "system" => ItemKind::System,
2217        _ => ItemKind::User,
2218    };
2219
2220    let content = value.get("content").cloned().unwrap_or(Value::Null);
2221    let text = if let Some(text) = content.get("text").and_then(Value::as_str) {
2222        text.to_string()
2223    } else if let Some(text) = content.as_str() {
2224        text.to_string()
2225    } else {
2226        content.to_string()
2227    };
2228
2229    Ok(Item {
2230        id: None,
2231        kind,
2232        parts: vec![Part::Text(TextPart {
2233            text,
2234            metadata: MetadataMap::new(),
2235        })],
2236        metadata: MetadataMap::new(),
2237    })
2238}
2239
2240fn required_string(value: &Value, field: &str) -> Result<String, McpError> {
2241    value
2242        .get(field)
2243        .and_then(Value::as_str)
2244        .map(str::to_owned)
2245        .ok_or_else(|| McpError::Protocol(format!("missing string field {field}")))
2246}
2247
2248fn value_to_invocable_output(value: Value) -> InvocableOutput {
2249    if let Some(content) = value.get("content").and_then(Value::as_array) {
2250        let text = content
2251            .iter()
2252            .filter_map(|item| item.get("text").and_then(Value::as_str))
2253            .collect::<Vec<_>>()
2254            .join("\n");
2255        if !text.is_empty() {
2256            return InvocableOutput::Text(text);
2257        }
2258    }
2259
2260    if let Some(text) = value.as_str() {
2261        InvocableOutput::Text(text.to_string())
2262    } else {
2263        InvocableOutput::Structured(value)
2264    }
2265}
2266
2267fn invocable_output_to_tool_output(output: InvocableOutput) -> ToolOutput {
2268    match output {
2269        InvocableOutput::Text(text) => ToolOutput::Text(text),
2270        InvocableOutput::Structured(value) => ToolOutput::Structured(value),
2271        InvocableOutput::Items(items) => {
2272            ToolOutput::Parts(items.into_iter().flat_map(|item| item.parts).collect())
2273        }
2274        InvocableOutput::Data(data) => ToolOutput::Structured(json!({ "data": data })),
2275    }
2276}
2277
2278fn metadata_to_value(metadata: &MetadataMap) -> Value {
2279    Value::Object(
2280        metadata
2281            .iter()
2282            .map(|(key, value)| (key.clone(), value.clone()))
2283            .collect(),
2284    )
2285}
2286
2287fn parse_auth_request(
2288    server_id: &McpServerId,
2289    method: &str,
2290    params: &Value,
2291    error: &Value,
2292) -> Option<AuthRequest> {
2293    let code = error.get("code").and_then(Value::as_i64);
2294    let message = error.get("message").and_then(Value::as_str);
2295    let data = error.get("data");
2296
2297    let auth_marker = matches!(code, Some(401 | -32001))
2298        || data
2299            .and_then(|data| data.get("auth_required"))
2300            .and_then(Value::as_bool)
2301            == Some(true)
2302        || data.and_then(|data| data.get("auth")).is_some();
2303
2304    if !auth_marker {
2305        return None;
2306    }
2307
2308    let mut challenge = MetadataMap::new();
2309    challenge.insert("server_id".into(), Value::String(server_id.to_string()));
2310    challenge.insert("method".into(), Value::String(method.into()));
2311
2312    if let Some(code) = code {
2313        challenge.insert("code".into(), Value::Number(code.into()));
2314    }
2315    if let Some(message) = message {
2316        challenge.insert("message".into(), Value::String(message.into()));
2317    }
2318    if let Some(data) = data {
2319        challenge.insert("data".into(), data.clone());
2320    }
2321
2322    Some(AuthRequest {
2323        task_id: None,
2324        id: format!("mcp:{}:{}", server_id, method),
2325        provider: format!("mcp.{}", server_id),
2326        operation: auth_operation_for_method(server_id, method, params),
2327        challenge,
2328    })
2329}
2330
2331fn auth_operation_for_method(
2332    server_id: &McpServerId,
2333    method: &str,
2334    params: &Value,
2335) -> AuthOperation {
2336    match method {
2337        "initialize" => AuthOperation::McpConnect {
2338            server_id: server_id.to_string(),
2339            metadata: MetadataMap::new(),
2340        },
2341        "tools/call" => AuthOperation::McpToolCall {
2342            server_id: server_id.to_string(),
2343            tool_name: params
2344                .get("name")
2345                .and_then(Value::as_str)
2346                .unwrap_or_default()
2347                .to_string(),
2348            input: params
2349                .get("arguments")
2350                .cloned()
2351                .unwrap_or_else(|| json!({})),
2352            metadata: MetadataMap::new(),
2353        },
2354        "resources/read" => AuthOperation::McpResourceRead {
2355            server_id: server_id.to_string(),
2356            resource_id: params
2357                .get("uri")
2358                .and_then(Value::as_str)
2359                .unwrap_or_default()
2360                .to_string(),
2361            metadata: MetadataMap::new(),
2362        },
2363        "prompts/get" => AuthOperation::McpPromptGet {
2364            server_id: server_id.to_string(),
2365            prompt_id: params
2366                .get("name")
2367                .and_then(Value::as_str)
2368                .unwrap_or_default()
2369                .to_string(),
2370            args: params
2371                .get("arguments")
2372                .cloned()
2373                .unwrap_or_else(|| json!({})),
2374            metadata: MetadataMap::new(),
2375        },
2376        other => AuthOperation::Custom {
2377            kind: format!("mcp.{other}"),
2378            payload: params.clone(),
2379            metadata: {
2380                let mut metadata = MetadataMap::new();
2381                metadata.insert("server_id".into(), Value::String(server_id.to_string()));
2382                metadata
2383            },
2384        },
2385    }
2386}
2387
2388fn normalize_mcp_tool_name(server_id: &McpServerId, tool_name: &str) -> String {
2389    let prefix = format!("mcp_{server_id}_");
2390    tool_name
2391        .strip_prefix(&prefix)
2392        .unwrap_or(tool_name)
2393        .to_string()
2394}
2395
2396async fn read_sse_stream<R>(
2397    mut reader: R,
2398    response_url: Url,
2399    frame_tx: mpsc::UnboundedSender<Result<McpFrame, McpError>>,
2400    endpoint_tx: oneshot::Sender<Result<Url, McpError>>,
2401) where
2402    R: AsyncBufRead + Unpin,
2403{
2404    let mut endpoint_tx = Some(endpoint_tx);
2405    loop {
2406        match read_next_sse_event(&mut reader).await {
2407            Ok(Some(event)) => {
2408                if let Some(endpoint) = legacy_sse_event_to_endpoint(&response_url, &event) {
2409                    if let Some(tx) = endpoint_tx.take() {
2410                        let _ = tx.send(endpoint);
2411                    }
2412                    continue;
2413                }
2414
2415                if let Some(frame) = legacy_sse_event_to_frame(event) {
2416                    let _ = frame_tx.send(frame);
2417                }
2418            }
2419            Ok(None) => break,
2420            Err(error) => {
2421                if let Some(tx) = endpoint_tx.take() {
2422                    let _ = tx.send(Err(error));
2423                } else {
2424                    let _ = frame_tx.send(Err(error));
2425                }
2426                return;
2427            }
2428        }
2429    }
2430
2431    if let Some(tx) = endpoint_tx.take() {
2432        let _ = tx.send(Err(McpError::Transport(
2433            "SSE stream ended before endpoint event".into(),
2434        )));
2435    }
2436}
2437
2438fn resolve_sse_endpoint(response_url: &Url, endpoint: &str) -> Result<Url, McpError> {
2439    response_url
2440        .join(endpoint.trim())
2441        .map_err(|error| McpError::Transport(format!("invalid SSE endpoint URL: {error}")))
2442}
2443
2444#[derive(Debug)]
2445struct SseEvent {
2446    event_name: Option<String>,
2447    data: String,
2448    id: Option<String>,
2449    retry_ms: Option<u64>,
2450}
2451
2452async fn read_next_sse_event<R>(reader: &mut R) -> Result<Option<SseEvent>, McpError>
2453where
2454    R: AsyncBufRead + Unpin,
2455{
2456    let mut event_name = None;
2457    let mut data_lines = Vec::new();
2458    let mut id = None;
2459    let mut retry_ms = None;
2460
2461    loop {
2462        let mut line = String::new();
2463        let read = reader.read_line(&mut line).await.map_err(McpError::Io)?;
2464        if read == 0 {
2465            if event_name.is_none() && data_lines.is_empty() && id.is_none() && retry_ms.is_none() {
2466                return Ok(None);
2467            }
2468            return Ok(Some(SseEvent {
2469                event_name,
2470                data: data_lines.join("\n"),
2471                id,
2472                retry_ms,
2473            }));
2474        }
2475
2476        let line = line.trim_end_matches(['\r', '\n']);
2477        if line.is_empty() {
2478            if event_name.is_none() && data_lines.is_empty() && id.is_none() && retry_ms.is_none() {
2479                continue;
2480            }
2481            return Ok(Some(SseEvent {
2482                event_name,
2483                data: data_lines.join("\n"),
2484                id,
2485                retry_ms,
2486            }));
2487        }
2488
2489        if line.starts_with(':') {
2490            continue;
2491        }
2492
2493        if let Some(rest) = line.strip_prefix("event:") {
2494            event_name = Some(rest.trim_start().to_string());
2495            continue;
2496        }
2497        if let Some(rest) = line.strip_prefix("data:") {
2498            data_lines.push(rest.trim_start().to_string());
2499            continue;
2500        }
2501        if let Some(rest) = line.strip_prefix("id:") {
2502            id = Some(rest.trim_start().to_string());
2503            continue;
2504        }
2505        if let Some(rest) = line.strip_prefix("retry:") {
2506            retry_ms = rest.trim_start().parse().ok();
2507        }
2508    }
2509}
2510
2511fn legacy_sse_event_to_endpoint(
2512    response_url: &Url,
2513    event: &SseEvent,
2514) -> Option<Result<Url, McpError>> {
2515    if event.event_name.as_deref() != Some("endpoint") {
2516        return None;
2517    }
2518    if event.data.is_empty() {
2519        return Some(Err(McpError::Transport(
2520            "legacy SSE endpoint event is missing data".into(),
2521        )));
2522    }
2523    Some(resolve_sse_endpoint(response_url, &event.data))
2524}
2525
2526fn legacy_sse_event_to_frame(event: SseEvent) -> Option<Result<McpFrame, McpError>> {
2527    let event_name = event.event_name.unwrap_or_else(|| "message".into());
2528    if event_name != "message" || event.data.is_empty() {
2529        return None;
2530    }
2531
2532    Some(
2533        serde_json::from_str(&event.data)
2534            .map_err(McpError::Serialize)
2535            .map(|value| McpFrame { value }),
2536    )
2537}
2538
2539fn streamable_http_event_to_frame(event: SseEvent) -> Result<Option<McpFrame>, McpError> {
2540    let event_name = event.event_name.unwrap_or_else(|| "message".into());
2541    if event_name != "message" || event.data.is_empty() {
2542        return Ok(None);
2543    }
2544
2545    let value = serde_json::from_str(&event.data).map_err(McpError::Serialize)?;
2546    Ok(Some(McpFrame { value }))
2547}
2548
2549fn is_jsonrpc_request(value: &Value) -> bool {
2550    value.get("method").is_some() && value.get("id").is_some()
2551}
2552
2553fn apply_streamable_http_headers(
2554    mut request: HttpRequestBuilder,
2555    protocol_version: Option<&str>,
2556    session_id: Option<&str>,
2557) -> HttpRequestBuilder {
2558    if let Some(protocol_version) = protocol_version {
2559        request = request.header("MCP-Protocol-Version", protocol_version);
2560    }
2561    if let Some(session_id) = session_id {
2562        request = request.header("MCP-Session-Id", session_id);
2563    }
2564
2565    request
2566}
2567
2568async fn streamable_http_status_error(
2569    operation: &str,
2570    status: StatusCode,
2571    response: HttpResponse,
2572) -> McpError {
2573    let body = response
2574        .text()
2575        .await
2576        .unwrap_or_else(|_| "<unreadable response body>".into());
2577    McpError::Transport(format!("{operation} failed with status {status}: {body}"))
2578}
2579
2580/// Errors produced by MCP transport, protocol, and lifecycle operations.
2581#[derive(Debug, Error)]
2582pub enum McpError {
2583    /// An underlying I/O error (e.g. spawning a child process or reading from a pipe).
2584    #[error("io error: {0}")]
2585    Io(#[from] std::io::Error),
2586    /// An HTTP-level error surfaced by the configured [`agentkit_http::HttpClient`].
2587    #[error("http error: {0}")]
2588    Http(#[from] HttpError),
2589    /// A JSON serialization or deserialization error.
2590    #[error("serialization error: {0}")]
2591    Serialize(#[from] serde_json::Error),
2592    /// A transport-level error (e.g. unexpected disconnection or bad SSE response).
2593    #[error("transport error: {0}")]
2594    Transport(String),
2595    /// An MCP protocol violation (e.g. missing required fields in a response).
2596    #[error("protocol error: {0}")]
2597    Protocol(String),
2598    /// The server requires authentication before the operation can proceed.
2599    /// Contains the [`AuthRequest`] that describes the challenge.
2600    #[error("MCP auth required: {0:?}")]
2601    AuthRequired(Box<AuthRequest>),
2602    /// An error occurred while resolving or replaying authentication.
2603    #[error("auth resolution error: {0}")]
2604    AuthResolution(String),
2605    /// The MCP server returned an error for the invoked method.
2606    #[error("invocation error: {0}")]
2607    Invocation(String),
2608    /// The referenced server ID is not registered in the [`McpServerManager`].
2609    #[error("unknown MCP server: {0}")]
2610    UnknownServer(String),
2611}
2612
2613#[cfg(test)]
2614mod tests {
2615    use std::collections::VecDeque;
2616    use std::sync::{Arc as StdArc, Mutex as StdMutex};
2617
2618    use super::*;
2619    use agentkit_tools_core::{PermissionChecker, PermissionDecision, PermissionRequest};
2620    use tokio::io::{AsyncReadExt, AsyncWriteExt};
2621    use tokio::net::TcpListener;
2622
2623    struct AllowAll;
2624
2625    impl PermissionChecker for AllowAll {
2626        fn evaluate(&self, _request: &dyn PermissionRequest) -> PermissionDecision {
2627            PermissionDecision::Allow
2628        }
2629    }
2630
2631    struct FakeTransport {
2632        recv: VecDeque<Value>,
2633    }
2634
2635    #[async_trait]
2636    impl McpTransport for FakeTransport {
2637        async fn send(&mut self, _message: McpFrame) -> Result<(), McpError> {
2638            Ok(())
2639        }
2640
2641        async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
2642            Ok(self.recv.pop_front().map(|value| McpFrame { value }))
2643        }
2644
2645        async fn close(&mut self) -> Result<(), McpError> {
2646            Ok(())
2647        }
2648    }
2649
2650    fn fake_connection(responses: Vec<Value>) -> McpConnection {
2651        McpConnection {
2652            server_id: McpServerId::new("fake"),
2653            transport: Mutex::new(Box::new(FakeTransport {
2654                recv: responses.into(),
2655            })),
2656            auth: Mutex::new(None),
2657            next_id: AtomicU64::new(1),
2658        }
2659    }
2660
2661    #[derive(Clone)]
2662    struct FakeTransportFactory {
2663        responses: StdArc<StdMutex<VecDeque<Vec<Value>>>>,
2664    }
2665
2666    impl FakeTransportFactory {
2667        fn new(sequences: Vec<Vec<Value>>) -> Self {
2668            Self {
2669                responses: StdArc::new(StdMutex::new(sequences.into())),
2670            }
2671        }
2672    }
2673
2674    #[async_trait]
2675    impl McpTransportFactory for FakeTransportFactory {
2676        async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
2677            let responses =
2678                self.responses.lock().unwrap().pop_front().ok_or_else(|| {
2679                    McpError::Transport("no fake transport responses left".into())
2680                })?;
2681            Ok(Box::new(FakeTransport {
2682                recv: responses.into(),
2683            }))
2684        }
2685    }
2686
2687    #[tokio::test]
2688    async fn discovery_parses_snapshot() {
2689        let connection = fake_connection(vec![
2690            json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "echo", "description": "Echo", "inputSchema": {"type": "object"} }] } }),
2691            json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [{ "uri": "file:///tmp/example.txt", "name": "example.txt", "mimeType": "text/plain" }] } }),
2692            json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [{ "name": "summarize", "description": "Summarize", "arguments": [{ "name": "path" }] }] } }),
2693        ]);
2694
2695        let snapshot = connection.discover().await.unwrap();
2696        assert_eq!(snapshot.tools[0].name, "echo");
2697        assert_eq!(snapshot.resources[0].id, "file:///tmp/example.txt");
2698        assert_eq!(snapshot.prompts[0].id, "summarize");
2699    }
2700
2701    #[tokio::test]
2702    async fn tool_adapter_returns_text_output() {
2703        let connection = Arc::new(fake_connection(vec![json!({
2704            "jsonrpc": "2.0",
2705            "id": 1,
2706            "result": { "content": [{ "type": "text", "text": "pong" }] }
2707        })]));
2708        let server_id = connection.server_id().clone();
2709        let adapter = McpToolAdapter::new(
2710            &server_id,
2711            connection,
2712            McpToolDescriptor {
2713                name: "echo".into(),
2714                description: Some("Echo".into()),
2715                input_schema: json!({ "type": "object" }),
2716                metadata: MetadataMap::new(),
2717            },
2718        );
2719        let metadata = MetadataMap::new();
2720        let mut ctx = ToolContext {
2721            capability: CapabilityContext {
2722                session_id: None,
2723                turn_id: None,
2724                metadata: &metadata,
2725            },
2726            permissions: &AllowAll,
2727            resources: &(),
2728            cancellation: None,
2729        };
2730
2731        let result = adapter
2732            .invoke(
2733                ToolRequest {
2734                    call_id: "call-1".into(),
2735                    tool_name: ToolName::new("mcp_fake_echo"),
2736                    input: json!({}),
2737                    session_id: "session-1".into(),
2738                    turn_id: "turn-1".into(),
2739                    metadata: MetadataMap::new(),
2740                },
2741                &mut ctx,
2742            )
2743            .await
2744            .unwrap();
2745
2746        assert_eq!(result.result.output, ToolOutput::Text("pong".into()));
2747    }
2748
2749    #[tokio::test]
2750    async fn request_surfaces_auth_required_errors() {
2751        let connection = fake_connection(vec![json!({
2752            "jsonrpc": "2.0",
2753            "id": 1,
2754            "error": {
2755                "code": -32001,
2756                "message": "authentication required",
2757                "data": {
2758                    "auth_required": true,
2759                    "scope": "secrets.read"
2760                }
2761            }
2762        })]);
2763
2764        let error = connection.call_tool("echo", json!({})).await.unwrap_err();
2765        match error {
2766            McpError::AuthRequired(request) => {
2767                assert_eq!(request.provider, "mcp.fake");
2768                assert_eq!(
2769                    request.challenge.get("method"),
2770                    Some(&Value::String("tools/call".into()))
2771                );
2772                assert!(matches!(
2773                    request.operation,
2774                    AuthOperation::McpToolCall { ref tool_name, .. } if tool_name == "echo"
2775                ));
2776            }
2777            other => panic!("unexpected error: {other:?}"),
2778        }
2779    }
2780
2781    #[tokio::test]
2782    async fn tool_adapter_maps_auth_required_into_tool_error() {
2783        let connection = Arc::new(fake_connection(vec![json!({
2784            "jsonrpc": "2.0",
2785            "id": 1,
2786            "error": {
2787                "code": -32001,
2788                "message": "authentication required",
2789                "data": { "auth_required": true }
2790            }
2791        })]));
2792        let server_id = connection.server_id().clone();
2793        let adapter = McpToolAdapter::new(
2794            &server_id,
2795            connection,
2796            McpToolDescriptor {
2797                name: "echo".into(),
2798                description: Some("Echo".into()),
2799                input_schema: json!({ "type": "object" }),
2800                metadata: MetadataMap::new(),
2801            },
2802        );
2803        let metadata = MetadataMap::new();
2804        let mut ctx = ToolContext {
2805            capability: CapabilityContext {
2806                session_id: None,
2807                turn_id: None,
2808                metadata: &metadata,
2809            },
2810            permissions: &AllowAll,
2811            resources: &(),
2812            cancellation: None,
2813        };
2814
2815        let error = adapter
2816            .invoke(
2817                ToolRequest {
2818                    call_id: "call-1".into(),
2819                    tool_name: ToolName::new("mcp_fake_echo"),
2820                    input: json!({}),
2821                    session_id: "session-1".into(),
2822                    turn_id: "turn-1".into(),
2823                    metadata: MetadataMap::new(),
2824                },
2825                &mut ctx,
2826            )
2827            .await
2828            .unwrap_err();
2829
2830        match error {
2831            ToolError::AuthRequired(request) => {
2832                assert_eq!(request.provider, "mcp.fake");
2833            }
2834            other => panic!("unexpected error: {other:?}"),
2835        }
2836    }
2837
2838    struct RecordingTransport {
2839        recv: VecDeque<Value>,
2840        sent: StdArc<StdMutex<Vec<Value>>>,
2841    }
2842
2843    #[async_trait]
2844    impl McpTransport for RecordingTransport {
2845        async fn send(&mut self, message: McpFrame) -> Result<(), McpError> {
2846            self.sent.lock().unwrap().push(message.value);
2847            Ok(())
2848        }
2849
2850        async fn recv(&mut self) -> Result<Option<McpFrame>, McpError> {
2851            Ok(self.recv.pop_front().map(|value| McpFrame { value }))
2852        }
2853
2854        async fn close(&mut self) -> Result<(), McpError> {
2855            Ok(())
2856        }
2857    }
2858
2859    #[derive(Clone)]
2860    struct RecordingTransportFactory {
2861        responses: StdArc<StdMutex<VecDeque<Vec<Value>>>>,
2862        sent: StdArc<StdMutex<Vec<Value>>>,
2863    }
2864
2865    impl RecordingTransportFactory {
2866        fn new(sequences: Vec<Vec<Value>>) -> Self {
2867            Self {
2868                responses: StdArc::new(StdMutex::new(sequences.into())),
2869                sent: StdArc::new(StdMutex::new(Vec::new())),
2870            }
2871        }
2872
2873        fn sent(&self) -> Vec<Value> {
2874            self.sent.lock().unwrap().clone()
2875        }
2876    }
2877
2878    #[async_trait]
2879    impl McpTransportFactory for RecordingTransportFactory {
2880        async fn connect(&self) -> Result<Box<dyn McpTransport>, McpError> {
2881            let responses = self.responses.lock().unwrap().pop_front().ok_or_else(|| {
2882                McpError::Transport("no recording transport responses left".into())
2883            })?;
2884            Ok(Box::new(RecordingTransport {
2885                recv: responses.into(),
2886                sent: self.sent.clone(),
2887            }))
2888        }
2889    }
2890
2891    #[tokio::test]
2892    async fn connection_includes_resolved_auth_in_future_requests() {
2893        let factory = RecordingTransportFactory::new(vec![vec![
2894            json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
2895            json!({ "jsonrpc": "2.0", "id": 1, "result": { "content": [{ "type": "text", "text": "ok" }] } }),
2896        ]]);
2897        let config = McpServerConfig::new(
2898            "recording",
2899            McpTransportBinding::Custom(Arc::new(factory.clone())),
2900        );
2901        let connection = McpConnection::connect(&config).await.unwrap();
2902        let mut auth = MetadataMap::new();
2903        auth.insert("token".into(), json!("secret-token"));
2904        let request = AuthRequest {
2905            task_id: None,
2906            id: "auth-recording-tool".into(),
2907            provider: "mcp.recording".into(),
2908            operation: AuthOperation::McpToolCall {
2909                server_id: "recording".into(),
2910                tool_name: "echo".into(),
2911                input: json!({}),
2912                metadata: MetadataMap::new(),
2913            },
2914            challenge: MetadataMap::new(),
2915        };
2916        connection
2917            .resolve_auth(agentkit_tools_core::AuthResolution::Provided {
2918                request,
2919                credentials: auth,
2920            })
2921            .await
2922            .unwrap();
2923
2924        let _ = connection.call_tool("echo", json!({})).await.unwrap();
2925        let sent = factory.sent();
2926        assert!(
2927            sent.iter().any(|value| {
2928                value
2929                    .get("params")
2930                    .and_then(|params| params.get("auth"))
2931                    .and_then(|auth| auth.get("token"))
2932                    == Some(&json!("secret-token"))
2933            }),
2934            "expected an MCP request to include the resolved auth payload, saw {:?}",
2935            sent
2936        );
2937    }
2938
2939    #[tokio::test]
2940    async fn manager_reuses_stored_auth_on_connect() {
2941        let factory = RecordingTransportFactory::new(vec![vec![
2942            json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
2943            json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
2944            json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
2945            json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
2946        ]]);
2947        let server_id = McpServerId::new("recording");
2948        let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
2949            server_id.to_string(),
2950            McpTransportBinding::Custom(Arc::new(factory.clone())),
2951        ));
2952        let mut auth = MetadataMap::new();
2953        auth.insert("token".into(), json!("seed-token"));
2954        let request = AuthRequest {
2955            task_id: None,
2956            id: "auth-recording-connect".into(),
2957            provider: "mcp.recording".into(),
2958            operation: AuthOperation::McpConnect {
2959                server_id: server_id.to_string(),
2960                metadata: MetadataMap::new(),
2961            },
2962            challenge: MetadataMap::new(),
2963        };
2964        manager
2965            .resolve_auth(agentkit_tools_core::AuthResolution::Provided {
2966                request,
2967                credentials: auth,
2968            })
2969            .await
2970            .unwrap();
2971
2972        manager.connect_server(&server_id).await.unwrap();
2973        let sent = factory.sent();
2974        assert!(
2975            sent.iter().any(|value| {
2976                value.get("method").and_then(Value::as_str) == Some("initialize")
2977                    && value
2978                        .get("params")
2979                        .and_then(|params| params.get("auth"))
2980                        .and_then(|auth| auth.get("token"))
2981                        == Some(&json!("seed-token"))
2982            }),
2983            "expected initialize to include stored auth, saw {:?}",
2984            sent
2985        );
2986    }
2987
2988    #[tokio::test]
2989    async fn manager_resolves_auth_and_replays_resource_read() {
2990        let factory = RecordingTransportFactory::new(vec![vec![
2991            json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
2992            json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
2993            json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
2994            json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
2995            json!({
2996                "jsonrpc": "2.0",
2997                "id": 4,
2998                "result": {
2999                    "contents": [
3000                        {
3001                            "uri": "file:///tmp/secret.txt",
3002                            "text": "secret from resource"
3003                        }
3004                    ]
3005                }
3006            }),
3007        ]]);
3008        let server_id = McpServerId::new("recording");
3009        let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
3010            server_id.to_string(),
3011            McpTransportBinding::Custom(Arc::new(factory.clone())),
3012        ));
3013        let mut auth = MetadataMap::new();
3014        auth.insert("token".into(), json!("resource-token"));
3015        let request = AuthRequest {
3016            task_id: None,
3017            id: "auth-recording-resource".into(),
3018            provider: "mcp.recording".into(),
3019            operation: AuthOperation::McpResourceRead {
3020                server_id: server_id.to_string(),
3021                resource_id: "file:///tmp/secret.txt".into(),
3022                metadata: MetadataMap::new(),
3023            },
3024            challenge: MetadataMap::new(),
3025        };
3026
3027        let result = manager
3028            .resolve_auth_and_resume(agentkit_tools_core::AuthResolution::Provided {
3029                request,
3030                credentials: auth,
3031            })
3032            .await
3033            .unwrap();
3034
3035        match result {
3036            McpOperationResult::Resource(contents) => {
3037                assert_eq!(
3038                    contents.data,
3039                    DataRef::InlineText("secret from resource".into())
3040                );
3041            }
3042            other => panic!("unexpected replay result: {other:?}"),
3043        }
3044
3045        let sent = factory.sent();
3046        assert!(
3047            sent.iter().any(|value| {
3048                value.get("method").and_then(Value::as_str) == Some("resources/read")
3049                    && value
3050                        .get("params")
3051                        .and_then(|params| params.get("auth"))
3052                        .and_then(|auth| auth.get("token"))
3053                        == Some(&json!("resource-token"))
3054            }),
3055            "expected resources/read to include resolved auth, saw {:?}",
3056            sent
3057        );
3058    }
3059
3060    #[tokio::test]
3061    async fn manager_resolves_auth_and_replays_connect() {
3062        let factory = RecordingTransportFactory::new(vec![vec![
3063            json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "recording", "version": "1.0.0" } } }),
3064            json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [] } }),
3065            json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3066            json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3067        ]]);
3068        let server_id = McpServerId::new("recording");
3069        let mut manager = McpServerManager::new().with_server(McpServerConfig::new(
3070            server_id.to_string(),
3071            McpTransportBinding::Custom(Arc::new(factory.clone())),
3072        ));
3073        let mut auth = MetadataMap::new();
3074        auth.insert("token".into(), json!("connect-token"));
3075        let request = AuthRequest {
3076            task_id: None,
3077            id: "auth-recording-connect-replay".into(),
3078            provider: "mcp.recording".into(),
3079            operation: AuthOperation::McpConnect {
3080                server_id: server_id.to_string(),
3081                metadata: MetadataMap::new(),
3082            },
3083            challenge: MetadataMap::new(),
3084        };
3085
3086        let result = manager
3087            .resolve_auth_and_resume(agentkit_tools_core::AuthResolution::Provided {
3088                request,
3089                credentials: auth,
3090            })
3091            .await
3092            .unwrap();
3093
3094        match result {
3095            McpOperationResult::Connected(snapshot) => {
3096                assert_eq!(snapshot.server_id, server_id);
3097            }
3098            other => panic!("unexpected replay result: {other:?}"),
3099        }
3100    }
3101
3102    #[tokio::test]
3103    async fn sse_transport_posts_messages_and_receives_frames() {
3104        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3105        let address = listener.local_addr().unwrap();
3106        let requests = StdArc::new(StdMutex::new(Vec::new()));
3107        let captured = requests.clone();
3108
3109        let server = tokio::spawn(async move {
3110            for _ in 0..2 {
3111                let (mut socket, _) = listener.accept().await.unwrap();
3112                let mut buffer = vec![0_u8; 4096];
3113                let read = socket.read(&mut buffer).await.unwrap();
3114                let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3115
3116                if request.starts_with("GET /sse ") {
3117                    let body = concat!(
3118                        "event: endpoint\n",
3119                        "data: /messages\n\n",
3120                        "event: message\n",
3121                        "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}\n\n"
3122                    );
3123                    let response = format!(
3124                        "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3125                        body.len(),
3126                        body
3127                    );
3128                    socket.write_all(response.as_bytes()).await.unwrap();
3129                } else {
3130                    captured.lock().unwrap().push(request);
3131                    socket
3132                        .write_all(
3133                            b"HTTP/1.1 202 Accepted\r\ncontent-length: 0\r\nconnection: close\r\n\r\n",
3134                        )
3135                        .await
3136                        .unwrap();
3137                }
3138            }
3139        });
3140
3141        let factory =
3142            SseTransportFactory::new(SseTransportConfig::new(format!("http://{address}/sse")));
3143        let mut transport = factory.connect().await.unwrap();
3144        transport
3145            .send(McpFrame {
3146                value: json!({
3147                    "jsonrpc": "2.0",
3148                    "id": 1,
3149                    "method": "tools/list",
3150                    "params": {}
3151                }),
3152            })
3153            .await
3154            .unwrap();
3155        let frame = transport.recv().await.unwrap().unwrap();
3156        transport.close().await.unwrap();
3157        server.await.unwrap();
3158
3159        assert_eq!(frame.value["result"]["tools"], json!([]));
3160        let requests = requests.lock().unwrap();
3161        assert_eq!(requests.len(), 1);
3162        assert!(requests[0].starts_with("POST /messages "));
3163        assert!(requests[0].contains("\"method\":\"tools/list\""));
3164    }
3165
3166    #[tokio::test]
3167    async fn streamable_http_connection_tracks_session_and_protocol_headers() {
3168        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3169        let address = listener.local_addr().unwrap();
3170        let requests = StdArc::new(StdMutex::new(Vec::new()));
3171        let captured = requests.clone();
3172
3173        let server = tokio::spawn(async move {
3174            for _ in 0..4 {
3175                let (mut socket, _) = listener.accept().await.unwrap();
3176                let mut buffer = vec![0_u8; 8192];
3177                let read = socket.read(&mut buffer).await.unwrap();
3178                let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3179                captured.lock().unwrap().push(request.clone());
3180
3181                let response = if request.contains("\"method\":\"initialize\"") {
3182                    let body = "{\"jsonrpc\":\"2.0\",\"id\":0,\"result\":{\"protocolVersion\":\"2025-11-25\",\"capabilities\":{},\"serverInfo\":{\"name\":\"remote\",\"version\":\"1.0.0\"}}}";
3183                    format!(
3184                        "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\nMCP-Session-Id: session-123\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3185                        body.len(),
3186                        body
3187                    )
3188                } else if request.contains("\"method\":\"notifications/initialized\"") {
3189                    "HTTP/1.1 202 Accepted\r\ncontent-length: 0\r\nconnection: close\r\n\r\n"
3190                        .to_string()
3191                } else if request.starts_with("DELETE /mcp ") {
3192                    "HTTP/1.1 204 No Content\r\ncontent-length: 0\r\nconnection: close\r\n\r\n"
3193                        .to_string()
3194                } else {
3195                    let body = "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}";
3196                    format!(
3197                        "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3198                        body.len(),
3199                        body
3200                    )
3201                };
3202
3203                socket.write_all(response.as_bytes()).await.unwrap();
3204            }
3205        });
3206
3207        let config = McpServerConfig::new(
3208            "remote",
3209            McpTransportBinding::StreamableHttp(StreamableHttpTransportConfig::new(format!(
3210                "http://{address}/mcp"
3211            ))),
3212        );
3213        let connection = McpConnection::connect(&config).await.unwrap();
3214        let _ = connection.list_tools().await.unwrap();
3215        connection.close().await.unwrap();
3216        server.await.unwrap();
3217
3218        let requests = requests.lock().unwrap();
3219        assert_eq!(requests.len(), 4);
3220        let normalized = requests
3221            .iter()
3222            .map(|request| request.to_ascii_lowercase())
3223            .collect::<Vec<_>>();
3224        assert!(requests[0].starts_with("POST /mcp "));
3225        assert!(!requests[0].contains("MCP-Session-Id:"));
3226        assert!(normalized[1].contains("mcp-session-id: session-123"));
3227        assert!(normalized[1].contains("mcp-protocol-version: 2025-11-25"));
3228        assert!(normalized[2].contains("mcp-session-id: session-123"));
3229        assert!(normalized[2].contains("mcp-protocol-version: 2025-11-25"));
3230        assert!(requests[3].starts_with("DELETE /mcp "));
3231        assert!(normalized[3].contains("mcp-session-id: session-123"));
3232    }
3233
3234    #[tokio::test]
3235    async fn streamable_http_transport_resumes_sse_streams_until_response_arrives() {
3236        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3237        let address = listener.local_addr().unwrap();
3238        let requests = StdArc::new(StdMutex::new(Vec::new()));
3239        let captured = requests.clone();
3240
3241        let server = tokio::spawn(async move {
3242            for _ in 0..2 {
3243                let (mut socket, _) = listener.accept().await.unwrap();
3244                let mut buffer = vec![0_u8; 8192];
3245                let read = socket.read(&mut buffer).await.unwrap();
3246                let request = String::from_utf8_lossy(&buffer[..read]).to_string();
3247                captured.lock().unwrap().push(request.clone());
3248
3249                let response = if request.starts_with("POST /mcp ") {
3250                    let body = concat!(
3251                        "id: evt-1\n",
3252                        "event: message\n",
3253                        "data: {\"jsonrpc\":\"2.0\",\"method\":\"notifications/message\",\"params\":{\"phase\":\"stream-start\"}}\n\n"
3254                    );
3255                    format!(
3256                        "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3257                        body.len(),
3258                        body
3259                    )
3260                } else {
3261                    let body = concat!(
3262                        "id: evt-2\n",
3263                        "event: message\n",
3264                        "data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}\n\n"
3265                    );
3266                    format!(
3267                        "HTTP/1.1 200 OK\r\ncontent-type: text/event-stream\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}",
3268                        body.len(),
3269                        body
3270                    )
3271                };
3272
3273                socket.write_all(response.as_bytes()).await.unwrap();
3274            }
3275        });
3276
3277        let factory = StreamableHttpTransportFactory::new(StreamableHttpTransportConfig::new(
3278            format!("http://{address}/mcp"),
3279        ));
3280        let mut transport = factory.connect().await.unwrap();
3281        transport
3282            .send(McpFrame {
3283                value: json!({
3284                    "jsonrpc": "2.0",
3285                    "id": 1,
3286                    "method": "tools/list",
3287                    "params": {}
3288                }),
3289            })
3290            .await
3291            .unwrap();
3292
3293        let first = transport.recv().await.unwrap().unwrap();
3294        let second = transport.recv().await.unwrap().unwrap();
3295        transport.close().await.unwrap();
3296        server.await.unwrap();
3297
3298        assert_eq!(
3299            first.value["method"],
3300            Value::String("notifications/message".into())
3301        );
3302        assert_eq!(second.value["result"]["tools"], json!([]));
3303
3304        let requests = requests.lock().unwrap();
3305        assert_eq!(requests.len(), 2);
3306        assert!(requests[0].starts_with("POST /mcp "));
3307        assert!(requests[1].starts_with("GET /mcp "));
3308        assert!(
3309            requests[1].contains("last-event-id: evt-1")
3310                || requests[1].contains("Last-Event-ID: evt-1")
3311        );
3312    }
3313
3314    #[tokio::test]
3315    async fn server_manager_connects_refreshes_and_aggregates_tools() {
3316        let alpha = McpServerConfig::new(
3317            "alpha",
3318            McpTransportBinding::Custom(Arc::new(FakeTransportFactory::new(vec![vec![
3319                json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "alpha", "version": "1.0.0" } } }),
3320                json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "echo", "description": "Echo", "inputSchema": {"type": "object"} }] } }),
3321                json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3322                json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3323                json!({ "jsonrpc": "2.0", "id": 4, "result": { "tools": [{ "name": "echo_v2", "description": "Echo 2", "inputSchema": {"type": "object"} }] } }),
3324                json!({ "jsonrpc": "2.0", "id": 5, "result": { "resources": [] } }),
3325                json!({ "jsonrpc": "2.0", "id": 6, "result": { "prompts": [] } }),
3326            ]]))),
3327        );
3328        let beta = McpServerConfig::new(
3329            "beta",
3330            McpTransportBinding::Custom(Arc::new(FakeTransportFactory::new(vec![vec![
3331                json!({ "jsonrpc": "2.0", "id": 0, "result": { "protocolVersion": "2025-11-25", "capabilities": {}, "serverInfo": { "name": "beta", "version": "1.0.0" } } }),
3332                json!({ "jsonrpc": "2.0", "id": 1, "result": { "tools": [{ "name": "search", "description": "Search", "inputSchema": {"type": "object"} }] } }),
3333                json!({ "jsonrpc": "2.0", "id": 2, "result": { "resources": [] } }),
3334                json!({ "jsonrpc": "2.0", "id": 3, "result": { "prompts": [] } }),
3335            ]]))),
3336        );
3337
3338        let mut manager = McpServerManager::new().with_server(alpha).with_server(beta);
3339
3340        let handles = manager.connect_all().await.unwrap();
3341        assert_eq!(handles.len(), 2);
3342        assert_eq!(
3343            manager
3344                .tool_registry()
3345                .specs()
3346                .into_iter()
3347                .map(|spec| spec.name.0)
3348                .collect::<Vec<_>>(),
3349            vec!["mcp_alpha_echo".to_string(), "mcp_beta_search".to_string()]
3350        );
3351
3352        let refreshed = manager
3353            .refresh_server(&McpServerId::new("alpha"))
3354            .await
3355            .unwrap();
3356        assert_eq!(refreshed.tools[0].name, "echo_v2");
3357        assert_eq!(
3358            manager
3359                .connected_server(&McpServerId::new("alpha"))
3360                .unwrap()
3361                .snapshot()
3362                .tools[0]
3363                .name,
3364            "echo_v2"
3365        );
3366
3367        let capabilities = manager.capability_provider();
3368        assert_eq!(capabilities.invocables().len(), 2);
3369
3370        manager
3371            .disconnect_server(&McpServerId::new("alpha"))
3372            .await
3373            .unwrap();
3374        assert!(
3375            manager
3376                .connected_server(&McpServerId::new("alpha"))
3377                .is_none()
3378        );
3379    }
3380}