use crate::mcp::error::{McpError, McpResult};
use serde::{Deserialize, Serialize};
use std::io::{BufRead, BufReader, Write};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
#[serde(default)]
pub id: Option<serde_json::Value>,
pub method: String,
#[serde(default)]
pub params: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
pub id: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
impl JsonRpcResponse {
pub fn success(id: Option<serde_json::Value>, result: serde_json::Value) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
result: Some(result),
error: None,
}
}
pub fn error(id: Option<serde_json::Value>, err: McpError) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
result: None,
error: Some(JsonRpcError {
code: err.code,
message: err.message,
data: err.data,
}),
}
}
}
impl From<McpError> for JsonRpcError {
fn from(err: McpError) -> Self {
Self {
code: err.code,
message: err.message,
data: err.data,
}
}
}
pub struct StdioTransport {
reader: BufReader<std::io::Stdin>,
}
impl StdioTransport {
pub fn new() -> Self {
Self {
reader: BufReader::new(std::io::stdin()),
}
}
pub fn read_request(&mut self) -> McpResult<Option<JsonRpcRequest>> {
let mut line = String::new();
match self.reader.read_line(&mut line) {
Ok(0) => {
Ok(None)
}
Ok(_) => {
let line = line.trim();
if line.is_empty() {
return self.read_request();
}
let request: JsonRpcRequest = serde_json::from_str(line)
.map_err(|e| McpError::parse_error(format!("Invalid JSON: {}", e)))?;
if request.jsonrpc != "2.0" {
return Err(McpError::invalid_request(format!(
"Expected JSON-RPC 2.0, got '{}'",
request.jsonrpc
)));
}
Ok(Some(request))
}
Err(e) => Err(McpError::internal_error(format!("Read error: {}", e))),
}
}
pub fn write_response(&self, response: &JsonRpcResponse) -> McpResult<()> {
let json = serde_json::to_string(response)?;
let mut stdout = std::io::stdout().lock();
writeln!(stdout, "{}", json)?;
stdout.flush()?;
Ok(())
}
#[allow(dead_code)] pub fn write_stderr(&self, message: &str) -> McpResult<()> {
let mut stderr = std::io::stderr().lock();
writeln!(stderr, "{}", message)?;
stderr.flush()?;
Ok(())
}
#[allow(dead_code)] pub fn read_confirmation(&mut self) -> McpResult<String> {
let mut line = String::new();
self.reader.read_line(&mut line)?;
Ok(line.trim().to_string())
}
}
impl Default for StdioTransport {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_success_response() {
let response = JsonRpcResponse::success(
Some(serde_json::json!(1)),
serde_json::json!({"status": "ok"}),
);
assert_eq!(response.jsonrpc, "2.0");
assert!(response.result.is_some());
assert!(response.error.is_none());
}
#[test]
fn test_error_response() {
let err = McpError::unknown_tool("git");
let response = JsonRpcResponse::error(Some(serde_json::json!(1)), err);
assert_eq!(response.jsonrpc, "2.0");
assert!(response.result.is_none());
assert!(response.error.is_some());
assert_eq!(response.error.unwrap().code, -32001);
}
#[test]
fn test_request_parsing() {
let json = r#"{"jsonrpc":"2.0","id":1,"method":"tools/list","params":{}}"#;
let request: JsonRpcRequest = serde_json::from_str(json).unwrap();
assert_eq!(request.jsonrpc, "2.0");
assert_eq!(request.id, Some(serde_json::json!(1)));
assert_eq!(request.method, "tools/list");
}
#[test]
fn test_notification_parsing() {
let json = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
let request: JsonRpcRequest = serde_json::from_str(json).unwrap();
assert_eq!(request.jsonrpc, "2.0");
assert!(request.id.is_none());
assert_eq!(request.method, "notifications/initialized");
}
}