use crate::{
context::{Context, Inject},
errors::RequestError,
handler::{PromptHandler, ToolHandler},
};
use kuri_mcp_protocol::{
jsonrpc::{ErrorData, JsonRpcRequest, JsonRpcResponse, Params, SendableMessage},
messages::{
CallToolResult, GetPromptResult, Implementation, InitializeResult, ListPromptsResult,
ListResourcesResult, ListToolsResult, PromptsCapability, ReadResourceResult,
ResourcesCapability, ServerCapabilities, ToolsCapability,
},
prompt::{Prompt, PromptError, PromptMessage, PromptMessageRole},
resource::{ResourceContents, ResourceError},
tool::ToolError,
Content,
};
use serde_json::json;
use serde_json::Value;
use std::task::Poll;
use std::{collections::HashMap, future::Future, pin::Pin};
use std::{convert::Infallible, rc::Rc};
use tower::Service;
type Tools = HashMap<String, Rc<dyn ToolHandler>>;
type Prompts = HashMap<String, Rc<dyn PromptHandler>>;
#[derive(Clone)]
pub struct MCPService {
name: String,
description: String,
tools: Rc<Tools>,
prompts: Rc<Prompts>,
ctx: Rc<Context>,
}
pub struct MCPServiceBuilder {
name: String,
description: String,
tools: Tools,
prompts: Prompts,
ctx: Context,
}
impl MCPServiceBuilder {
pub fn new(name: String, description: String) -> Self {
Self {
name,
description,
tools: HashMap::new(),
prompts: HashMap::new(),
ctx: Context::default(),
}
}
pub fn with_tool(mut self, tool: impl ToolHandler) -> Self {
self.tools.insert(tool.name().to_string(), Rc::new(tool));
self
}
pub fn with_prompt(mut self, prompt: impl PromptHandler) -> Self {
self.prompts
.insert(prompt.name().to_string(), Rc::new(prompt));
self
}
pub fn with_state<T: 'static>(mut self, state: Inject<T>) -> Self {
self.ctx.insert(state);
self
}
pub fn build(self) -> MCPService {
MCPService {
name: self.name,
description: self.description,
tools: Rc::new(self.tools),
prompts: Rc::new(self.prompts),
ctx: Rc::new(self.ctx),
}
}
}
pub struct CapabilitiesBuilder {
tools: Option<ToolsCapability>,
prompts: Option<PromptsCapability>,
resources: Option<ResourcesCapability>,
}
impl Default for CapabilitiesBuilder {
fn default() -> Self {
Self::new()
}
}
impl CapabilitiesBuilder {
pub fn new() -> Self {
Self {
tools: None,
prompts: None,
resources: None,
}
}
pub fn with_tools(mut self, list_changed: bool) -> Self {
self.tools = Some(ToolsCapability {
list_changed: Some(list_changed),
});
self
}
pub fn with_prompts(mut self, list_changed: bool) -> Self {
self.prompts = Some(PromptsCapability {
list_changed: Some(list_changed),
});
self
}
#[allow(dead_code)]
pub fn with_resources(mut self, subscribe: bool, list_changed: bool) -> Self {
self.resources = Some(ResourcesCapability {
subscribe: Some(subscribe),
list_changed: Some(list_changed),
});
self
}
pub fn build(self) -> ServerCapabilities {
ServerCapabilities {
tools: self.tools,
prompts: self.prompts,
resources: self.resources,
}
}
}
trait MCPServiceTrait: 'static {
fn name(&self) -> String;
fn instructions(&self) -> String;
fn capabilities(&self) -> ServerCapabilities;
fn list_tools(&self) -> Vec<kuri_mcp_protocol::tool::Tool>;
fn call_tool(
&self,
tool_name: &str,
arguments: Value,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + '_>>;
fn list_resources(&self) -> Vec<kuri_mcp_protocol::resource::Resource>;
fn read_resource(
&self,
uri: &str,
) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + 'static>>;
fn list_prompts(&self) -> Vec<Prompt>;
fn get_prompt(
&self,
prompt_name: &str,
arguments: HashMap<String, serde_json::Value>,
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + '_>>;
}
impl MCPServiceTrait for MCPService {
fn name(&self) -> String {
self.name.clone()
}
fn instructions(&self) -> String {
self.description.clone()
}
fn capabilities(&self) -> kuri_mcp_protocol::messages::ServerCapabilities {
let mut builder = CapabilitiesBuilder::new();
if !self.tools.is_empty() {
builder = builder.with_tools(false);
}
if !self.prompts.is_empty() {
builder = builder.with_prompts(false);
}
builder.build()
}
fn list_tools(&self) -> Vec<kuri_mcp_protocol::tool::Tool> {
self.tools
.iter()
.map(|(name, tool)| {
kuri_mcp_protocol::tool::Tool::new(name.clone(), tool.description(), tool.schema())
})
.collect()
}
fn call_tool(
&self,
tool_name: &str,
arguments: serde_json::Value,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + '_>> {
let tool = match self.tools.get(tool_name) {
Some(tool) => tool,
None => {
return Box::pin(futures::future::ready(Err(ToolError::NotFound(
tool_name.to_string(),
))))
}
};
Box::pin(async move {
let res = tool.call(&self.ctx, arguments).await?;
let contents = match res {
serde_json::Value::Number(n) => vec![Content::text(n.to_string())],
serde_json::Value::String(s) => vec![Content::text(s)],
serde_json::Value::Bool(b) => vec![Content::text(b.to_string())],
serde_json::Value::Array(_) => serde_json::from_value(res)
.map_err(|e| ToolError::ExecutionError(e.to_string()))?,
serde_json::Value::Null => vec![],
serde_json::Value::Object(_) => serde_json::from_value(res)
.map_err(|e| ToolError::ExecutionError(e.to_string()))?,
};
Ok(contents)
})
}
fn list_resources(&self) -> Vec<kuri_mcp_protocol::resource::Resource> {
vec![]
}
fn read_resource(
&self,
_uri: &str,
) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + 'static>> {
Box::pin(futures::future::ready(Err(ResourceError::ExecutionError(
"Reading resources is not yet implemented".into(),
))))
}
fn list_prompts(&self) -> Vec<Prompt> {
self.prompts
.values()
.map(|prompt| Prompt::new(prompt.name(), prompt.description(), prompt.arguments()))
.collect()
}
fn get_prompt(
&self,
prompt_name: &str,
arguments: HashMap<String, serde_json::Value>,
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + '_>> {
let prompt = match self.prompts.get(prompt_name) {
Some(prompt) => prompt,
None => {
return Box::pin(futures::future::ready(Err(PromptError::NotFound(
prompt_name.to_string(),
))));
}
};
Box::pin(async move {
let result = prompt.call(&self.ctx, arguments).await?;
Ok(result)
})
}
}
#[allow(clippy::manual_async_fn)]
impl MCPService {
fn handle_ping(
&self,
req: JsonRpcRequest,
) -> impl Future<Output = Result<JsonRpcResponse, RequestError>> {
async move { Ok(JsonRpcResponse::success(req.id, json!({}))) }
}
fn handle_initialize(
&self,
req: JsonRpcRequest,
) -> impl Future<Output = Result<JsonRpcResponse, RequestError>> + '_ {
async move {
let result = InitializeResult {
protocol_version: "2024-11-05".to_string(),
capabilities: self.capabilities(),
server_info: Implementation {
name: self.name(),
version: env!("CARGO_PKG_VERSION").to_string(),
},
instructions: Some(self.instructions()),
};
let result = serde_json::to_value(result)
.map_err(|e| RequestError::Internal(format!("JSON serialization error: {}", e)))?;
let response = JsonRpcResponse::success(req.id, result);
Ok(response)
}
}
fn handle_tools_list(
&self,
req: JsonRpcRequest,
) -> impl Future<Output = Result<JsonRpcResponse, RequestError>> + '_ {
async move {
let tools = self.list_tools();
let result = ListToolsResult { tools };
let result = serde_json::to_value(result)
.map_err(|e| RequestError::Internal(format!("JSON serialization error: {}", e)))?;
let response = JsonRpcResponse::success(req.id, result);
Ok(response)
}
}
fn handle_tools_call(
&self,
req: JsonRpcRequest,
) -> impl Future<Output = Result<JsonRpcResponse, RequestError>> + '_ {
async move {
let params = match req.params {
Some(Params::Map(map)) => map,
Some(_) => {
return Err(RequestError::InvalidParams(
"Parameters must be a map-like object".into(),
))
}
None => return Err(RequestError::InvalidParams("The request was empty".into())),
};
let name = params
.get("name")
.and_then(Value::as_str)
.ok_or_else(|| RequestError::InvalidParams("No tool name was provided".into()))?;
let arguments = params.get("arguments").cloned().unwrap_or(Value::Null);
let result = match self.call_tool(name, arguments).await {
Ok(result) => CallToolResult {
content: result,
is_error: None,
},
Err(err) => {
match err {
ToolError::NotFound(e) => {
return Err(RequestError::InvalidParams(format!(
"Tool not found: {}",
e
)));
}
ToolError::InvalidParameters(e) => {
return Err(RequestError::InvalidParams(format!(
"Invalid tool arguments: {}",
e
)));
}
_ => CallToolResult {
content: vec![Content::text(err.to_string())],
is_error: Some(true),
},
}
}
};
let result = serde_json::to_value(result)
.map_err(|e| RequestError::Internal(format!("JSON serialization error: {}", e)))?;
let response = JsonRpcResponse::success(req.id, result);
Ok(response)
}
}
fn handle_resources_list(
&self,
req: JsonRpcRequest,
) -> impl Future<Output = Result<JsonRpcResponse, RequestError>> + '_ {
async move {
let resources = self.list_resources();
let result = ListResourcesResult { resources };
let result = serde_json::to_value(result)
.map_err(|e| RequestError::Internal(format!("JSON serialization error: {}", e)))?;
let response = JsonRpcResponse::success(req.id, result);
Ok(response)
}
}
fn handle_resources_read(
&self,
req: JsonRpcRequest,
) -> impl Future<Output = Result<JsonRpcResponse, RequestError>> + '_ {
async move {
let params = match req.params {
Some(Params::Map(map)) => map,
Some(_) => {
return Err(RequestError::InvalidParams(
"Parameters must be a map-like object".into(),
))
}
None => return Err(RequestError::InvalidParams("Missing parameters".into())),
};
let uri = params
.get("uri")
.and_then(Value::as_str)
.ok_or_else(|| RequestError::InvalidParams("Missing resource URI".into()))?;
let contents = self.read_resource(uri).await.map_err(RequestError::from)?;
let result = ReadResourceResult {
contents: vec![ResourceContents::TextResourceContents {
uri: uri.to_string(),
mime_type: Some("text/plain".to_string()),
text: contents,
}],
};
let result = serde_json::to_value(result)
.map_err(|e| RequestError::Internal(format!("JSON serialization error: {}", e)))?;
let response = JsonRpcResponse::success(req.id, result);
Ok(response)
}
}
fn handle_prompts_list(
&self,
req: JsonRpcRequest,
) -> impl Future<Output = Result<JsonRpcResponse, RequestError>> + '_ {
async move {
let prompts = self.list_prompts();
let result = ListPromptsResult { prompts };
let result = serde_json::to_value(result)
.map_err(|e| RequestError::Internal(format!("JSON serialization error: {}", e)))?;
let response = JsonRpcResponse::success(req.id, result);
Ok(response)
}
}
fn handle_prompts_get(
&self,
req: JsonRpcRequest,
) -> impl Future<Output = Result<JsonRpcResponse, RequestError>> + '_ {
async move {
let params = match req.params {
Some(Params::Map(map)) => map,
Some(_) => {
return Err(RequestError::InvalidParams(
"Parameters must be a map-like object when calling `prompts/get`".into(),
))
}
None => return Err(RequestError::InvalidParams("Missing parameters".into())),
};
let prompt_name = params
.get("name")
.and_then(Value::as_str)
.ok_or_else(|| RequestError::InvalidParams("Missing prompt name".into()))?;
let arguments = params
.get("arguments")
.and_then(Value::as_object)
.ok_or_else(|| RequestError::InvalidParams("Missing arguments object".into()))?;
let arguments: HashMap<String, serde_json::Value> = arguments
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect();
let prompt_message =
self.get_prompt(prompt_name, arguments)
.await
.map_err(|e| match e {
PromptError::InvalidParameters(_) => {
RequestError::InvalidParams(e.to_string())
}
PromptError::NotFound(_) => RequestError::InvalidParams(e.to_string()),
PromptError::InternalError(_) => RequestError::Internal(e.to_string()),
})?;
let messages = vec![PromptMessage::new_text(
PromptMessageRole::User,
prompt_message.to_string(),
)];
let result = serde_json::to_value(GetPromptResult {
description: None,
messages,
})
.map_err(|e| RequestError::Internal(format!("JSON serialization error: {}", e)))?;
let response = JsonRpcResponse::success(req.id, result);
Ok(response)
}
}
}
impl Service<SendableMessage> for MCPService {
type Response = Option<JsonRpcResponse>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: SendableMessage) -> Self::Future {
let this = self.clone();
Box::pin(async move {
if let SendableMessage::Request(req) = req {
let id = req.id.clone();
let result = match req.method.as_str() {
"ping" => this.handle_ping(req).await,
"initialize" => this.handle_initialize(req).await,
"tools/list" => this.handle_tools_list(req).await,
"tools/call" => this.handle_tools_call(req).await,
"resources/list" => this.handle_resources_list(req).await,
"resources/read" => this.handle_resources_read(req).await,
"prompts/list" => this.handle_prompts_list(req).await,
"prompts/get" => this.handle_prompts_get(req).await,
_ => Err(RequestError::MethodNotFound(req.method)),
};
let response = match result {
Ok(response) => response,
Err(e) => {
let error = ErrorData::from(e);
JsonRpcResponse::error(id, error)
}
};
Ok(Some(response))
} else {
Ok(None)
}
})
}
}