1use std::sync::Arc;
2
3use async_mcp::{
4 server::Server,
5 transport::{ServerInMemoryTransport, Transport},
6};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11use crate::AuthType;
12
13#[async_trait::async_trait]
14pub trait ServerTrait: Send + Sync {
15 async fn listen(&self) -> anyhow::Result<()>;
16}
17
18pub type BuilderFn = dyn Fn(&ServerMetadataWrapper, ServerInMemoryTransport) -> anyhow::Result<Box<dyn ServerTrait>>
19 + Send
20 + Sync;
21
22#[async_trait::async_trait]
23impl<T: Transport> ServerTrait for Server<T> {
24 async fn listen(&self) -> anyhow::Result<()> {
25 self.listen().await
26 }
27}
28
29#[derive(Clone, Serialize, Deserialize, schemars::JsonSchema)]
30pub struct ServerMetadataWrapper {
31 pub server_metadata: McpServerMetadata,
32 #[serde(skip)]
33 pub builder: Option<Arc<BuilderFn>>,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
37#[serde(deny_unknown_fields, tag = "type", rename_all = "lowercase")]
38pub enum TransportType {
39 InMemory,
40 SSE {
41 server_url: String,
42 #[serde(flatten, skip_serializing_if = "Option::is_none")]
43 headers: Option<HashMap<String, String>>,
44 },
45 WS {
46 server_url: String,
47 #[serde(flatten, skip_serializing_if = "Option::is_none")]
48 headers: Option<HashMap<String, String>>,
49 },
50 Stdio {
51 command: String,
52 args: Vec<String>,
53 #[serde(flatten, skip_serializing_if = "Option::is_none")]
54 env_vars: Option<HashMap<String, String>>,
55 },
56}
57
58#[derive(Clone, Serialize, Deserialize, JsonSchema)]
59pub struct McpServerMetadata {
60 #[serde(default)]
61 pub auth_session_key: Option<String>,
62 #[serde(default = "default_transport_type", flatten)]
63 pub mcp_transport: TransportType,
64 #[serde(default)]
65 pub auth_type: Option<AuthType>,
66}
67
68pub fn default_transport_type() -> TransportType {
69 TransportType::InMemory
70}
71
72impl std::fmt::Debug for McpServerMetadata {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 f.debug_struct("ServerMetadata")
75 .field("auth_session_key", &self.auth_session_key)
76 .field("mcp_transport", &self.mcp_transport)
77 .field("auth_type", &self.auth_type)
78 .finish()
79 }
80}