use std::sync::Arc;
use async_mcp::{
server::Server,
transport::{ServerInMemoryTransport, Transport},
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use utoipa::ToSchema;
use crate::auth::AuthType;
#[async_trait::async_trait]
pub trait ServerTrait: Send + Sync {
async fn listen(&self) -> anyhow::Result<()>;
}
pub type BuilderFn = dyn Fn(&ServerMetadataWrapper, ServerInMemoryTransport) -> anyhow::Result<Box<dyn ServerTrait>>
+ Send
+ Sync;
#[async_trait::async_trait]
impl<T: Transport> ServerTrait for Server<T> {
async fn listen(&self) -> anyhow::Result<()> {
self.listen().await
}
}
#[derive(Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct ServerMetadataWrapper {
pub server_metadata: McpServerMetadata,
#[serde(skip)]
pub builder: Option<Arc<BuilderFn>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(deny_unknown_fields, tag = "type", rename_all = "lowercase")]
pub enum TransportType {
InMemory,
SSE {
server_url: String,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
headers: Option<HashMap<String, String>>,
},
WS {
server_url: String,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
headers: Option<HashMap<String, String>>,
},
Stdio {
command: String,
args: Vec<String>,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
env_vars: Option<HashMap<String, String>>,
},
}
#[derive(Clone, Serialize, Deserialize, JsonSchema)]
pub struct McpServerMetadata {
#[serde(default)]
pub auth_session_key: Option<String>,
#[serde(default = "default_transport_type", flatten)]
pub mcp_transport: TransportType,
#[serde(default)]
pub auth_type: Option<AuthType>,
}
pub fn default_transport_type() -> TransportType {
TransportType::InMemory
}
impl std::fmt::Debug for McpServerMetadata {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ServerMetadata")
.field("auth_session_key", &self.auth_session_key)
.field("mcp_transport", &self.mcp_transport)
.field("auth_type", &self.auth_type)
.finish()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, ToSchema, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum McpClientTransport {
StreamableHttp {
url: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
headers: Option<HashMap<String, String>>,
},
Sse {
url: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
headers: Option<HashMap<String, String>>,
},
}
impl McpClientTransport {
pub fn url(&self) -> &str {
match self {
Self::StreamableHttp { url, .. } | Self::Sse { url, .. } => url.as_str(),
}
}
pub fn headers(&self) -> Option<&HashMap<String, String>> {
match self {
Self::StreamableHttp { headers, .. } | Self::Sse { headers, .. } => headers.as_ref(),
}
}
pub fn validate(&self) -> Result<(), String> {
let url = self.url();
if url.trim().is_empty() {
return Err("transport requires a url".to_string());
}
url::Url::parse(url).map_err(|e| format!("invalid url '{}': {}", url, e))?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct McpServerHandle {
pub name: String,
pub transport: McpClientTransport,
pub resolved_headers: HashMap<String, String>,
pub enabled: bool,
}
impl McpServerHandle {
pub fn validate(&self) -> Result<(), String> {
if self.name.trim().is_empty() {
return Err("MCP server handle name must be non-empty".to_string());
}
if !self
.name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
{
return Err(format!(
"MCP server handle name '{}' must be alphanumeric/underscore/dash only",
self.name
));
}
self.transport.validate()
}
}