Skip to main content

synwire_mcp_adapters/
client.rs

1//! Multi-server MCP client.
2//!
3//! [`MultiServerMcpClient`] connects to N named MCP servers simultaneously,
4//! aggregates their tools, and routes tool calls to the correct server.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use futures_util::future::join_all;
10use serde_json::Value;
11use synwire_core::mcp::traits::{McpServerStatus, McpTransport};
12use tokio::sync::RwLock;
13
14use crate::callbacks::McpCallbacks;
15use crate::error::McpAdapterError;
16use crate::session::McpClientSession;
17
18// ---------------------------------------------------------------------------
19// Connection configuration
20// ---------------------------------------------------------------------------
21
22/// Configuration for connecting to a single MCP server.
23///
24/// Each variant describes a different transport mechanism.
25#[derive(Debug, Clone)]
26#[non_exhaustive]
27pub enum Connection {
28    /// Launch a subprocess and communicate over its stdin/stdout.
29    Stdio {
30        /// Executable path.
31        command: String,
32        /// Command-line arguments.
33        args: Vec<String>,
34        /// Environment variables.
35        env: HashMap<String, String>,
36    },
37
38    /// Connect via Server-Sent Events (SSE) transport.
39    Sse {
40        /// SSE endpoint URL.
41        url: String,
42        /// Optional Bearer token.
43        auth_token: Option<String>,
44        /// Connection timeout in seconds.
45        timeout_secs: Option<u64>,
46    },
47
48    /// Connect via Streamable HTTP (MCP 2025-03-26 spec).
49    StreamableHttp {
50        /// HTTP endpoint URL.
51        url: String,
52        /// Optional Bearer token.
53        auth_token: Option<String>,
54        /// Connection timeout in seconds.
55        timeout_secs: Option<u64>,
56    },
57
58    /// Connect via WebSocket.
59    WebSocket {
60        /// WebSocket URL (ws:// or wss://).
61        url: String,
62        /// Optional Bearer token.
63        auth_token: Option<String>,
64    },
65}
66
67impl Connection {
68    /// Creates a transport for this connection configuration.
69    ///
70    /// Returns a `Box<dyn McpTransport>` suitable for use with
71    /// [`McpClientSession`].
72    ///
73    /// # Errors
74    ///
75    /// Returns an error if the underlying transport cannot be initialised
76    /// (e.g. HTTP client TLS failure).
77    pub fn into_transport(
78        self,
79        name: &str,
80    ) -> Result<Box<dyn McpTransport>, synwire_core::agents::error::AgentError> {
81        match self {
82            Self::Stdio { command, args, env } => Ok(Box::new(
83                synwire_agent::mcp::StdioMcpTransport::new(name, command, args, env),
84            )),
85            Self::Sse {
86                url,
87                auth_token,
88                timeout_secs,
89            }
90            | Self::StreamableHttp {
91                url,
92                auth_token,
93                timeout_secs,
94            } => Ok(Box::new(synwire_agent::mcp::HttpMcpTransport::try_new(
95                name,
96                url,
97                auth_token,
98                timeout_secs,
99            )?)),
100            Self::WebSocket { url, auth_token } => Ok(Box::new(
101                crate::transport::WebSocketMcpTransport::new(name, url, auth_token),
102            )),
103        }
104    }
105}
106
107// ---------------------------------------------------------------------------
108// Server entry
109// ---------------------------------------------------------------------------
110
111struct ServerEntry {
112    session: McpClientSession,
113    /// Optional prefix applied to all tool names from this server.
114    tool_name_prefix: Option<String>,
115}
116
117impl std::fmt::Debug for ServerEntry {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        f.debug_struct("ServerEntry")
120            .field("session", &self.session)
121            .field("tool_name_prefix", &self.tool_name_prefix)
122            .finish()
123    }
124}
125
126// ---------------------------------------------------------------------------
127// MultiServerMcpClient
128// ---------------------------------------------------------------------------
129
130/// Configuration used to build a [`MultiServerMcpClient`].
131#[derive(Debug, Default)]
132pub struct MultiServerMcpClientConfig {
133    /// Named server connections.
134    pub servers: HashMap<String, Connection>,
135    /// Optional prefix applied to all aggregated tool names.
136    ///
137    /// Per-server prefixes can be set via [`with_server_prefix`](Self::with_server_prefix).
138    pub global_tool_prefix: Option<String>,
139    /// Per-server tool name prefixes (override the global prefix for a specific server).
140    pub server_prefixes: HashMap<String, String>,
141}
142
143impl MultiServerMcpClientConfig {
144    /// Creates an empty configuration.
145    #[must_use]
146    pub fn new() -> Self {
147        Self::default()
148    }
149
150    /// Adds a named server to the configuration.
151    #[must_use]
152    pub fn with_server(mut self, name: impl Into<String>, connection: Connection) -> Self {
153        let _ = self.servers.insert(name.into(), connection);
154        self
155    }
156
157    /// Sets a per-server tool name prefix.
158    #[must_use]
159    pub fn with_server_prefix(
160        mut self,
161        server_name: impl Into<String>,
162        prefix: impl Into<String>,
163    ) -> Self {
164        let _ = self
165            .server_prefixes
166            .insert(server_name.into(), prefix.into());
167        self
168    }
169
170    /// Sets the global tool name prefix applied to all servers that lack a
171    /// per-server prefix.
172    #[must_use]
173    pub fn with_global_prefix(mut self, prefix: impl Into<String>) -> Self {
174        self.global_tool_prefix = Some(prefix.into());
175        self
176    }
177}
178
179/// A client that connects to multiple MCP servers simultaneously and
180/// aggregates their tools under a unified interface.
181///
182/// # Connection
183///
184/// Call [`connect`](Self::connect) to establish connections to all configured
185/// servers in parallel. Tools become available immediately after connection.
186///
187/// # Tool naming
188///
189/// Each server may have an optional prefix. When a prefix is configured,
190/// tool names are exposed as `{prefix}/{tool_name}` to avoid collisions
191/// across servers. The original server-local name is preserved for routing.
192pub struct MultiServerMcpClient {
193    servers: Arc<RwLock<HashMap<String, ServerEntry>>>,
194    /// Callbacks for logging, progress, and elicitation events.
195    callbacks: Arc<McpCallbacks>,
196}
197
198impl std::fmt::Debug for MultiServerMcpClient {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        f.debug_struct("MultiServerMcpClient")
201            .field("callbacks", &self.callbacks)
202            .finish_non_exhaustive()
203    }
204}
205
206impl MultiServerMcpClient {
207    /// Connects to all servers in `config` simultaneously and returns a
208    /// fully initialised client.
209    ///
210    /// Connection failures for individual servers are logged and their status
211    /// will show as `Disconnected`. Call [`health`](Self::health) to inspect
212    /// per-server status.
213    ///
214    /// # Errors
215    ///
216    /// Returns [`McpAdapterError`] only if configuration is fundamentally
217    /// invalid. Individual server failures are accumulated silently.
218    pub async fn connect(
219        config: MultiServerMcpClientConfig,
220        callbacks: McpCallbacks,
221    ) -> Result<Self, McpAdapterError> {
222        let callbacks = Arc::new(callbacks);
223
224        // Destructure config to allow partial moves.
225        let MultiServerMcpClientConfig {
226            servers,
227            server_prefixes,
228            global_tool_prefix,
229        } = config;
230
231        // Connect all servers in parallel
232        let connect_futures: Vec<_> = servers
233            .into_iter()
234            .map(|(name, conn)| {
235                let prefix = server_prefixes
236                    .get(&name)
237                    .cloned()
238                    .or_else(|| global_tool_prefix.clone());
239                let transport_result = conn.into_transport(&name);
240                async move {
241                    let transport: Arc<dyn McpTransport> = match transport_result {
242                        Ok(t) => Arc::from(t),
243                        Err(e) => {
244                            tracing::error!(server = %name, error = %e, "Failed to build transport");
245                            return None;
246                        }
247                    };
248                    match McpClientSession::connect(name.clone(), transport).await {
249                        Ok(mut session) => {
250                            // Best-effort tool cache population
251                            if let Err(e) = session.populate_tool_cache().await {
252                                tracing::warn!(
253                                    server = %name,
254                                    error = %e,
255                                    "Failed to populate tool cache"
256                                );
257                            }
258                            Some((
259                                name,
260                                ServerEntry {
261                                    session,
262                                    tool_name_prefix: prefix,
263                                },
264                            ))
265                        }
266                        Err(e) => {
267                            tracing::error!(
268                                server = %name,
269                                error = %e,
270                                "Failed to connect to MCP server"
271                            );
272                            None
273                        }
274                    }
275                }
276            })
277            .collect();
278
279        let results = join_all(connect_futures).await;
280        let servers: HashMap<String, ServerEntry> = results.into_iter().flatten().collect();
281
282        tracing::info!(connected = servers.len(), "MultiServerMcpClient connected");
283
284        Ok(Self {
285            servers: Arc::new(RwLock::new(servers)),
286            callbacks,
287        })
288    }
289
290    /// Returns all aggregated tool descriptors from all connected servers.
291    ///
292    /// Tool names are prefixed when a server prefix is configured.
293    pub async fn get_tool_descriptors(&self) -> Vec<AggregatedToolDescriptor> {
294        let servers = self.servers.read().await;
295        let mut tools = Vec::new();
296
297        for (server_name, entry) in servers.iter() {
298            for descriptor in entry.session.cached_tools() {
299                let exposed_name = entry.tool_name_prefix.as_ref().map_or_else(
300                    || descriptor.name.clone(),
301                    |prefix| format!("{prefix}/{}", descriptor.name),
302                );
303                tools.push(AggregatedToolDescriptor {
304                    exposed_name,
305                    server_name: server_name.clone(),
306                    original_name: descriptor.name.clone(),
307                    description: descriptor.description.clone(),
308                    input_schema: descriptor.input_schema.clone(),
309                });
310            }
311        }
312        drop(servers);
313
314        tools
315    }
316
317    /// Returns health status for all servers.
318    #[allow(clippy::significant_drop_tightening)]
319    pub async fn health(&self) -> Vec<McpServerStatus> {
320        let servers = self.servers.read().await;
321        let status_futures: Vec<_> = servers
322            .values()
323            .map(|entry| entry.session.status())
324            .collect();
325        join_all(status_futures).await
326    }
327
328    /// Calls a tool by its exposed name (including any prefix).
329    ///
330    /// # Errors
331    ///
332    /// - [`McpAdapterError::ToolNotFound`] if no server exposes the given name.
333    /// - [`McpAdapterError::Transport`] if the tool call fails.
334    pub async fn call_tool(
335        &self,
336        exposed_tool_name: &str,
337        arguments: Value,
338    ) -> Result<Value, McpAdapterError> {
339        // Resolve routing and clone the transport Arc, then drop the lock
340        // before the async call to avoid holding the guard across await points.
341        let (server_name, original_name, transport) = {
342            let servers = self.servers.read().await;
343
344            let routing = servers.iter().find_map(|(server_name, entry)| {
345                for descriptor in entry.session.cached_tools() {
346                    let exposed = entry.tool_name_prefix.as_ref().map_or_else(
347                        || descriptor.name.clone(),
348                        |prefix| format!("{prefix}/{}", descriptor.name),
349                    );
350                    if exposed == exposed_tool_name {
351                        return Some((server_name.clone(), descriptor.name.clone()));
352                    }
353                }
354                None
355            });
356
357            let (server_name, original_name) =
358                routing.ok_or_else(|| McpAdapterError::ToolNotFound {
359                    name: exposed_tool_name.to_owned(),
360                })?;
361
362            let transport = servers
363                .get(&server_name)
364                .ok_or_else(|| McpAdapterError::ServerNotFound {
365                    name: server_name.clone(),
366                })?
367                .session
368                .transport()
369                .clone();
370            drop(servers);
371
372            (server_name, original_name, transport)
373        };
374
375        transport
376            .call_tool(&original_name, arguments)
377            .await
378            .map_err(|e| McpAdapterError::Transport {
379                message: format!("Tool '{original_name}' on server '{server_name}' failed: {e}"),
380            })
381    }
382
383    /// Returns a reference to the callbacks bundle.
384    #[must_use]
385    pub fn callbacks(&self) -> &McpCallbacks {
386        &self.callbacks
387    }
388}
389
390// ---------------------------------------------------------------------------
391// AggregatedToolDescriptor
392// ---------------------------------------------------------------------------
393
394/// A tool descriptor with routing metadata for [`MultiServerMcpClient`].
395#[derive(Debug, Clone)]
396pub struct AggregatedToolDescriptor {
397    /// The tool name as exposed by this client (may include server prefix).
398    pub exposed_name: String,
399    /// The name of the server that provides this tool.
400    pub server_name: String,
401    /// The tool's original name on the server.
402    pub original_name: String,
403    /// Human-readable description.
404    pub description: String,
405    /// JSON Schema for the tool's input parameters.
406    pub input_schema: Value,
407}
408
409#[cfg(test)]
410#[allow(clippy::unwrap_used)]
411mod tests {
412    use super::*;
413    use crate::pagination::PaginationCursor;
414
415    #[test]
416    fn connection_enum_variants_exist() {
417        let _stdio = Connection::Stdio {
418            command: "mcp-server".into(),
419            args: vec![],
420            env: HashMap::new(),
421        };
422        let _ws = Connection::WebSocket {
423            url: "ws://localhost:3000".into(),
424            auth_token: None,
425        };
426        let _sse = Connection::Sse {
427            url: "http://localhost:3000/sse".into(),
428            auth_token: None,
429            timeout_secs: None,
430        };
431        let _http = Connection::StreamableHttp {
432            url: "http://localhost:3000".into(),
433            auth_token: None,
434            timeout_secs: None,
435        };
436    }
437
438    #[test]
439    fn config_builder() {
440        let config = MultiServerMcpClientConfig::new()
441            .with_server(
442                "s1",
443                Connection::WebSocket {
444                    url: "ws://localhost:3000".into(),
445                    auth_token: None,
446                },
447            )
448            .with_server_prefix("s1", "srv1")
449            .with_global_prefix("global");
450
451        assert!(config.servers.contains_key("s1"));
452        assert_eq!(config.server_prefixes.get("s1"), Some(&"srv1".to_owned()));
453        assert_eq!(config.global_tool_prefix, Some("global".to_owned()));
454    }
455
456    #[test]
457    fn pagination_used_in_client_context() {
458        // Verify PaginationCursor is usable from client module context.
459        let mut cursor = PaginationCursor::new();
460        assert!(cursor.advance(Some("token1".into())));
461        assert!(!cursor.advance(None));
462    }
463}