#[cfg(feature = "runtime-agnostic")]
use async_codec_lite::{Decoder, Encoder};
#[cfg(feature = "runtime-tokio")]
use tokio_util::codec::{Decoder, Encoder};
use bytes::{Buf, BufMut, BytesMut};
use std::{
io::{self, Write},
marker::PhantomData,
};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ParseError {
#[error("failed to parse JSON body: {0}")]
Body(serde_json::Error),
#[error("failed to encode response: {0}")]
Encode(io::Error),
#[error("failed to parse headers: {0}")]
Httparse(httparse::Error),
#[error("invalid content length value")]
InvalidLength,
#[error("missing required `Content-Length` header")]
MissingHeader,
#[error("request contains invalid UTF-8: {0}")]
Utf8(std::str::Utf8Error),
}
impl From<io::Error> for ParseError {
fn from(error: io::Error) -> Self {
ParseError::Encode(error)
}
}
impl From<serde_json::Error> for ParseError {
fn from(error: serde_json::Error) -> Self {
ParseError::Body(error)
}
}
impl From<std::str::Utf8Error> for ParseError {
fn from(error: std::str::Utf8Error) -> Self {
ParseError::Utf8(error)
}
}
#[derive(Clone, Debug)]
pub struct LanguageServerCodec<T> {
http_error: Option<httparse::Error>,
headers_len: Option<usize>,
content_len: Option<usize>,
_marker: PhantomData<T>,
}
impl<T> LanguageServerCodec<T> {
fn reset(&mut self) {
self.http_error = None;
self.headers_len = None;
self.content_len = None;
}
}
impl<T> Default for LanguageServerCodec<T> {
fn default() -> Self {
LanguageServerCodec {
http_error: None,
headers_len: None,
content_len: None,
_marker: PhantomData,
}
}
}
#[cfg(feature = "runtime-agnostic")]
impl<T: serde::Serialize> Encoder for LanguageServerCodec<T> {
type Error = ParseError;
type Item = T;
fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
let msg = serde_json::to_string(&item)?;
log::trace!("-> {}", msg);
dst.reserve(msg.len() + number_of_digits(msg.len()) + 20);
let mut writer = dst.writer();
write!(writer, "Content-Length: {}\r\n\r\n{}", msg.len(), msg)?;
writer.flush()?;
Ok(())
}
}
#[cfg(feature = "runtime-tokio")]
impl<T: serde::Serialize> Encoder<T> for LanguageServerCodec<T> {
type Error = ParseError;
fn encode(&mut self, item: T, dst: &mut BytesMut) -> Result<(), Self::Error> {
let msg = serde_json::to_string(&item)?;
log::trace!("-> {}", msg);
dst.reserve(msg.len() + number_of_digits(msg.len()) + 20);
let mut writer = dst.writer();
write!(writer, "Content-Length: {}\r\n\r\n{}", msg.len(), msg)?;
writer.flush()?;
Ok(())
}
}
#[inline]
fn number_of_digits(mut n: usize) -> usize {
let mut num_digits = 0;
while n > 0 {
n /= 10;
num_digits += 1;
}
num_digits
}
impl<T: serde::de::DeserializeOwned> Decoder for LanguageServerCodec<T> {
type Error = ParseError;
type Item = T;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if self.headers_len.is_none() {
{
let dst = &mut [httparse::EMPTY_HEADER; 2];
match httparse::parse_headers(src, dst) {
Ok(httparse::Status::Complete((header_len, headers))) => {
self.headers_len = Some(header_len);
for header in headers {
if header.name == "Content-Length" {
let content_len = std::str::from_utf8(header.value)?;
let content_len = content_len.parse().map_err(|_| ParseError::InvalidLength)?;
self.content_len = Some(content_len);
}
}
},
Ok(httparse::Status::Partial) => return Ok(None),
Err(error) => {
self.http_error = Some(error);
},
}
}
}
if let (Some(headers_len), Some(content_len)) = (self.headers_len, self.content_len) {
let delta = headers_len + content_len;
if src.len() < delta {
return Ok(None);
}
let message = &src[headers_len .. delta];
let message = std::str::from_utf8(message)?;
log::trace!("<- {}", message);
let data = match serde_json::from_str(message) {
Ok(parsed) => Ok(Some(parsed)),
Err(err) => Err(err.into()),
};
self.reset();
src.advance(delta);
data
} else {
self.reset();
if let Some(offset) = twoway::find_bytes(src, b"Content-Length") {
src.advance(offset);
}
if let Some(http_error) = self.http_error {
Err(ParseError::Httparse(http_error))
} else {
Err(ParseError::MissingHeader)
}
}
}
}
#[cfg(test)]
mod tests {
use bytes::BytesMut;
use serde_json::Value;
use super::*;
#[test]
fn decodes_invalid_content_length() {
let decoded = r#"{"jsonrpc":"2.0","method":"exit"}"#.to_string();
let content_len = "Content-Length: foo".to_string();
let encoded = format!("{}\r\n\r\n{}", content_len, decoded);
let mut codec = LanguageServerCodec::<()>::default();
let mut buffer = BytesMut::from(encoded.as_str());
let message = codec.decode(&mut buffer);
if let Err(ParseError::InvalidLength) = message {
} else {
unreachable!();
}
}
#[test]
fn decode_long_messages() {
let padding = "data".repeat(5000);
let decoded = format!(r#"{{ "jsonrpc" : "2.0", "method" : "foo", "params": "{}" }}"#, padding);
let content_len = format!("Content-Length: {}", decoded.len());
let encoded = format!("{}\r\n\r\n{}", content_len, decoded);
let mut codec = LanguageServerCodec::default();
let mut buffer = BytesMut::from(encoded.as_str());
let message = codec.decode(&mut buffer).unwrap();
let decoded: Value = serde_json::from_str(&decoded).unwrap();
assert_eq!(message, Some(decoded));
}
#[test]
fn decode_optional_content_type() {
let decoded = r#"{"jsonrpc":"2.0","method":"exit"}"#.to_string();
let content_len = format!("Content-Length: {}", decoded.len());
let content_type =
"Content-Type: application/vscode-jsonrpc; charset=utf-8; foo=\"bar\\nbaz\\\"qux\\\"\"".to_string();
let encoded = format!("{}\r\n{}\r\n\r\n{}", content_len, content_type, decoded);
let mut codec = LanguageServerCodec::default();
let mut buffer = BytesMut::from(encoded.as_str());
let message = codec.decode(&mut buffer).unwrap();
let decoded: Value = serde_json::from_str(&decoded).unwrap();
assert_eq!(message, Some(decoded));
}
#[test]
fn decode_partial() {
let content_len = "Content-Length: 42".to_string();
let encoded = format!("{}\r\n\r\n", content_len);
let mut codec = LanguageServerCodec::<()>::default();
let mut buffer = BytesMut::from(encoded.as_str());
let message = codec.decode(&mut buffer);
if let Ok(None) = message {
} else {
unreachable!();
}
}
#[test]
fn encode_and_decode() {
let decoded = r#"{"jsonrpc":"2.0","method":"exit"}"#.to_string();
let encoded = format!("Content-Length: {}\r\n\r\n{}", decoded.len(), decoded);
let mut codec = LanguageServerCodec::default();
let mut buffer = BytesMut::new();
let item: Value = serde_json::from_str(&decoded).unwrap();
codec.encode(item, &mut buffer).unwrap();
assert_eq!(buffer, BytesMut::from(encoded.as_str()));
let mut buffer = BytesMut::from(encoded.as_str());
let message = codec.decode(&mut buffer).unwrap();
let decoded = serde_json::from_str(&decoded).unwrap();
assert_eq!(message, Some(decoded));
}
#[test]
fn parse_error_from_io_error() {
let kind = std::io::ErrorKind::Other;
let error = "test error";
let error = std::io::Error::new(kind, error);
let _ = ParseError::from(error);
}
#[test]
fn parse_error_from_utf8_error() {
let bytes = vec![0, 159, 146, 150];
if let Err(error) = std::str::from_utf8(&bytes) {
let _ = ParseError::from(error);
} else {
unreachable!()
}
}
#[test]
fn recovers_from_parse_error() {
let decoded = r#"{"jsonrpc":"2.0","method":"exit"}"#.to_string();
let encoded = format!("Content-Length: {}\r\n\r\n{}", decoded.len(), decoded);
let mixed = format!("1234567890abcdefgh{}", encoded);
let mut codec = LanguageServerCodec::default();
let mut buffer = BytesMut::from(mixed.as_str());
assert!(matches!(codec.decode(&mut buffer), Err(ParseError::MissingHeader)));
let message = codec.decode(&mut buffer).unwrap();
let decoded: Value = serde_json::from_str(&decoded).unwrap();
assert_eq!(message, Some(decoded));
}
}