Skip to main content

distri_types/
mcp.rs

1use std::sync::Arc;
2
3use async_mcp::{
4    server::Server,
5    transport::{ServerInMemoryTransport, Transport},
6};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11use crate::AuthType;
12
13#[async_trait::async_trait]
14pub trait ServerTrait: Send + Sync {
15    async fn listen(&self) -> anyhow::Result<()>;
16}
17
18pub type BuilderFn = dyn Fn(&ServerMetadataWrapper, ServerInMemoryTransport) -> anyhow::Result<Box<dyn ServerTrait>>
19    + Send
20    + Sync;
21
22#[async_trait::async_trait]
23impl<T: Transport> ServerTrait for Server<T> {
24    async fn listen(&self) -> anyhow::Result<()> {
25        self.listen().await
26    }
27}
28
29#[derive(Clone, Serialize, Deserialize, schemars::JsonSchema)]
30pub struct ServerMetadataWrapper {
31    pub server_metadata: McpServerMetadata,
32    #[serde(skip)]
33    pub builder: Option<Arc<BuilderFn>>,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
37#[serde(deny_unknown_fields, tag = "type", rename_all = "lowercase")]
38pub enum TransportType {
39    InMemory,
40    SSE {
41        server_url: String,
42        #[serde(flatten, skip_serializing_if = "Option::is_none")]
43        headers: Option<HashMap<String, String>>,
44    },
45    WS {
46        server_url: String,
47        #[serde(flatten, skip_serializing_if = "Option::is_none")]
48        headers: Option<HashMap<String, String>>,
49    },
50    Stdio {
51        command: String,
52        args: Vec<String>,
53        #[serde(flatten, skip_serializing_if = "Option::is_none")]
54        env_vars: Option<HashMap<String, String>>,
55    },
56}
57
58#[derive(Clone, Serialize, Deserialize, JsonSchema)]
59pub struct McpServerMetadata {
60    #[serde(default)]
61    pub auth_session_key: Option<String>,
62    #[serde(default = "default_transport_type", flatten)]
63    pub mcp_transport: TransportType,
64    #[serde(default)]
65    pub auth_type: Option<AuthType>,
66}
67
68pub fn default_transport_type() -> TransportType {
69    TransportType::InMemory
70}
71
72impl std::fmt::Debug for McpServerMetadata {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        f.debug_struct("ServerMetadata")
75            .field("auth_session_key", &self.auth_session_key)
76            .field("mcp_transport", &self.mcp_transport)
77            .field("auth_type", &self.auth_type)
78            .finish()
79    }
80}