use std::io::{BufRead, Write};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::errors::{McpError, McpErrorKind};
pub const MCP_PROTOCOL_VERSION: &str = "2025-06-18";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<Value>,
pub method: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub params: Option<Value>,
}
impl JsonRpcRequest {
pub fn is_notification(&self) -> bool {
self.id.is_none()
}
}
#[derive(Debug, Clone, Serialize)]
pub struct JsonRpcResponse {
pub jsonrpc: &'static str,
pub id: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<Value>,
}
impl JsonRpcResponse {
pub fn success(id: Value, result: Value) -> Self {
Self {
jsonrpc: "2.0",
id,
result: Some(result),
error: None,
}
}
pub fn error(id: Value, err: &McpError) -> Self {
Self {
jsonrpc: "2.0",
id,
result: None,
error: Some(err.to_rpc_error_object()),
}
}
}
pub fn read_message<R: BufRead>(r: &mut R) -> Result<Option<JsonRpcRequest>, McpError> {
let mut line = String::new();
let bytes = r
.read_line(&mut line)
.map_err(|e| McpError::new(McpErrorKind::ParseError, format!("read: {e}")))?;
if bytes == 0 {
return Ok(None);
}
let trimmed = line.trim();
if trimmed.is_empty() {
return read_message(r);
}
let req: JsonRpcRequest = serde_json::from_str(trimmed)
.map_err(|e| McpError::new(McpErrorKind::ParseError, format!("invalid frame: {e}")))?;
if req.jsonrpc != "2.0" {
return Err(McpError::new(
McpErrorKind::InvalidRequest,
format!("jsonrpc must be \"2.0\", got \"{}\"", req.jsonrpc),
));
}
Ok(Some(req))
}
pub fn write_response<W: Write>(w: &mut W, resp: &JsonRpcResponse) -> std::io::Result<()> {
let line = serde_json::to_string(resp).map_err(std::io::Error::other)?;
w.write_all(line.as_bytes())?;
w.write_all(b"\n")?;
w.flush()
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::io::Cursor;
#[test]
fn parse_initialize_request() {
let raw = r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18"}}"#;
let mut buf = Cursor::new(format!("{raw}\n"));
let req = read_message(&mut buf).unwrap().unwrap();
assert_eq!(req.method, "initialize");
assert_eq!(req.id, Some(json!(1)));
assert!(!req.is_notification());
}
#[test]
fn parse_notification_no_id() {
let raw = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
let mut buf = Cursor::new(format!("{raw}\n"));
let req = read_message(&mut buf).unwrap().unwrap();
assert!(req.is_notification());
}
#[test]
fn eof_returns_none() {
let mut empty = Cursor::new(Vec::<u8>::new());
assert!(read_message(&mut empty).unwrap().is_none());
}
#[test]
fn invalid_json_returns_parse_error() {
let mut bad = Cursor::new(b"not json\n".to_vec());
let err = read_message(&mut bad).unwrap_err();
assert_eq!(err.kind, McpErrorKind::ParseError);
}
#[test]
fn round_trip_response() {
let resp = JsonRpcResponse::success(json!("req-1"), json!({"ok": true}));
let mut out = Vec::new();
write_response(&mut out, &resp).unwrap();
let line = String::from_utf8(out).unwrap();
assert!(line.ends_with('\n'));
let parsed: serde_json::Value = serde_json::from_str(line.trim()).unwrap();
assert_eq!(parsed["jsonrpc"], "2.0");
assert_eq!(parsed["id"], "req-1");
assert_eq!(parsed["result"]["ok"], true);
}
#[test]
fn error_response_carries_code() {
let err = McpError::kind_only(McpErrorKind::MethodNotFound);
let resp = JsonRpcResponse::error(json!(42), &err);
let mut out = Vec::new();
write_response(&mut out, &resp).unwrap();
let parsed: serde_json::Value = serde_json::from_slice(&out).unwrap();
assert_eq!(parsed["error"]["code"], -32_601);
}
}