Skip to main content

agent_sdk/mcp/
client.rs

1//! MCP client implementation.
2
3use anyhow::{Context, Result, bail};
4use serde_json::{Value, json};
5use std::sync::Arc;
6
7use super::protocol::JsonRpcRequest;
8use super::protocol::{
9    ClientCapabilities, ClientInfo, InitializeParams, InitializeResult, McpPrompt, McpResource,
10    McpToolCallResult, McpToolDefinition, PREFERRED_PROTOCOL_VERSION, PromptGetParams,
11    PromptGetResult, PromptsListResult, ResourceReadParams, ResourceReadResult,
12    ResourcesListResult, ToolCallParams, ToolsListResult, is_known_protocol_version,
13};
14use super::transport::McpTransport;
15
16/// MCP protocol revision this client advertises during `initialize`.
17///
18/// Retained as a public alias of [`PREFERRED_PROTOCOL_VERSION`] for backwards
19/// compatibility. The revision actually used for a connection is whatever the
20/// server selects during the handshake — see [`McpClient::protocol_version`].
21pub const MCP_PROTOCOL_VERSION: &str = PREFERRED_PROTOCOL_VERSION;
22
23/// MCP client for communicating with MCP servers.
24///
25/// The client handles the MCP protocol, including initialization,
26/// tool discovery, and tool execution.
27///
28/// # Example
29///
30/// ```ignore
31/// use agent_sdk::mcp::{McpClient, StdioTransport};
32///
33/// // Spawn server and create client
34/// let transport = StdioTransport::spawn("npx", &["-y", "mcp-server"]).await?;
35/// let client = McpClient::new(transport, "my-server".to_string()).await?;
36///
37/// // List available tools
38/// let tools = client.list_tools().await?;
39///
40/// // Call a tool
41/// let result = client.call_tool("tool_name", json!({"arg": "value"})).await?;
42/// ```
43pub struct McpClient<T: McpTransport> {
44    transport: Arc<T>,
45    server_name: String,
46    server_info: Option<InitializeResult>,
47    /// Protocol revision selected by the server during `initialize`.
48    negotiated_version: Option<String>,
49}
50
51impl<T: McpTransport> McpClient<T> {
52    /// Create a new MCP client and initialize the connection.
53    ///
54    /// # Arguments
55    ///
56    /// * `transport` - The transport to use for communication
57    /// * `server_name` - A name to identify this server connection
58    ///
59    /// # Errors
60    ///
61    /// Returns an error if initialization fails.
62    pub async fn new(transport: Arc<T>, server_name: String) -> Result<Self> {
63        let mut client = Self {
64            transport,
65            server_name,
66            server_info: None,
67            negotiated_version: None,
68        };
69
70        client.initialize().await?;
71
72        Ok(client)
73    }
74
75    /// Create a client without initialization.
76    ///
77    /// Use this if you need to control when initialization happens.
78    #[must_use]
79    pub const fn new_uninitialized(transport: Arc<T>, server_name: String) -> Self {
80        Self {
81            transport,
82            server_name,
83            server_info: None,
84            negotiated_version: None,
85        }
86    }
87
88    /// Initialize the MCP connection.
89    ///
90    /// This must be called before using other methods.
91    ///
92    /// # Errors
93    ///
94    /// Returns an error if the server rejects initialization.
95    pub async fn initialize(&mut self) -> Result<&InitializeResult> {
96        #[cfg(feature = "otel")]
97        let started_at = std::time::Instant::now();
98        #[cfg(feature = "otel")]
99        let mut span = {
100            use crate::observability::langfuse;
101            let mut span = start_mcp_span("mcp.initialize", &self.server_name);
102            langfuse::tag_observation(&mut span, langfuse::ObservationType::Chain);
103            span
104        };
105
106        let result = self.initialize_inner().await;
107
108        #[cfg(feature = "otel")]
109        finish_mcp_span(
110            &mut span,
111            &result,
112            "initialize",
113            &self.server_name,
114            started_at,
115        );
116
117        result?;
118
119        self.server_info
120            .as_ref()
121            .context("Server info not available")
122    }
123
124    async fn initialize_inner(&mut self) -> Result<()> {
125        let params = InitializeParams {
126            protocol_version: PREFERRED_PROTOCOL_VERSION.to_string(),
127            capabilities: ClientCapabilities::default(),
128            client_info: ClientInfo {
129                name: "agent-sdk".to_string(),
130                version: env!("CARGO_PKG_VERSION").to_string(),
131            },
132        };
133
134        let request = JsonRpcRequest::new("initialize", Some(serde_json::to_value(&params)?), 0);
135
136        let response = self.transport.send(request).await?;
137
138        let result: InitializeResult = response
139            .result
140            .map(serde_json::from_value)
141            .transpose()
142            .context("Failed to parse initialize response")?
143            .context("Initialize response missing result")?;
144
145        // Honour the revision the server actually selected. The server may
146        // downgrade to an older revision (e.g. a legacy `2024-11-05` server);
147        // we adapt to its choice rather than insisting on our preference. An
148        // unrecognised revision is not fatal — proceed but log it.
149        let negotiated = result.protocol_version.clone();
150        if !is_known_protocol_version(&negotiated) {
151            log::warn!(
152                "MCP server '{}' negotiated unknown protocol revision '{}' (advertised '{}')",
153                self.server_name,
154                negotiated,
155                PREFERRED_PROTOCOL_VERSION,
156            );
157        }
158        // Inform the transport so out-of-band carriers (HTTP header) can use it.
159        self.transport.set_protocol_version(&negotiated).await;
160        self.negotiated_version = Some(negotiated);
161
162        // Send initialized notification (fire-and-forget)
163        let notification = JsonRpcRequest::new("notifications/initialized", None, 0);
164        let _ = self.transport.send_notification(notification).await;
165
166        self.server_info = Some(result);
167        Ok(())
168    }
169
170    /// Get the server name.
171    #[must_use]
172    pub fn server_name(&self) -> &str {
173        &self.server_name
174    }
175
176    /// Get server info if initialized.
177    #[must_use]
178    pub const fn server_info(&self) -> Option<&InitializeResult> {
179        self.server_info.as_ref()
180    }
181
182    /// The MCP protocol revision negotiated with the server.
183    ///
184    /// Returns `None` until [`McpClient::initialize`] has completed. This is
185    /// the revision the *server* selected, which may be older than
186    /// [`PREFERRED_PROTOCOL_VERSION`] if the server is on a legacy build.
187    #[must_use]
188    pub fn protocol_version(&self) -> Option<&str> {
189        self.negotiated_version.as_deref()
190    }
191
192    /// List available tools from the server.
193    ///
194    /// # Errors
195    ///
196    /// Returns an error if the request fails.
197    pub async fn list_tools(&self) -> Result<Vec<McpToolDefinition>> {
198        #[cfg(feature = "otel")]
199        let started_at = std::time::Instant::now();
200        #[cfg(feature = "otel")]
201        let mut span = {
202            use crate::observability::langfuse;
203            let mut span = start_mcp_span("mcp.tools/list", &self.server_name);
204            langfuse::tag_observation(&mut span, langfuse::ObservationType::Chain);
205            span
206        };
207
208        let result = self.list_tools_inner().await;
209
210        #[cfg(feature = "otel")]
211        {
212            use opentelemetry::KeyValue;
213            use opentelemetry::trace::Span;
214            if let Ok(ref tools) = result {
215                span.set_attribute(KeyValue::new(
216                    "mcp.tools.count",
217                    i64::try_from(tools.len()).unwrap_or(0),
218                ));
219            }
220            finish_mcp_span(
221                &mut span,
222                &result,
223                "tools/list",
224                &self.server_name,
225                started_at,
226            );
227        }
228
229        result
230    }
231
232    async fn list_tools_inner(&self) -> Result<Vec<McpToolDefinition>> {
233        let request = JsonRpcRequest::new("tools/list", None, 0);
234
235        let response = self.transport.send(request).await?;
236
237        let result: ToolsListResult = response
238            .result
239            .map(serde_json::from_value)
240            .transpose()
241            .context("Failed to parse tools/list response")?
242            .context("tools/list response missing result")?;
243
244        Ok(result.tools)
245    }
246
247    /// Call a tool on the server.
248    ///
249    /// # Arguments
250    ///
251    /// * `name` - Tool name to call
252    /// * `arguments` - Tool arguments as JSON
253    ///
254    /// # Errors
255    ///
256    /// Returns an error if the tool call fails.
257    pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<McpToolCallResult> {
258        #[cfg(feature = "otel")]
259        let started_at = std::time::Instant::now();
260        #[cfg(feature = "otel")]
261        let mut span = {
262            use crate::observability::langfuse;
263            use opentelemetry::KeyValue;
264            let mut span = start_mcp_span_with_attrs(
265                "mcp.tools/call",
266                vec![
267                    KeyValue::new("mcp.server.name", self.server_name.clone()),
268                    KeyValue::new("gen_ai.tool.name", name.to_string()),
269                ],
270            );
271            langfuse::tag_observation(&mut span, langfuse::ObservationType::Tool);
272            span
273        };
274
275        let result = self.call_tool_inner(name, arguments).await;
276
277        #[cfg(feature = "otel")]
278        finish_mcp_call_tool_span(
279            &mut span,
280            &result,
281            "tools/call",
282            &self.server_name,
283            started_at,
284        );
285
286        result
287    }
288
289    async fn call_tool_inner(&self, name: &str, arguments: Value) -> Result<McpToolCallResult> {
290        let params = ToolCallParams {
291            name: name.to_string(),
292            arguments: Some(arguments),
293        };
294
295        let request = JsonRpcRequest::new("tools/call", Some(serde_json::to_value(&params)?), 0);
296
297        let response = self.transport.send(request).await?;
298
299        if let Some(ref error) = response.error {
300            bail!("Tool call failed: {} (code {})", error.message, error.code);
301        }
302
303        let result: McpToolCallResult = response
304            .result
305            .map(serde_json::from_value)
306            .transpose()
307            .context("Failed to parse tools/call response")?
308            .context("tools/call response missing result")?;
309
310        Ok(result)
311    }
312
313    /// Call a tool with raw Value arguments.
314    ///
315    /// # Arguments
316    ///
317    /// * `name` - Tool name to call
318    /// * `arguments` - Tool arguments as optional JSON
319    ///
320    /// # Errors
321    ///
322    /// Returns an error if the tool call fails.
323    pub async fn call_tool_raw(
324        &self,
325        name: &str,
326        arguments: Option<Value>,
327    ) -> Result<McpToolCallResult> {
328        let args = arguments.unwrap_or_else(|| json!({}));
329        self.call_tool(name, args).await
330    }
331
332    /// List resources exposed by the server (`resources/list`).
333    ///
334    /// Resources are addressable data (files, database rows, API payloads) the
335    /// server makes available for reading. Returns an empty list if the server
336    /// did not advertise the `resources` capability.
337    ///
338    /// # Errors
339    ///
340    /// Returns an error if the request fails or the response cannot be parsed.
341    pub async fn list_resources(&self) -> Result<Vec<McpResource>> {
342        if !self.supports_resources() {
343            return Ok(Vec::new());
344        }
345        #[cfg(feature = "otel")]
346        let started_at = std::time::Instant::now();
347        #[cfg(feature = "otel")]
348        let mut span = {
349            use crate::observability::langfuse;
350            let mut span = start_mcp_span("mcp.resources/list", &self.server_name);
351            langfuse::tag_observation(&mut span, langfuse::ObservationType::Chain);
352            span
353        };
354
355        let result = self.list_resources_inner().await;
356
357        #[cfg(feature = "otel")]
358        finish_mcp_span(
359            &mut span,
360            &result,
361            "resources/list",
362            &self.server_name,
363            started_at,
364        );
365
366        result
367    }
368
369    async fn list_resources_inner(&self) -> Result<Vec<McpResource>> {
370        let request = JsonRpcRequest::new("resources/list", None, 0);
371        let response = self.transport.send(request).await?;
372        let result: ResourcesListResult = response
373            .result
374            .map(serde_json::from_value)
375            .transpose()
376            .context("Failed to parse resources/list response")?
377            .context("resources/list response missing result")?;
378        Ok(result.resources)
379    }
380
381    /// Read a resource by URI (`resources/read`).
382    ///
383    /// # Errors
384    ///
385    /// Returns an error if the request fails or the response cannot be parsed.
386    pub async fn read_resource(&self, uri: &str) -> Result<ResourceReadResult> {
387        #[cfg(feature = "otel")]
388        let started_at = std::time::Instant::now();
389        #[cfg(feature = "otel")]
390        let mut span = {
391            use crate::observability::langfuse;
392            let mut span = start_mcp_span("mcp.resources/read", &self.server_name);
393            langfuse::tag_observation(&mut span, langfuse::ObservationType::Chain);
394            span
395        };
396
397        let result = self.read_resource_inner(uri).await;
398
399        #[cfg(feature = "otel")]
400        finish_mcp_span(
401            &mut span,
402            &result,
403            "resources/read",
404            &self.server_name,
405            started_at,
406        );
407
408        result
409    }
410
411    async fn read_resource_inner(&self, uri: &str) -> Result<ResourceReadResult> {
412        let params = ResourceReadParams {
413            uri: uri.to_string(),
414        };
415        let request =
416            JsonRpcRequest::new("resources/read", Some(serde_json::to_value(&params)?), 0);
417        let response = self.transport.send(request).await?;
418        let result: ResourceReadResult = response
419            .result
420            .map(serde_json::from_value)
421            .transpose()
422            .context("Failed to parse resources/read response")?
423            .context("resources/read response missing result")?;
424        Ok(result)
425    }
426
427    /// List prompts exposed by the server (`prompts/list`).
428    ///
429    /// Returns an empty list if the server did not advertise the `prompts`
430    /// capability.
431    ///
432    /// # Errors
433    ///
434    /// Returns an error if the request fails or the response cannot be parsed.
435    pub async fn list_prompts(&self) -> Result<Vec<McpPrompt>> {
436        if !self.supports_prompts() {
437            return Ok(Vec::new());
438        }
439        #[cfg(feature = "otel")]
440        let started_at = std::time::Instant::now();
441        #[cfg(feature = "otel")]
442        let mut span = {
443            use crate::observability::langfuse;
444            let mut span = start_mcp_span("mcp.prompts/list", &self.server_name);
445            langfuse::tag_observation(&mut span, langfuse::ObservationType::Chain);
446            span
447        };
448
449        let result = self.list_prompts_inner().await;
450
451        #[cfg(feature = "otel")]
452        finish_mcp_span(
453            &mut span,
454            &result,
455            "prompts/list",
456            &self.server_name,
457            started_at,
458        );
459
460        result
461    }
462
463    async fn list_prompts_inner(&self) -> Result<Vec<McpPrompt>> {
464        let request = JsonRpcRequest::new("prompts/list", None, 0);
465        let response = self.transport.send(request).await?;
466        let result: PromptsListResult = response
467            .result
468            .map(serde_json::from_value)
469            .transpose()
470            .context("Failed to parse prompts/list response")?
471            .context("prompts/list response missing result")?;
472        Ok(result.prompts)
473    }
474
475    /// Fetch and render a prompt by name (`prompts/get`).
476    ///
477    /// # Arguments
478    ///
479    /// * `name` - Prompt name to fetch.
480    /// * `arguments` - Optional arguments to interpolate into the template.
481    ///
482    /// # Errors
483    ///
484    /// Returns an error if the request fails or the response cannot be parsed.
485    pub async fn get_prompt(
486        &self,
487        name: &str,
488        arguments: Option<Value>,
489    ) -> Result<PromptGetResult> {
490        #[cfg(feature = "otel")]
491        let started_at = std::time::Instant::now();
492        #[cfg(feature = "otel")]
493        let mut span = {
494            use crate::observability::langfuse;
495            let mut span = start_mcp_span("mcp.prompts/get", &self.server_name);
496            langfuse::tag_observation(&mut span, langfuse::ObservationType::Chain);
497            span
498        };
499
500        let result = self.get_prompt_inner(name, arguments).await;
501
502        #[cfg(feature = "otel")]
503        finish_mcp_span(
504            &mut span,
505            &result,
506            "prompts/get",
507            &self.server_name,
508            started_at,
509        );
510
511        result
512    }
513
514    async fn get_prompt_inner(
515        &self,
516        name: &str,
517        arguments: Option<Value>,
518    ) -> Result<PromptGetResult> {
519        let params = PromptGetParams {
520            name: name.to_string(),
521            arguments,
522        };
523        let request = JsonRpcRequest::new("prompts/get", Some(serde_json::to_value(&params)?), 0);
524        let response = self.transport.send(request).await?;
525        let result: PromptGetResult = response
526            .result
527            .map(serde_json::from_value)
528            .transpose()
529            .context("Failed to parse prompts/get response")?
530            .context("prompts/get response missing result")?;
531        Ok(result)
532    }
533
534    /// Whether the server advertised the `resources` capability.
535    #[must_use]
536    pub fn supports_resources(&self) -> bool {
537        self.server_info
538            .as_ref()
539            .is_some_and(|info| info.capabilities.resources.is_some())
540    }
541
542    /// Whether the server advertised the `prompts` capability.
543    #[must_use]
544    pub fn supports_prompts(&self) -> bool {
545        self.server_info
546            .as_ref()
547            .is_some_and(|info| info.capabilities.prompts.is_some())
548    }
549
550    /// Close the client connection.
551    ///
552    /// # Errors
553    ///
554    /// Returns an error if the transport fails to close.
555    pub async fn close(&self) -> Result<()> {
556        self.transport.close().await
557    }
558}
559
560#[cfg(feature = "otel")]
561fn start_mcp_span(
562    name: impl Into<std::borrow::Cow<'static, str>>,
563    server_name: &str,
564) -> opentelemetry::global::BoxedSpan {
565    use opentelemetry::KeyValue;
566    start_mcp_span_with_attrs(
567        name,
568        vec![KeyValue::new("mcp.server.name", server_name.to_string())],
569    )
570}
571
572#[cfg(feature = "otel")]
573fn start_mcp_span_with_attrs(
574    name: impl Into<std::borrow::Cow<'static, str>>,
575    attrs: Vec<opentelemetry::KeyValue>,
576) -> opentelemetry::global::BoxedSpan {
577    use crate::observability::{baggage, spans};
578    let mut span = spans::start_client_span(name, attrs);
579    baggage::copy_baggage_to_active_span(&mut span);
580    span
581}
582
583#[cfg(feature = "otel")]
584fn finish_mcp_span<T>(
585    span: &mut opentelemetry::global::BoxedSpan,
586    result: &Result<T>,
587    method: &'static str,
588    server_name: &str,
589    started_at: std::time::Instant,
590) {
591    use crate::observability::{metrics, spans};
592    use opentelemetry::KeyValue;
593    use opentelemetry::trace::Span;
594
595    let mut metric_attrs = vec![
596        KeyValue::new("mcp.method", method),
597        KeyValue::new("mcp.server.name", server_name.to_string()),
598    ];
599    if let Err(err) = result {
600        spans::set_span_error(span, "mcp_error", &format!("{err}"));
601        metric_attrs.push(KeyValue::new(
602            crate::observability::attrs::ERROR_TYPE,
603            "mcp_error",
604        ));
605    }
606    let elapsed_secs = started_at.elapsed().as_secs_f64();
607    metrics::Metrics::global()
608        .mcp_requests_duration
609        .record(elapsed_secs, &metric_attrs);
610    span.end();
611}
612
613#[cfg(feature = "otel")]
614fn finish_mcp_call_tool_span(
615    span: &mut opentelemetry::global::BoxedSpan,
616    result: &Result<super::protocol::McpToolCallResult>,
617    method: &'static str,
618    server_name: &str,
619    started_at: std::time::Instant,
620) {
621    use crate::observability::{metrics, spans};
622    use opentelemetry::KeyValue;
623    use opentelemetry::trace::Span;
624
625    let mut metric_attrs = vec![
626        KeyValue::new("mcp.method", method),
627        KeyValue::new("mcp.server.name", server_name.to_string()),
628    ];
629    let error_kind: Option<&'static str> = match result {
630        Ok(tool_result) if tool_result.is_error => {
631            let error_text = tool_result
632                .content
633                .iter()
634                .find_map(|c| match c {
635                    super::protocol::McpContent::Text { text } => Some(text.as_str()),
636                    _ => None,
637                })
638                .unwrap_or("MCP tool returned error");
639            spans::set_span_error(span, "tool_error", error_text);
640            Some("tool_error")
641        }
642        Err(err) => {
643            spans::set_span_error(span, "mcp_error", &format!("{err}"));
644            Some("mcp_error")
645        }
646        Ok(_) => None,
647    };
648    if let Some(kind) = error_kind {
649        metric_attrs.push(KeyValue::new(crate::observability::attrs::ERROR_TYPE, kind));
650    }
651    let elapsed_secs = started_at.elapsed().as_secs_f64();
652    metrics::Metrics::global()
653        .mcp_requests_duration
654        .record(elapsed_secs, &metric_attrs);
655    span.end();
656}
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661
662    #[test]
663    fn test_mcp_protocol_version() {
664        assert!(!MCP_PROTOCOL_VERSION.is_empty());
665    }
666
667    #[test]
668    fn test_client_info() {
669        let info = ClientInfo {
670            name: "test".to_string(),
671            version: "1.0.0".to_string(),
672        };
673
674        let json = serde_json::to_string(&info).expect("serialize");
675        assert!(json.contains("test"));
676        assert!(json.contains("1.0.0"));
677    }
678}