use crate::error::{Error, Result};
use crate::types::{
CallToolRequest, CallToolResult, ClientRequest, Content, GetPromptRequest, GetPromptResult,
Implementation, InitializeRequest, InitializeResult, JSONRPCError, JSONRPCResponse,
ListPromptsRequest, ListPromptsResult, ListResourcesRequest, ListResourcesResult,
ListToolsRequest, ListToolsResult, PromptInfo, ReadResourceRequest, ReadResourceResult,
Request, RequestId, ResourceInfo, ServerCapabilities, ToolInfo,
};
use crate::{ErrorCode, SUPPORTED_PROTOCOL_VERSIONS};
use serde_json::json;
use serde_json::Value;
use std::collections::HashMap;
pub trait WasmTool: Send + Sync {
fn execute(&self, args: Value) -> Result<Value>;
fn info(&self) -> ToolInfo;
}
pub trait WasmResource: Send + Sync {
fn read(&self, uri: &str) -> Result<ReadResourceResult>;
fn list(&self, cursor: Option<String>) -> Result<ListResourcesResult>;
fn templates(&self) -> Vec<ResourceInfo> {
Vec::new()
}
}
pub trait WasmPrompt: Send + Sync {
fn generate(&self, args: HashMap<String, String>) -> Result<GetPromptResult>;
fn info(&self) -> PromptInfo;
}
pub struct WasmMcpServer {
info: Implementation,
capabilities: ServerCapabilities,
tools: HashMap<String, Box<dyn WasmTool>>,
resources: HashMap<String, Box<dyn WasmResource>>,
prompts: HashMap<String, Box<dyn WasmPrompt>>,
tool_infos: HashMap<String, ToolInfo>,
prompt_infos: HashMap<String, PromptInfo>,
}
impl WasmMcpServer {
pub fn builder() -> WasmMcpServerBuilder {
WasmMcpServerBuilder::new()
}
fn map_error_code(error: &Error) -> ErrorCode {
match error {
Error::Protocol { code, .. } => *code,
_ => ErrorCode::INTERNAL_ERROR,
}
}
pub async fn handle_request(&self, id: RequestId, request: Request) -> JSONRPCResponse {
let result = match request {
Request::Client(client_req) => self.handle_client_request(*client_req).await,
Request::Server(_) => Err(Error::protocol(
ErrorCode::INVALID_REQUEST,
"Server requests not supported in WASM",
)),
};
match result {
Ok(value) => JSONRPCResponse {
jsonrpc: "2.0".to_string(),
id,
payload: crate::types::jsonrpc::ResponsePayload::Result(value),
},
Err(error) => JSONRPCResponse {
jsonrpc: "2.0".to_string(),
id,
payload: crate::types::jsonrpc::ResponsePayload::Error(JSONRPCError {
code: Self::map_error_code(&error).0,
message: error.to_string(),
data: None,
}),
},
}
}
async fn handle_client_request(&self, request: ClientRequest) -> Result<Value> {
match request {
ClientRequest::Initialize(params) => self.handle_initialize(params),
ClientRequest::ListTools(params) => self.handle_list_tools(params),
ClientRequest::CallTool(params) => self.handle_call_tool(params),
ClientRequest::ListResources(params) => self.handle_list_resources(params),
ClientRequest::ReadResource(params) => self.handle_read_resource(params),
ClientRequest::ListPrompts(params) => self.handle_list_prompts(params),
ClientRequest::GetPrompt(params) => self.handle_get_prompt(params),
_ => Err(Error::protocol(
ErrorCode::METHOD_NOT_FOUND,
"Method not supported in WASM",
)),
}
}
fn handle_initialize(&self, params: InitializeRequest) -> Result<Value> {
let negotiated_version = crate::negotiate_protocol_version(¶ms.protocol_version);
let result = InitializeResult {
protocol_version: crate::types::ProtocolVersion(negotiated_version.to_string()),
capabilities: self.capabilities.clone(),
server_info: self.info.clone(),
instructions: None,
};
serde_json::to_value(result).map_err(|e| Error::internal(&e.to_string()))
}
fn handle_list_tools(&self, _params: ListToolsRequest) -> Result<Value> {
let tools: Vec<ToolInfo> = self.tool_infos.values().cloned().collect();
let result = ListToolsResult {
tools,
next_cursor: None,
};
serde_json::to_value(result).map_err(|e| Error::internal(&e.to_string()))
}
fn handle_call_tool(&self, params: CallToolRequest) -> Result<Value> {
let tool = self.tools.get(¶ms.name).ok_or_else(|| {
Error::protocol(
ErrorCode::METHOD_NOT_FOUND,
&format!("Tool '{}' not found", params.name),
)
})?;
let args = params.arguments.clone();
match tool.execute(args) {
Ok(result_value) => {
let content = if let Some(text) = result_value.as_str() {
vec![Content::text(text)]
} else if result_value.is_object() {
vec![Content::text(
serde_json::to_string_pretty(&result_value)
.unwrap_or_else(|_| "{}".to_string()),
)]
} else {
vec![Content::text(result_value.to_string())]
};
let result = CallToolResult::new(content);
serde_json::to_value(result).map_err(|e| Error::internal(&e.to_string()))
},
Err(e) => {
let result = CallToolResult::error(vec![Content::text(format!("Error: {}", e))]);
serde_json::to_value(result).map_err(|e| Error::internal(&e.to_string()))
},
}
}
fn handle_list_resources(&self, params: ListResourcesRequest) -> Result<Value> {
let mut all_resources = Vec::new();
let mut next_cursor = None;
let (provider_name, provider_cursor) = if let Some(cursor) = params.cursor {
if let Some((name, cur)) = cursor.split_once(':') {
(Some(name.to_string()), Some(cur.to_string()))
} else {
(None, Some(cursor))
}
} else {
(None, None)
};
let mut found_provider = provider_name.is_none();
for (name, resource) in &self.resources {
if let Some(ref pname) = provider_name {
if name != pname {
continue;
}
}
if found_provider {
match resource.list(provider_cursor.clone()) {
Ok(result) => {
all_resources.extend(result.resources);
if let Some(cursor) = result.next_cursor {
next_cursor = Some(format!("{}:{}", name, cursor));
}
break; },
Err(_) => continue,
}
}
if provider_name.is_none() {
found_provider = true;
}
}
let result = ListResourcesResult {
resources: all_resources,
next_cursor,
};
serde_json::to_value(result).map_err(|e| Error::internal(&e.to_string()))
}
fn handle_read_resource(&self, params: ReadResourceRequest) -> Result<Value> {
for resource in self.resources.values() {
if let Ok(result) = resource.read(¶ms.uri) {
return serde_json::to_value(result).map_err(|e| Error::internal(&e.to_string()));
}
}
Err(Error::protocol(
ErrorCode::METHOD_NOT_FOUND,
&format!("No resource handler for URI: {}", params.uri),
))
}
fn handle_list_prompts(&self, _params: ListPromptsRequest) -> Result<Value> {
let prompts: Vec<PromptInfo> = self.prompt_infos.values().cloned().collect();
let result = ListPromptsResult {
prompts,
next_cursor: None,
};
serde_json::to_value(result).map_err(|e| Error::internal(&e.to_string()))
}
fn handle_get_prompt(&self, params: GetPromptRequest) -> Result<Value> {
let prompt = self.prompts.get(¶ms.name).ok_or_else(|| {
Error::protocol(
ErrorCode::METHOD_NOT_FOUND,
&format!("Prompt '{}' not found", params.name),
)
})?;
let result = prompt.generate(params.arguments.clone())?;
serde_json::to_value(result).map_err(|e| Error::internal(&e.to_string()))
}
}
pub struct WasmMcpServerBuilder {
name: String,
version: String,
capabilities: ServerCapabilities,
tools: HashMap<String, Box<dyn WasmTool>>,
resources: HashMap<String, Box<dyn WasmResource>>,
prompts: HashMap<String, Box<dyn WasmPrompt>>,
tool_infos: HashMap<String, ToolInfo>,
prompt_infos: HashMap<String, PromptInfo>,
}
impl WasmMcpServerBuilder {
pub fn new() -> Self {
Self {
name: "wasm-mcp-server".to_string(),
version: "1.0.0".to_string(),
capabilities: ServerCapabilities::default(),
tools: HashMap::new(),
resources: HashMap::new(),
prompts: HashMap::new(),
tool_infos: HashMap::new(),
prompt_infos: HashMap::new(),
}
}
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
pub fn version(mut self, version: impl Into<String>) -> Self {
self.version = version.into();
self
}
pub fn capabilities(mut self, capabilities: ServerCapabilities) -> Self {
self.capabilities = capabilities;
self
}
pub fn tool<T: WasmTool + 'static>(mut self, name: impl Into<String>, tool: T) -> Self {
let name = name.into();
let info = tool.info();
self.tool_infos.insert(name.clone(), info);
self.tools.insert(name, Box::new(tool));
self.capabilities.tools = Some(Default::default());
self
}
pub fn resource<R: WasmResource + 'static>(
mut self,
name: impl Into<String>,
resource: R,
) -> Self {
self.resources.insert(name.into(), Box::new(resource));
self.capabilities.resources = Some(Default::default());
self
}
pub fn prompt<P: WasmPrompt + 'static>(mut self, name: impl Into<String>, prompt: P) -> Self {
let name = name.into();
let info = prompt.info();
self.prompt_infos.insert(name.clone(), info);
self.prompts.insert(name, Box::new(prompt));
self.capabilities.prompts = Some(Default::default());
self
}
pub fn build(self) -> WasmMcpServer {
WasmMcpServer {
info: Implementation::new(self.name, self.version),
capabilities: self.capabilities,
tools: self.tools,
resources: self.resources,
prompts: self.prompts,
tool_infos: self.tool_infos,
prompt_infos: self.prompt_infos,
}
}
}
pub struct SimpleTool<F> {
name: String,
description: String,
input_schema: Value,
handler: F,
}
impl<F> SimpleTool<F>
where
F: Fn(Value) -> Result<Value> + Send + Sync,
{
pub fn new(name: impl Into<String>, description: impl Into<String>, handler: F) -> Self {
Self {
name: name.into(),
description: description.into(),
input_schema: json!({
"type": "object",
"properties": {},
"additionalProperties": true
}),
handler,
}
}
pub fn with_schema(mut self, schema: Value) -> Self {
self.input_schema = schema;
self
}
}
impl<F> WasmTool for SimpleTool<F>
where
F: Fn(Value) -> Result<Value> + Send + Sync,
{
fn execute(&self, args: Value) -> Result<Value> {
(self.handler)(args)
}
fn info(&self) -> ToolInfo {
ToolInfo {
name: self.name.clone(),
title: None,
description: Some(self.description.clone()),
input_schema: self.input_schema.clone(),
output_schema: None,
annotations: None,
icons: None,
_meta: None,
execution: None,
}
}
}