use lsp_types::NumberOrString;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::Value;
use std::{
borrow::Cow,
io::{self, BufRead, Write},
};
pub type RequestId = NumberOrString;
fn invalid_data(error: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
io::Error::new(io::ErrorKind::InvalidData, error)
}
macro_rules! invalid_data {
($($tt:tt)*) => (invalid_data(format!($($tt)*)))
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
#[serde(untagged)]
pub enum Message {
Request(Request),
Response(Response),
Notification(Notification),
}
impl Message {
pub fn read(r: &mut impl BufRead) -> io::Result<Option<Self>> {
match read_msg(r)? {
Some(bytes) => match serde_json::from_slice(&bytes) {
Ok(msg) => Ok(Some(msg)),
Err(e) => Err(invalid_data!("malformed LSP message: {e:?}")),
},
None => Ok(None),
}
}
pub fn write(self, w: &mut impl Write) -> io::Result<()> {
let data = serde_json::to_vec(&JsonRpc {
jsonrpc: "2.0",
msg: self,
})?;
write!(w, "Content-Length: {}\r\n\r\n", data.len())?;
w.write_all(&data)?;
return w.flush();
#[derive(Serialize)]
struct JsonRpc {
jsonrpc: &'static str,
#[serde(flatten)]
msg: Message,
}
}
}
fn read_msg(r: &mut dyn BufRead) -> io::Result<Option<Vec<u8>>> {
let mut content_length: Option<usize> = None;
let mut buf = String::new();
loop {
buf.clear();
if r.read_line(&mut buf)? == 0 {
return Ok(None);
}
if !buf.ends_with("\r\n") {
return Err(invalid_data!("malformed header: {buf:?}"));
} else if buf == "\r\n" {
break; }
match buf.trim().split_once(": ") {
Some((h, v)) if h.eq_ignore_ascii_case("Content-Length") => {
content_length = Some(v.parse().map_err(invalid_data)?);
}
Some(_) => (), None => return Err(invalid_data!("malformed header: {buf:?}")),
}
}
match content_length {
None => Err(invalid_data!("no Content-Length header")),
Some(len) => {
let mut buf = buf.into_bytes();
buf.resize(len, 0);
r.read_exact(&mut buf)?;
Ok(Some(buf))
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Notification {
pub method: Cow<'static, str>,
#[serde(default = "Value::default")]
#[serde(skip_serializing_if = "Value::is_null")]
pub params: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Request {
pub id: RequestId,
pub method: Cow<'static, str>,
#[serde(default = "Value::default")]
#[serde(skip_serializing_if = "Value::is_null")]
pub params: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(untagged)]
pub enum Response {
Error { id: RequestId, error: ResponseError },
Result { id: RequestId, result: Value },
}
impl Response {
pub fn id(&self) -> RequestId {
match self {
Self::Result { id, .. } => id.clone(),
Self::Error { id, .. } => id.clone(),
}
}
pub fn null_resp(id: RequestId) -> Self {
Self::Result {
id,
result: Value::Null,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ResponseError {
pub code: ErrorCode,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<Value>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorCode {
ParseError,
InvalidRequest,
MethodNotFound,
InvalidParams,
InternalError,
ServerErrorStart,
ServerErrorEnd,
ServerNotInitialized,
Unknown,
RequestCancelled,
ContentModified,
ServerCancelled,
RequestFailed,
Custom(i32),
}
impl Serialize for ErrorCode {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let val = match self {
Self::ParseError => -32700,
Self::InvalidRequest => -32600,
Self::MethodNotFound => -32601,
Self::InvalidParams => -32602,
Self::InternalError => -32603,
Self::ServerErrorStart => -32099,
Self::ServerErrorEnd => -32000,
Self::ServerNotInitialized => -32002,
Self::Unknown => -32001,
Self::RequestCancelled => -32800,
Self::ContentModified => -32801,
Self::ServerCancelled => -32802,
Self::RequestFailed => -32803,
Self::Custom(i) => *i,
};
serializer.serialize_i32(val)
}
}
impl<'de> Deserialize<'de> for ErrorCode {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let val = match i32::deserialize(deserializer)? {
-32700 => Self::ParseError,
-32600 => Self::InvalidRequest,
-32601 => Self::MethodNotFound,
-32602 => Self::InvalidParams,
-32603 => Self::InternalError,
-32099 => Self::ServerErrorStart,
-32000 => Self::ServerErrorEnd,
-32002 => Self::ServerNotInitialized,
-32001 => Self::Unknown,
-32800 => Self::RequestCancelled,
-32801 => Self::ContentModified,
-32802 => Self::ServerCancelled,
-32803 => Self::RequestFailed,
i => Self::Custom(i),
};
Ok(val)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use simple_test_case::test_case;
use std::io::BufReader;
fn rpc(val: serde_json::Value) -> String {
let s = serde_json::to_string(&val).unwrap();
let len = s.len();
format!("Content-Length: {len}\r\n\r\n{s}")
}
#[test_case(
rpc(json!({ "id": 1, "method": "foo", "params": {"foo": "bar"} })),
Message::Request(Request {
id: RequestId::Number(1),
method: Cow::Borrowed("foo"),
params: json!({"foo": "bar"}),
});
"simple request"
)]
#[test_case(
rpc(json!({ "id": 1, "result": {"foo": "bar"} })),
Message::Response(Response::Result {
id: RequestId::Number(1),
result: json!({"foo": "bar"}),
});
"simple successful response"
)]
#[test_case(
rpc(json!({
"id": 1,
"error": {
"code": -32600,
"message": "invalid",
"data": {"foo": "bar"}
}
})),
Message::Response(Response::Error {
id: RequestId::Number(1),
error: ResponseError {
code: ErrorCode::InvalidRequest,
message: "invalid".to_owned(),
data: Some(json!({"foo": "bar"})),
}
});
"simple error response with data"
)]
#[test_case(
rpc(json!({
"id": 1,
"error": {
"code": -32600,
"message": "invalid",
}
})),
Message::Response(Response::Error {
id: RequestId::Number(1),
error: ResponseError {
code: ErrorCode::InvalidRequest,
message: "invalid".to_owned(),
data: None,
}
});
"simple error response without data"
)]
#[test_case(
rpc(json!({ "method": "foo", "params": {"foo": "bar"} })),
Message::Notification(Notification {
method: Cow::Borrowed("foo"),
params: json!({"foo": "bar"}),
});
"simple notification"
)]
#[test]
fn read_message(s: String, expected: Message) {
let m = Message::read(&mut BufReader::new(s.as_bytes()))
.unwrap()
.unwrap();
assert_eq!(m, expected);
}
#[test_case(ErrorCode::ParseError; "parse error")]
#[test_case(ErrorCode::InvalidRequest; "invalid request")]
#[test_case(ErrorCode::MethodNotFound; "method not found")]
#[test_case(ErrorCode::InvalidParams; "invalid params")]
#[test_case(ErrorCode::InternalError; "internal error")]
#[test_case(ErrorCode::ServerErrorStart; "server error start")]
#[test_case(ErrorCode::ServerErrorEnd; "server error end")]
#[test_case(ErrorCode::ServerNotInitialized; "server not initialized")]
#[test_case(ErrorCode::Unknown; "unknown")]
#[test_case(ErrorCode::RequestCancelled; "request cancelled")]
#[test_case(ErrorCode::ContentModified; "content modified")]
#[test_case(ErrorCode::ServerCancelled; "server cancelled")]
#[test_case(ErrorCode::RequestFailed; "request failed")]
#[test_case(ErrorCode::Custom(42); "custom")]
#[test]
fn error_code_serde_round_trip(e: ErrorCode) {
let s = serde_json::to_string(&e).unwrap();
let parsed: ErrorCode = serde_json::from_str(&s).unwrap();
assert_eq!(parsed, e);
}
}