use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use url::Url;
use crate::server::{
completion::{CompletionCallback, RegisteredCompletion},
error::ServerError,
prompt::{PromptBuilder, RegisteredPrompt},
resource::{RegisteredResource, RegisteredResourceTemplate, ResourceTemplate, ReadResourceResult, ReadResourceCallbackFn},
roots::{RegisteredRoots, Root},
sampling::{RegisteredSampling, SamplingRequest, SamplingResult},
tool::{RegisteredTool, ToolBuilder},
Server,
};
use crate::transport::Transport;
use crate::types::{Implementation, Prompt, Resource, ServerCapabilities, Tool};
type Result<T> = std::result::Result<T, ServerError>;
pub struct McpServer {
pub server: Server,
registered_resources: HashMap<String, RegisteredResource>,
registered_resource_templates: HashMap<String, RegisteredResourceTemplate>,
registered_tools: HashMap<String, RegisteredTool>,
registered_prompts: HashMap<String, RegisteredPrompt>,
registered_sampling: Option<RegisteredSampling>,
registered_roots: Option<RegisteredRoots>,
registered_completion: Option<RegisteredCompletion>,
}
impl McpServer {
pub fn new(server_info: Implementation) -> Self {
Self {
server: Server::new(server_info),
registered_resources: HashMap::new(),
registered_resource_templates: HashMap::new(),
registered_tools: HashMap::new(),
registered_prompts: HashMap::new(),
registered_sampling: None,
registered_roots: None,
registered_completion: None,
}
}
pub fn register_completion(&mut self, handler: impl CompletionCallback + 'static) {
self.registered_completion = Some(RegisteredCompletion {
callback: Box::new(handler),
});
self.server.register_capabilities(ServerCapabilities {
completion: Some(Default::default()),
..Default::default()
});
}
pub fn register_sampling(
&mut self,
callback: impl Fn(SamplingRequest) -> Pin<Box<dyn Future<Output = Result<SamplingResult>> + Send + 'static>>
+ Send
+ Sync
+ 'static,
) {
self.registered_sampling = Some(RegisteredSampling {
callback: Arc::new(callback),
});
self.server.register_capabilities(ServerCapabilities {
sampling: Some(Default::default()),
..Default::default()
});
}
pub fn register_roots(
&mut self,
list_callback: impl Fn() -> Pin<Box<dyn Future<Output = Result<Vec<Root>>> + Send>>
+ Send
+ Sync
+ 'static,
supports_change_notifications: bool,
) {
let wrapped_callback = move || -> Pin<Box<dyn Future<Output = anyhow::Result<Vec<Root>>> + Send>> {
let fut = list_callback();
Box::pin(async move { fut.await.map_err(|e| anyhow::anyhow!("{}", e)) })
};
self.registered_roots = Some(RegisteredRoots::new(
wrapped_callback,
supports_change_notifications,
));
self.server.register_capabilities(ServerCapabilities {
roots: Some(Default::default()),
..Default::default()
});
}
pub async fn connect(&self, _transport: impl Transport) -> Result<()> {
self.server.connect(_transport).await
}
pub fn resource(
&mut self,
name: impl Into<String>,
uri: impl Into<String>,
metadata: Option<Resource>,
read_callback: impl Fn(&Url) -> Pin<Box<dyn Future<Output = ReadResourceResult> + Send + 'static>>
+ Send
+ Sync
+ 'static,
) {
let uri = uri.into();
let name = name.into();
let metadata = metadata.unwrap_or_else(|| Resource {
uri: Url::parse(&uri).unwrap_or_else(|e| {
eprintln!("Warning: Invalid URI '{}': {}", uri, e);
Url::parse("about:invalid").unwrap()
}),
name: name.clone(),
description: None,
mime_type: None,
});
self.registered_resources.insert(
uri.clone(),
RegisteredResource::new(
metadata,
read_callback,
false,
),
);
if self.registered_resources.len() == 1 {
self.server.register_capabilities(ServerCapabilities {
resources: Some(Default::default()),
..Default::default()
});
}
}
pub fn resource_template(
&mut self,
name: impl Into<String>,
template: ResourceTemplate,
metadata: Option<Resource>,
read_callback: impl Fn(&Url) -> Pin<Box<dyn Future<Output = ReadResourceResult> + Send + 'static>>
+ Send
+ Sync
+ 'static,
) {
let name = name.into();
let metadata = metadata.unwrap_or_else(|| Resource {
uri: Url::parse(template.uri_template()).unwrap_or_else(|e| {
eprintln!("Warning: Invalid URI template: {}", e);
Url::parse("about:invalid").unwrap()
}),
name: name.clone(),
description: None,
mime_type: None,
});
self.registered_resource_templates.insert(
name,
RegisteredResourceTemplate {
template,
metadata,
read_callback: Arc::new(ReadResourceCallbackFn(Box::new(read_callback))),
},
);
if self.registered_resource_templates.len() == 1 {
self.server.register_capabilities(ServerCapabilities {
resources: Some(Default::default()),
..Default::default()
});
}
}
pub fn prompt_builder(&self, name: impl Into<String>) -> PromptBuilder {
PromptBuilder::new(name)
}
pub fn register_prompt(&mut self, metadata: impl Into<Prompt>, registered: RegisteredPrompt) {
let metadata = metadata.into();
self.registered_prompts
.insert(metadata.name.clone(), registered);
if self.registered_prompts.len() == 1 {
self.server.register_capabilities(ServerCapabilities {
prompts: Some(Default::default()),
..Default::default()
});
}
}
pub fn tool_builder(&self, name: impl Into<String>) -> ToolBuilder {
ToolBuilder::new(name)
}
pub fn register_tool(&mut self, metadata: impl Into<Tool>, registered: RegisteredTool) {
let metadata = metadata.into();
self.registered_tools.insert(metadata.name.clone(), registered);
if self.registered_tools.len() == 1 {
self.server.register_capabilities(ServerCapabilities {
tools: Some(Default::default()),
..Default::default()
});
}
}
}