use futures::stream::StreamExt;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::{collections::HashMap, error::Error};
use tokio::{
io::{self as tokio_io, AsyncWriteExt, BufReader},
};
use tokio_util::codec::{FramedRead, LinesCodec};
#[derive(Serialize, Deserialize, Debug)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
#[serde(default = "default_id")]
pub id: Value,
pub method: String,
pub params: Option<Value>,
}
fn default_id() -> Value {
Value::Null
}
#[derive(Serialize, Deserialize, Debug)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
pub id: Value,
pub result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
}
pub type ToolHandler = Box<dyn Fn(Value) -> tokio::task::JoinHandle<Result<Value, String>> + Send + Sync>;
pub struct McpServer {
name: String,
version: String,
tools: HashMap<String, (String, Value, ToolHandler)>,
}
impl McpServer {
pub fn new(name: &str, version: &str) -> Self {
McpServer {
name: name.to_string(),
version: version.to_string(),
tools: HashMap::new(),
}
}
pub fn add_tool<F, Fut>(&mut self, name: &str, description: &str, input_schema: Value, handler: F)
where
F: Fn(Value) -> Fut + Send + Sync + Clone + 'static,
Fut: std::future::Future<Output = Result<Value, String>> + Send + 'static,
{
let handler_boxed: ToolHandler = Box::new(move |args| {
let handler_clone = handler.clone();
tokio::spawn(async move {
handler_clone(args).await
})
});
self.tools.insert(
name.to_string(),
(description.to_string(), input_schema, handler_boxed),
);
}
async fn handle_message(&self, msg: String) -> Option<String> {
eprintln!("Received message: {}", msg);
let request: JsonRpcRequest = match serde_json::from_str(&msg) {
Ok(req) => req,
Err(e) => {
eprintln!("Error parsing request: {}", e);
let error_response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: Value::Null,
result: None,
error: Some(JsonRpcError {
code: -32700,
message: "Parse error".to_string(),
data: Some(json!({"error": e.to_string()}))
})
};
let response_str = serde_json::to_string(&error_response).unwrap();
eprintln!("Sending error response: {}", response_str);
return Some(response_str);
}
};
if request.method == "initialize" {
let response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(json!({
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {
"listChanged": true
}
},
"serverInfo": {
"name": self.name,
"version": self.version
}
})),
error: None,
};
let response_str = serde_json::to_string(&response).unwrap();
eprintln!("Sending initialize response: {}", response_str);
return Some(response_str);
}
if request.method == "tools/list" {
let tools: Vec<Value> = self.tools
.iter()
.map(|(name, (description, input_schema, _))| {
json!({
"name": name,
"description": description,
"inputSchema": input_schema,
})
})
.collect();
let response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(json!({
"tools": tools,
})),
error: None,
};
return Some(serde_json::to_string(&response).unwrap());
}
if request.method == "resources/list" {
let response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(json!({
"resources": []
})),
error: None,
};
return Some(serde_json::to_string(&response).unwrap());
}
if request.method == "prompts/list" {
let response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(json!({
"prompts": []
})),
error: None,
};
return Some(serde_json::to_string(&response).unwrap());
}
if request.method == "tools/call" {
let params = match request.params {
Some(p) => p,
None => {
let error_response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id,
result: None,
error: Some(JsonRpcError {
code: -32602,
message: "Invalid params".to_string(),
data: None
})
};
return Some(serde_json::to_string(&error_response).unwrap());
}
};
let tool_name = match params.get("name") {
Some(Value::String(name)) => name,
_ => {
let error_response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id,
result: None,
error: Some(JsonRpcError {
code: -32602,
message: "Tool name required".to_string(),
data: None
})
};
return Some(serde_json::to_string(&error_response).unwrap());
}
};
let arguments = match params.get("arguments") {
Some(args) => args.clone(),
None => json!({}),
};
if let Some((_, _, handler)) = self.tools.get(tool_name) {
let handle = handler(arguments);
match handle.await.unwrap() {
Ok(val) => {
let response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(json!({
"content": [
{
"type": "text",
"text": val.to_string()
}
],
"isError": false
})),
error: None,
};
return Some(serde_json::to_string(&response).unwrap());
},
Err(e) => {
let error_response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id,
result: None,
error: Some(JsonRpcError {
code: -32000, message: e,
data: None
})
};
return Some(serde_json::to_string(&error_response).unwrap());
},
};
} else {
let error_response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id,
result: None,
error: Some(JsonRpcError {
code: -32601,
message: format!("Tool not found: {}", tool_name),
data: None
})
};
return Some(serde_json::to_string(&error_response).unwrap());
}
}
if request.method.starts_with("notifications/") {
eprintln!("Handling notification: {}, not sending a response", request.method);
return None;
}
eprintln!("Unknown method: {}", request.method);
let error_response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id,
result: None,
error: Some(JsonRpcError {
code: -32601,
message: format!("Method not found: {}", request.method),
data: None
})
};
let response_str = serde_json::to_string(&error_response).unwrap();
eprintln!("Sending method not found response: {}", response_str);
Some(response_str)
}
pub async fn run_stdio(self) -> Result<(), Box<dyn Error>> {
let stdin = tokio_io::stdin();
let mut stdout = tokio_io::stdout();
let reader = BufReader::new(stdin);
let mut lines = FramedRead::new(reader, LinesCodec::new());
while let Some(line_result) = lines.next().await {
match line_result {
Ok(line) => {
eprintln!("Processing line: {}", line);
if let Some(response) = self.handle_message(line).await {
eprintln!("Sending response: {}", response);
stdout.write_all(response.as_bytes()).await?;
stdout.write_all(b"\n").await?;
stdout.flush().await?;
} else {
eprintln!("No response to send (notification handling)");
}
}
Err(e) => eprintln!("Error reading line: {}", e),
}
}
Ok(())
}
}