use bytes::{Buf, BufMut, BytesMut};
use std::io::{self, Write};
use tokio_util::codec::{Decoder, Encoder};
use crate::Message;
const HEADER_TERMINATOR: &[u8] = b"\r\n\r\n";
#[derive(Debug, Default)]
pub struct LspCodec {
content_length: Option<usize>,
}
impl LspCodec {
#[must_use]
pub fn new() -> Self {
Self {
content_length: None,
}
}
}
impl Decoder for LspCodec {
type Item = Message;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if self.content_length.is_none() {
let Some(header_end) = find_subsequence(src, HEADER_TERMINATOR) else {
return Ok(None); };
let headers = &src[..header_end];
let content_length = parse_content_length(headers)?;
src.advance(header_end + HEADER_TERMINATOR.len());
self.content_length = Some(content_length);
}
let Some(content_length) = self.content_length else {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"missing Content-Length state: expected parsed headers before body, received empty decoder state",
));
};
if src.len() < content_length {
return Ok(None); }
let body = src.split_to(content_length);
self.content_length = None;
let message: Message = serde_json::from_slice(&body)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
Ok(Some(message))
}
}
impl Encoder<Message> for LspCodec {
type Error = io::Error;
fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
let json =
serde_json::to_vec(&item).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
dst.reserve(32 + json.len());
write!(dst.writer(), "Content-Length: {}\r\n\r\n", json.len())?;
dst.extend_from_slice(&json);
Ok(())
}
}
fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
haystack
.windows(needle.len())
.position(|window| window == needle)
}
fn parse_content_length(headers: &[u8]) -> io::Result<usize> {
let headers_str =
std::str::from_utf8(headers).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
for line in headers_str.split("\r\n") {
let line_lower = line.to_ascii_lowercase();
if line_lower.strip_prefix("content-length:").is_some() {
let value = &line["content-length:".len()..];
return value
.trim()
.parse()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e));
}
}
Err(io::Error::new(
io::ErrorKind::InvalidData,
"Missing Content-Length header",
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{ErrorCode, Notification, Request, Response, ResponseError};
use serde_json::json;
#[test]
fn encode_request_test() {
let mut codec = LspCodec::new();
let mut buf = BytesMut::new();
let req = Request::new(1, "test/method", None);
let msg = Message::Request(req);
codec.encode(msg, &mut buf).unwrap();
let output = std::str::from_utf8(&buf).unwrap();
assert!(output.starts_with("Content-Length: "));
assert!(output.contains("\r\n\r\n"));
let parts: Vec<&str> = output.splitn(2, "\r\n\r\n").collect();
assert_eq!(parts.len(), 2);
let body = parts[1];
let parsed: serde_json::Value = serde_json::from_str(body).unwrap();
assert_eq!(parsed["method"], "test/method");
assert_eq!(parsed["id"], 1);
assert_eq!(parsed["jsonrpc"], "2.0");
let header = parts[0];
let content_length: usize = header
.strip_prefix("Content-Length: ")
.unwrap()
.parse()
.unwrap();
assert_eq!(content_length, body.len());
}
#[test]
fn encode_response_test() {
let mut codec = LspCodec::new();
let mut buf = BytesMut::new();
let resp = Response::ok(42, json!({"result": "value"}));
let msg = Message::Response(resp);
codec.encode(msg, &mut buf).unwrap();
let output = std::str::from_utf8(&buf).unwrap();
assert!(output.starts_with("Content-Length: "));
assert!(output.contains("\r\n\r\n"));
let body = output.split("\r\n\r\n").nth(1).unwrap();
let parsed: serde_json::Value = serde_json::from_str(body).unwrap();
assert_eq!(parsed["id"], 42);
assert!(parsed.get("result").is_some());
}
#[test]
fn encode_notification_test() {
let mut codec = LspCodec::new();
let mut buf = BytesMut::new();
let notif = Notification::new("textDocument/didOpen", Some(json!({"uri": "file:///test"})));
let msg = Message::Notification(notif);
codec.encode(msg, &mut buf).unwrap();
let output = std::str::from_utf8(&buf).unwrap();
assert!(output.starts_with("Content-Length: "));
let body = output.split("\r\n\r\n").nth(1).unwrap();
let parsed: serde_json::Value = serde_json::from_str(body).unwrap();
assert_eq!(parsed["method"], "textDocument/didOpen");
assert!(parsed.get("id").is_none());
}
#[test]
fn decode_complete_message_test() {
let mut codec = LspCodec::new();
let mut buf = BytesMut::new();
let json_body = r#"{"jsonrpc":"2.0","id":1,"method":"test"}"#;
let framed = format!("Content-Length: {}\r\n\r\n{}", json_body.len(), json_body);
buf.extend_from_slice(framed.as_bytes());
let msg = codec.decode(&mut buf).unwrap().unwrap();
assert!(msg.is_request());
if let Message::Request(req) = msg {
assert_eq!(req.method, "test");
}
}
#[test]
fn decode_partial_header_test() {
let mut codec = LspCodec::new();
let mut buf = BytesMut::new();
buf.extend_from_slice(b"Content-Length: ");
assert!(codec.decode(&mut buf).unwrap().is_none());
buf.extend_from_slice(b"40\r\n");
assert!(codec.decode(&mut buf).unwrap().is_none());
buf.extend_from_slice(b"\r\n");
assert!(codec.decode(&mut buf).unwrap().is_none());
let json_body = r#"{"jsonrpc":"2.0","id":1,"method":"test"}"#;
assert_eq!(json_body.len(), 40);
buf.extend_from_slice(json_body.as_bytes());
let msg = codec.decode(&mut buf).unwrap().unwrap();
assert!(msg.is_request());
}
#[test]
fn decode_partial_body_test() {
let mut codec = LspCodec::new();
let mut buf = BytesMut::new();
let json_body = r#"{"jsonrpc":"2.0","id":1,"method":"test"}"#;
buf.extend_from_slice(format!("Content-Length: {}\r\n\r\n", json_body.len()).as_bytes());
buf.extend_from_slice(&json_body.as_bytes()[..20]);
assert!(codec.decode(&mut buf).unwrap().is_none());
buf.extend_from_slice(&json_body.as_bytes()[20..]);
let msg = codec.decode(&mut buf).unwrap().unwrap();
assert!(msg.is_request());
}
#[test]
fn decode_multiple_messages_test() {
let mut codec = LspCodec::new();
let mut buf = BytesMut::new();
let json1 = r#"{"jsonrpc":"2.0","id":1,"method":"first"}"#;
let json2 = r#"{"jsonrpc":"2.0","id":2,"method":"second"}"#;
buf.extend_from_slice(
format!("Content-Length: {}\r\n\r\n{}", json1.len(), json1).as_bytes(),
);
buf.extend_from_slice(
format!("Content-Length: {}\r\n\r\n{}", json2.len(), json2).as_bytes(),
);
let msg1 = codec.decode(&mut buf).unwrap().unwrap();
if let Message::Request(req) = msg1 {
assert_eq!(req.method, "first");
} else {
panic!("Expected request");
}
assert!(!buf.is_empty());
let msg2 = codec.decode(&mut buf).unwrap().unwrap();
if let Message::Request(req) = msg2 {
assert_eq!(req.method, "second");
} else {
panic!("Expected request");
}
assert!(buf.is_empty());
}
#[test]
fn encode_decode_roundtrip_test() {
let mut codec = LspCodec::new();
let mut buf = BytesMut::new();
let request = Message::Request(Request::new(
123,
"textDocument/completion",
Some(json!({"position": {"line": 10}})),
));
let response = Message::Response(Response::ok(456, json!({"items": []})));
let notification = Message::Notification(Notification::new("textDocument/didSave", None));
codec.encode(request.clone(), &mut buf).unwrap();
codec.encode(response.clone(), &mut buf).unwrap();
codec.encode(notification.clone(), &mut buf).unwrap();
let decoded_request = codec.decode(&mut buf).unwrap().unwrap();
assert!(decoded_request.is_request());
if let (Message::Request(orig), Message::Request(dec)) = (&request, &decoded_request) {
assert_eq!(orig.id, dec.id);
assert_eq!(orig.method, dec.method);
}
let decoded_response = codec.decode(&mut buf).unwrap().unwrap();
assert!(decoded_response.is_response());
let decoded_notification = codec.decode(&mut buf).unwrap().unwrap();
assert!(decoded_notification.is_notification());
assert!(buf.is_empty());
}
#[test]
fn content_length_byte_count_test() {
let mut codec = LspCodec::new();
let mut buf = BytesMut::new();
let req = Request::new(1, "test/\u{65E5}\u{672C}", None); let msg = Message::Request(req);
codec.encode(msg, &mut buf).unwrap();
let output = std::str::from_utf8(&buf).unwrap();
let parts: Vec<&str> = output.splitn(2, "\r\n\r\n").collect();
let header = parts[0];
let body = parts[1];
let content_length: usize = header
.strip_prefix("Content-Length: ")
.unwrap()
.parse()
.unwrap();
assert_eq!(content_length, body.len());
assert!(body.len() > body.chars().count());
}
#[test]
fn case_insensitive_header_parsing() {
let mut codec = LspCodec::new();
let mut buf = BytesMut::new();
let json_body = r#"{"jsonrpc":"2.0","id":1,"method":"test"}"#;
let framed = format!("content-length: {}\r\n\r\n{}", json_body.len(), json_body);
buf.extend_from_slice(framed.as_bytes());
let msg = codec.decode(&mut buf).unwrap().unwrap();
assert!(msg.is_request());
}
#[test]
fn response_error_roundtrip() {
let mut codec = LspCodec::new();
let mut buf = BytesMut::new();
let error = ResponseError::new(ErrorCode::MethodNotFound, "Method not found");
let resp = Message::Response(Response::err(1, error));
codec.encode(resp, &mut buf).unwrap();
let decoded = codec.decode(&mut buf).unwrap().unwrap();
if let Message::Response(r) = decoded {
assert!(r.error().is_some());
assert_eq!(r.into_error().unwrap().code, -32601);
} else {
panic!("Expected response");
}
}
#[test]
fn decode_invalid_json_returns_error() {
let mut codec = LspCodec::new();
let mut buf = BytesMut::new();
let invalid_json = "{ not valid json }";
let framed = format!(
"Content-Length: {}\r\n\r\n{}",
invalid_json.len(),
invalid_json
);
buf.extend_from_slice(framed.as_bytes());
let result = codec.decode(&mut buf);
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
}
#[test]
fn decode_missing_content_length_returns_error() {
let mut codec = LspCodec::new();
let mut buf = BytesMut::new();
let framed = "Some-Other-Header: value\r\n\r\n{}";
buf.extend_from_slice(framed.as_bytes());
let result = codec.decode(&mut buf);
assert!(result.is_err());
}
}