use axum::{
extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State},
response::Response,
};
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::server::RustMCP;
#[derive(Serialize, Deserialize, Debug)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<Value>,
pub method: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct JsonRpcNotification {
pub jsonrpc: String,
pub method: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
}
pub async fn ws_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<RustMCP>>,
) -> Response {
ws.on_upgrade(|socket| handle_socket(socket, state))
}
#[derive(Debug)]
pub struct ClientState {
}
impl ClientState {
pub fn new() -> Self {
Self {}
}
}
impl Default for ClientState {
fn default() -> Self {
Self::new()
}
}
async fn handle_socket(socket: WebSocket, state: Arc<RustMCP>) {
println!("WebSocket connection established");
let client_state = Arc::new(Mutex::new(ClientState::new()));
let (mut sender, mut receiver) = socket.split();
let state_clone = state.clone();
let client_state_clone = client_state.clone();
let receiver_handle = tokio::spawn(async move {
while let Some(Ok(msg)) = receiver.next().await {
if let Message::Text(text) = msg {
println!("Received message: {}", text);
if let Err(e) = handle_message(text, &state_clone, &mut sender, &client_state_clone).await {
eprintln!("Error handling message: {}", e);
break;
}
}
}
});
let _ = receiver_handle.await;
println!("WebSocket connection closed");
}
async fn handle_message(
text: String,
state: &Arc<RustMCP>,
sender: &mut futures::stream::SplitSink<WebSocket, Message>,
_client_state: &Arc<Mutex<ClientState>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(&text) {
let response = match request.method.as_str() {
"initialize" => {
let result = serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {
"listChanged": true
},
"resources": {
"subscribe": true,
"listChanged": true
},
"prompts": {
"listChanged": true
}
},
"serverInfo": {
"name": "RustMCP-rs",
"version": "0.1.0"
}
});
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id, result: Some(result),
error: None,
}
},
"notifications/initialized" => {
println!("Received initialized notification, sending success response");
let response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::Value::Number(serde_json::Number::from(0))),
result: Some(serde_json::json!({})),
error: None,
};
if let Ok(response_text) = serde_json::to_string(&response) {
let _ = sender.send(Message::Text(response_text)).await;
}
return Ok(());
},
_ => {
handle_jsonrpc_method(request, state).await
}
};
if let Ok(response_text) = serde_json::to_string(&response) {
sender.send(Message::Text(response_text)).await?;
}
}
Ok(())
}
async fn handle_jsonrpc_method(request: JsonRpcRequest, state: &Arc<RustMCP>) -> JsonRpcResponse {
match request.method.as_str() {
"tools/list" => {
let tools = state.mcp_list_tools();
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id.clone(),
result: Some(serde_json::json!({
"tools": tools
})),
error: None,
}
},
"resources/list" => {
let resources = state.mcp_list_resources();
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id.clone(),
result: Some(serde_json::json!({
"resources": resources
})),
error: None,
}
},
"prompts/list" => {
let prompts = state.mcp_list_prompts();
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id.clone(),
result: Some(serde_json::json!({
"prompts": prompts
})),
error: None,
}
},
"tools/call" => {
if let Some(params) = request.params {
if let Ok(call_params) = serde_json::from_value::<serde_json::Map<String, Value>>(params) {
let name = call_params.get("name").and_then(|v| v.as_str()).unwrap_or("");
let arguments = call_params.get("arguments").cloned();
let arguments_map = if let Some(args) = arguments {
serde_json::from_value::<std::collections::HashMap<String, Value>>(args).ok()
} else {
None
};
match state.mcp_call_tool(name, arguments_map).await {
Ok(result) => JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id.clone(),
result: Some(serde_json::json!({
"content": [{
"type": "text",
"text": format!("{}", result)
}],
"isError": false
})),
error: None,
},
Err(e) => JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id.clone(),
result: Some(serde_json::json!({
"content": [{
"type": "text",
"text": e
}],
"isError": true
})),
error: None,
},
}
} else {
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id,
result: None,
error: Some(JsonRpcError {
code: -32602,
message: "Invalid params".to_string(),
data: None,
}),
}
}
} else {
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id,
result: None,
error: Some(JsonRpcError {
code: -32602,
message: "Missing params".to_string(),
data: None,
}),
}
}
},
"resources/read" => {
if let Some(params) = request.params {
if let Ok(read_params) = serde_json::from_value::<serde_json::Map<String, Value>>(params) {
let uri = read_params.get("uri").and_then(|v| v.as_str()).unwrap_or("");
match state.mcp_read_resource(uri) {
Ok(result) => JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id,
result: Some(serde_json::json!({
"contents": [{
"uri": uri,
"text": result
}]
})),
error: None,
},
Err(e) => JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id,
result: None,
error: Some(JsonRpcError {
code: -32000,
message: e,
data: None,
}),
},
}
} else {
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id,
result: None,
error: Some(JsonRpcError {
code: -32602,
message: "Invalid params".to_string(),
data: None,
}),
}
}
} else {
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id,
result: None,
error: Some(JsonRpcError {
code: -32602,
message: "Missing params".to_string(),
data: None,
}),
}
}
},
_ => {
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: request.id.clone(),
result: None,
error: Some(JsonRpcError {
code: -32601,
message: "Method not found".to_string(),
data: None,
}),
}
}
}
}