use std::io::{self, BufRead, Write};
use crate::protocol::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
#[derive(Debug)]
pub enum IncomingMessage {
Request(JsonRpcRequest),
Notification(JsonRpcNotification),
}
pub struct StdioTransport {
reader: Box<dyn BufRead + Send>,
writer: Box<dyn Write + Send>,
}
impl StdioTransport {
pub fn stdio() -> Self {
Self {
reader: Box::new(io::BufReader::new(io::stdin())),
writer: Box::new(io::stdout()),
}
}
#[cfg(test)]
pub fn new(reader: Box<dyn BufRead + Send>, writer: Box<dyn Write + Send>) -> Self {
Self { reader, writer }
}
pub fn read_message(&mut self) -> io::Result<Option<IncomingMessage>> {
let mut line = String::new();
match self.reader.read_line(&mut line) {
Ok(0) => Ok(None), Ok(_) => {
let line = line.trim();
if line.is_empty() {
return Ok(None);
}
tracing::debug!("Received: {}", line);
if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(line) {
return Ok(Some(IncomingMessage::Request(request)));
}
if let Ok(notification) = serde_json::from_str::<JsonRpcNotification>(line) {
return Ok(Some(IncomingMessage::Notification(notification)));
}
tracing::warn!("Failed to parse message: {}", line);
Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid JSON-RPC message: {}", line),
))
}
Err(e) => Err(e),
}
}
pub fn write_response(&mut self, response: &JsonRpcResponse) -> io::Result<()> {
let json = serde_json::to_string(response).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Serialization error: {}", e),
)
})?;
tracing::debug!("Sending: {}", json);
writeln!(self.writer, "{}", json)?;
self.writer.flush()
}
pub fn write_notification(&mut self, notification: &JsonRpcNotification) -> io::Result<()> {
let json = serde_json::to_string(notification).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Serialization error: {}", e),
)
})?;
tracing::debug!("Sending notification: {}", json);
writeln!(self.writer, "{}", json)?;
self.writer.flush()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::RequestId;
use std::io::Cursor;
#[test]
fn test_read_request() {
let input = r#"{"jsonrpc":"2.0","id":1,"method":"test","params":{}}"#;
let reader = Box::new(Cursor::new(format!("{}\n", input)));
let writer = Box::new(Vec::new());
let mut transport = StdioTransport::new(reader, writer);
let msg = transport.read_message().unwrap();
match msg {
Some(IncomingMessage::Request(req)) => {
assert_eq!(req.method, "test");
assert_eq!(req.id, RequestId::Number(1));
}
_ => panic!("Expected request"),
}
}
#[test]
fn test_read_notification() {
let input = r#"{"jsonrpc":"2.0","method":"initialized"}"#;
let reader = Box::new(Cursor::new(format!("{}\n", input)));
let writer = Box::new(Vec::new());
let mut transport = StdioTransport::new(reader, writer);
let msg = transport.read_message().unwrap();
match msg {
Some(IncomingMessage::Notification(notif)) => {
assert_eq!(notif.method, "initialized");
}
_ => panic!("Expected notification"),
}
}
#[test]
fn test_write_response() {
use std::sync::{Arc, Mutex};
let buffer = Arc::new(Mutex::new(Vec::new()));
let buffer_clone = buffer.clone();
struct SharedWriter(Arc<Mutex<Vec<u8>>>);
impl std::io::Write for SharedWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.lock().unwrap().extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
let reader = Box::new(Cursor::new(Vec::new()));
let writer = Box::new(SharedWriter(buffer_clone));
let mut transport = StdioTransport::new(reader, writer);
let response =
JsonRpcResponse::success(RequestId::Number(1), serde_json::json!({"test": true}));
transport.write_response(&response).unwrap();
let output = String::from_utf8(buffer.lock().unwrap().clone()).unwrap();
assert!(output.contains("\"jsonrpc\":\"2.0\""));
assert!(output.contains("\"id\":1"));
}
#[test]
fn test_read_eof() {
let reader = Box::new(Cursor::new(Vec::new()));
let writer = Box::new(Vec::new());
let mut transport = StdioTransport::new(reader, writer);
let msg = transport.read_message().unwrap();
assert!(msg.is_none());
}
#[test]
fn test_read_empty_line() {
let reader = Box::new(Cursor::new("\n".to_string()));
let writer = Box::new(Vec::new());
let mut transport = StdioTransport::new(reader, writer);
let msg = transport.read_message().unwrap();
assert!(msg.is_none());
}
#[test]
fn test_read_invalid_json() {
let reader = Box::new(Cursor::new("not valid json\n".to_string()));
let writer = Box::new(Vec::new());
let mut transport = StdioTransport::new(reader, writer);
let result = transport.read_message();
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::InvalidData);
}
#[test]
fn test_write_notification() {
use std::sync::{Arc, Mutex};
let buffer = Arc::new(Mutex::new(Vec::new()));
let buffer_clone = buffer.clone();
struct SharedWriter(Arc<Mutex<Vec<u8>>>);
impl std::io::Write for SharedWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.lock().unwrap().extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
let reader = Box::new(Cursor::new(Vec::new()));
let writer = Box::new(SharedWriter(buffer_clone));
let mut transport = StdioTransport::new(reader, writer);
let notification = JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: "test/notification".to_string(),
params: Some(serde_json::json!({"key": "value"})),
};
transport.write_notification(¬ification).unwrap();
let output = String::from_utf8(buffer.lock().unwrap().clone()).unwrap();
assert!(output.contains("\"jsonrpc\":\"2.0\""));
assert!(output.contains("\"method\":\"test/notification\""));
assert!(output.ends_with('\n'));
}
#[test]
fn test_write_notification_without_params() {
use std::sync::{Arc, Mutex};
let buffer = Arc::new(Mutex::new(Vec::new()));
let buffer_clone = buffer.clone();
struct SharedWriter(Arc<Mutex<Vec<u8>>>);
impl std::io::Write for SharedWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0.lock().unwrap().extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
let reader = Box::new(Cursor::new(Vec::new()));
let writer = Box::new(SharedWriter(buffer_clone));
let mut transport = StdioTransport::new(reader, writer);
let notification = JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: "initialized".to_string(),
params: None,
};
transport.write_notification(¬ification).unwrap();
let output = String::from_utf8(buffer.lock().unwrap().clone()).unwrap();
assert!(output.contains("\"method\":\"initialized\""));
}
#[test]
fn test_read_request_with_string_id() {
let input = r#"{"jsonrpc":"2.0","id":"abc","method":"ping"}"#;
let reader = Box::new(Cursor::new(format!("{}\n", input)));
let writer = Box::new(Vec::new());
let mut transport = StdioTransport::new(reader, writer);
let msg = transport.read_message().unwrap();
match msg {
Some(IncomingMessage::Request(req)) => {
assert_eq!(req.method, "ping");
assert_eq!(req.id, RequestId::String("abc".to_string()));
}
_ => panic!("Expected request"),
}
}
}