Skip to main content

distri_types/
mcp.rs

1use std::sync::Arc;
2
3use async_mcp::{
4    server::Server,
5    transport::{ServerInMemoryTransport, Transport},
6};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use utoipa::ToSchema;
11
12use crate::auth::AuthType;
13
14#[async_trait::async_trait]
15pub trait ServerTrait: Send + Sync {
16    async fn listen(&self) -> anyhow::Result<()>;
17}
18
19pub type BuilderFn = dyn Fn(&ServerMetadataWrapper, ServerInMemoryTransport) -> anyhow::Result<Box<dyn ServerTrait>>
20    + Send
21    + Sync;
22
23#[async_trait::async_trait]
24impl<T: Transport> ServerTrait for Server<T> {
25    async fn listen(&self) -> anyhow::Result<()> {
26        self.listen().await
27    }
28}
29
30#[derive(Clone, Serialize, Deserialize, schemars::JsonSchema)]
31pub struct ServerMetadataWrapper {
32    pub server_metadata: McpServerMetadata,
33    #[serde(skip)]
34    pub builder: Option<Arc<BuilderFn>>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
38#[serde(deny_unknown_fields, tag = "type", rename_all = "lowercase")]
39pub enum TransportType {
40    InMemory,
41    SSE {
42        server_url: String,
43        #[serde(flatten, skip_serializing_if = "Option::is_none")]
44        headers: Option<HashMap<String, String>>,
45    },
46    WS {
47        server_url: String,
48        #[serde(flatten, skip_serializing_if = "Option::is_none")]
49        headers: Option<HashMap<String, String>>,
50    },
51    Stdio {
52        command: String,
53        args: Vec<String>,
54        #[serde(flatten, skip_serializing_if = "Option::is_none")]
55        env_vars: Option<HashMap<String, String>>,
56    },
57}
58
59#[derive(Clone, Serialize, Deserialize, JsonSchema)]
60pub struct McpServerMetadata {
61    #[serde(default)]
62    pub auth_session_key: Option<String>,
63    #[serde(default = "default_transport_type", flatten)]
64    pub mcp_transport: TransportType,
65    #[serde(default)]
66    pub auth_type: Option<AuthType>,
67}
68
69pub fn default_transport_type() -> TransportType {
70    TransportType::InMemory
71}
72
73impl std::fmt::Debug for McpServerMetadata {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        f.debug_struct("ServerMetadata")
76            .field("auth_session_key", &self.auth_session_key)
77            .field("mcp_transport", &self.mcp_transport)
78            .field("auth_type", &self.auth_type)
79            .finish()
80    }
81}
82
83/// Client-side transport configuration for connecting to an MCP server.
84///
85/// Mirrors the variants the `rmcp` crate supports. `headers` is always optional
86/// — most public servers don't need extra headers, and OAuth bearers are
87/// injected by the connection resolver at pool-connect time rather than being
88/// stored here.
89///
90/// The `Stdio` variant is intentionally absent: connections are workspace
91/// resources that must be portable across hosts, so spawning a local child
92/// process is not user-configurable. In-process A2A servers use the separate
93/// `TransportType` enum on `McpServerMetadata`.
94#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, ToSchema, PartialEq)]
95#[serde(tag = "type", rename_all = "snake_case")]
96pub enum McpClientTransport {
97    /// Streamable HTTP transport (MCP 2025-03-26+ spec). One bidirectional
98    /// endpoint that may upgrade individual responses to SSE.
99    StreamableHttp {
100        url: String,
101        #[serde(default, skip_serializing_if = "Option::is_none")]
102        headers: Option<HashMap<String, String>>,
103    },
104    /// Legacy SSE-only transport (MCP 2024-11-05 spec). Kept for servers that
105    /// haven't migrated yet; prefer Streamable HTTP for new connections.
106    Sse {
107        url: String,
108        #[serde(default, skip_serializing_if = "Option::is_none")]
109        headers: Option<HashMap<String, String>>,
110    },
111}
112
113impl McpClientTransport {
114    pub fn url(&self) -> &str {
115        match self {
116            Self::StreamableHttp { url, .. } | Self::Sse { url, .. } => url.as_str(),
117        }
118    }
119
120    pub fn headers(&self) -> Option<&HashMap<String, String>> {
121        match self {
122            Self::StreamableHttp { headers, .. } | Self::Sse { headers, .. } => headers.as_ref(),
123        }
124    }
125
126    pub fn validate(&self) -> Result<(), String> {
127        let url = self.url();
128        if url.trim().is_empty() {
129            return Err("transport requires a url".to_string());
130        }
131        url::Url::parse(url).map_err(|e| format!("invalid url '{}': {}", url, e))?;
132        Ok(())
133    }
134}
135
136/// A pool-ready handle for one MCP server: identifier + transport + already-
137/// resolved authorization headers.
138///
139/// Built by the host (e.g. distri-cloud) by looking up `kind = Mcp` connections
140/// in scope and resolving their `auth_type` into bearer headers. Passed into
141/// `McpClientPool::new` so the pool itself never has to touch the connection
142/// store.
143#[derive(Debug, Clone)]
144pub struct McpServerHandle {
145    /// Stable name used by agents in `ToolsConfig.mcp[].server`.
146    pub name: String,
147    pub transport: McpClientTransport,
148    /// Headers to merge into the transport at connect time. Typically contains
149    /// `Authorization: Bearer …` if the backing connection is OAuth.
150    pub resolved_headers: HashMap<String, String>,
151    pub enabled: bool,
152}
153
154impl McpServerHandle {
155    pub fn validate(&self) -> Result<(), String> {
156        if self.name.trim().is_empty() {
157            return Err("MCP server handle name must be non-empty".to_string());
158        }
159        if !self
160            .name
161            .chars()
162            .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
163        {
164            return Err(format!(
165                "MCP server handle name '{}' must be alphanumeric/underscore/dash only",
166                self.name
167            ));
168        }
169        self.transport.validate()
170    }
171}