ra_ap_lsp_server/
msg.rs

1use std::{
2    fmt,
3    io::{self, BufRead, Write},
4};
5
6use serde::{de::DeserializeOwned, Deserialize, Serialize};
7
8use crate::error::ExtractError;
9
10#[derive(Serialize, Deserialize, Debug, Clone)]
11#[serde(untagged)]
12pub enum Message {
13    Request(Request),
14    Response(Response),
15    Notification(Notification),
16}
17
18impl From<Request> for Message {
19    fn from(request: Request) -> Message {
20        Message::Request(request)
21    }
22}
23
24impl From<Response> for Message {
25    fn from(response: Response) -> Message {
26        Message::Response(response)
27    }
28}
29
30impl From<Notification> for Message {
31    fn from(notification: Notification) -> Message {
32        Message::Notification(notification)
33    }
34}
35
36#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
37#[serde(transparent)]
38pub struct RequestId(IdRepr);
39
40#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
41#[serde(untagged)]
42enum IdRepr {
43    I32(i32),
44    String(String),
45}
46
47impl From<i32> for RequestId {
48    fn from(id: i32) -> RequestId {
49        RequestId(IdRepr::I32(id))
50    }
51}
52
53impl From<String> for RequestId {
54    fn from(id: String) -> RequestId {
55        RequestId(IdRepr::String(id))
56    }
57}
58
59impl fmt::Display for RequestId {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        match &self.0 {
62            IdRepr::I32(it) => fmt::Display::fmt(it, f),
63            // Use debug here, to make it clear that `92` and `"92"` are
64            // different, and to reduce WTF factor if the sever uses `" "` as an
65            // ID.
66            IdRepr::String(it) => fmt::Debug::fmt(it, f),
67        }
68    }
69}
70
71#[derive(Debug, Serialize, Deserialize, Clone)]
72pub struct Request {
73    pub id: RequestId,
74    pub method: String,
75    #[serde(default = "serde_json::Value::default")]
76    #[serde(skip_serializing_if = "serde_json::Value::is_null")]
77    pub params: serde_json::Value,
78}
79
80#[derive(Debug, Serialize, Deserialize, Clone)]
81pub struct Response {
82    // JSON RPC allows this to be null if it was impossible
83    // to decode the request's id. Ignore this special case
84    // and just die horribly.
85    pub id: RequestId,
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub result: Option<serde_json::Value>,
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub error: Option<ResponseError>,
90}
91
92#[derive(Debug, Serialize, Deserialize, Clone)]
93pub struct ResponseError {
94    pub code: i32,
95    pub message: String,
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub data: Option<serde_json::Value>,
98}
99
100#[derive(Clone, Copy, Debug)]
101#[non_exhaustive]
102pub enum ErrorCode {
103    // Defined by JSON RPC:
104    ParseError = -32700,
105    InvalidRequest = -32600,
106    MethodNotFound = -32601,
107    InvalidParams = -32602,
108    InternalError = -32603,
109    ServerErrorStart = -32099,
110    ServerErrorEnd = -32000,
111
112    /// Error code indicating that a server received a notification or
113    /// request before the server has received the `initialize` request.
114    ServerNotInitialized = -32002,
115    UnknownErrorCode = -32001,
116
117    // Defined by the protocol:
118    /// The client has canceled a request and a server has detected
119    /// the cancel.
120    RequestCanceled = -32800,
121
122    /// The server detected that the content of a document got
123    /// modified outside normal conditions. A server should
124    /// NOT send this error code if it detects a content change
125    /// in it unprocessed messages. The result even computed
126    /// on an older state might still be useful for the client.
127    ///
128    /// If a client decides that a result is not of any use anymore
129    /// the client should cancel the request.
130    ContentModified = -32801,
131
132    /// The server cancelled the request. This error code should
133    /// only be used for requests that explicitly support being
134    /// server cancellable.
135    ///
136    /// @since 3.17.0
137    ServerCancelled = -32802,
138
139    /// A request failed but it was syntactically correct, e.g the
140    /// method name was known and the parameters were valid. The error
141    /// message should contain human readable information about why
142    /// the request failed.
143    ///
144    /// @since 3.17.0
145    RequestFailed = -32803,
146}
147
148#[derive(Debug, Serialize, Deserialize, Clone)]
149pub struct Notification {
150    pub method: String,
151    #[serde(default = "serde_json::Value::default")]
152    #[serde(skip_serializing_if = "serde_json::Value::is_null")]
153    pub params: serde_json::Value,
154}
155
156impl Message {
157    pub fn read(r: &mut impl BufRead) -> io::Result<Option<Message>> {
158        Message::_read(r)
159    }
160    fn _read(r: &mut dyn BufRead) -> io::Result<Option<Message>> {
161        let text = match read_msg_text(r)? {
162            None => return Ok(None),
163            Some(text) => text,
164        };
165        let msg = serde_json::from_str(&text)?;
166        Ok(Some(msg))
167    }
168    pub fn write(self, w: &mut impl Write) -> io::Result<()> {
169        self._write(w)
170    }
171    fn _write(self, w: &mut dyn Write) -> io::Result<()> {
172        #[derive(Serialize)]
173        struct JsonRpc {
174            jsonrpc: &'static str,
175            #[serde(flatten)]
176            msg: Message,
177        }
178        let text = serde_json::to_string(&JsonRpc { jsonrpc: "2.0", msg: self })?;
179        write_msg_text(w, &text)
180    }
181}
182
183impl Response {
184    pub fn new_ok<R: Serialize>(id: RequestId, result: R) -> Response {
185        Response { id, result: Some(serde_json::to_value(result).unwrap()), error: None }
186    }
187    pub fn new_err(id: RequestId, code: i32, message: String) -> Response {
188        let error = ResponseError { code, message, data: None };
189        Response { id, result: None, error: Some(error) }
190    }
191}
192
193impl Request {
194    pub fn new<P: Serialize>(id: RequestId, method: String, params: P) -> Request {
195        Request { id, method, params: serde_json::to_value(params).unwrap() }
196    }
197    pub fn extract<P: DeserializeOwned>(
198        self,
199        method: &str,
200    ) -> Result<(RequestId, P), ExtractError<Request>> {
201        if self.method != method {
202            return Err(ExtractError::MethodMismatch(self));
203        }
204        match serde_json::from_value(self.params) {
205            Ok(params) => Ok((self.id, params)),
206            Err(error) => Err(ExtractError::JsonError { method: self.method, error }),
207        }
208    }
209
210    pub(crate) fn is_shutdown(&self) -> bool {
211        self.method == "shutdown"
212    }
213    pub(crate) fn is_initialize(&self) -> bool {
214        self.method == "initialize"
215    }
216}
217
218impl Notification {
219    pub fn new(method: String, params: impl Serialize) -> Notification {
220        Notification { method, params: serde_json::to_value(params).unwrap() }
221    }
222    pub fn extract<P: DeserializeOwned>(
223        self,
224        method: &str,
225    ) -> Result<P, ExtractError<Notification>> {
226        if self.method != method {
227            return Err(ExtractError::MethodMismatch(self));
228        }
229        match serde_json::from_value(self.params) {
230            Ok(params) => Ok(params),
231            Err(error) => Err(ExtractError::JsonError { method: self.method, error }),
232        }
233    }
234    pub(crate) fn is_exit(&self) -> bool {
235        self.method == "exit"
236    }
237    pub(crate) fn is_initialized(&self) -> bool {
238        self.method == "initialized"
239    }
240}
241
242fn read_msg_text(inp: &mut dyn BufRead) -> io::Result<Option<String>> {
243    fn invalid_data(error: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
244        io::Error::new(io::ErrorKind::InvalidData, error)
245    }
246    macro_rules! invalid_data {
247        ($($tt:tt)*) => (invalid_data(format!($($tt)*)))
248    }
249
250    let mut size = None;
251    let mut buf = String::new();
252    loop {
253        buf.clear();
254        if inp.read_line(&mut buf)? == 0 {
255            return Ok(None);
256        }
257        if !buf.ends_with("\r\n") {
258            return Err(invalid_data!("malformed header: {:?}", buf));
259        }
260        let buf = &buf[..buf.len() - 2];
261        if buf.is_empty() {
262            break;
263        }
264        let mut parts = buf.splitn(2, ": ");
265        let header_name = parts.next().unwrap();
266        let header_value =
267            parts.next().ok_or_else(|| invalid_data!("malformed header: {:?}", buf))?;
268        if header_name == "Content-Length" {
269            size = Some(header_value.parse::<usize>().map_err(invalid_data)?);
270        }
271    }
272    let size: usize = size.ok_or_else(|| invalid_data!("no Content-Length"))?;
273    let mut buf = buf.into_bytes();
274    buf.resize(size, 0);
275    inp.read_exact(&mut buf)?;
276    let buf = String::from_utf8(buf).map_err(invalid_data)?;
277    log::debug!("< {}", buf);
278    Ok(Some(buf))
279}
280
281fn write_msg_text(out: &mut dyn Write, msg: &str) -> io::Result<()> {
282    log::debug!("> {}", msg);
283    write!(out, "Content-Length: {}\r\n\r\n", msg.len())?;
284    out.write_all(msg.as_bytes())?;
285    out.flush()?;
286    Ok(())
287}
288
289#[cfg(test)]
290mod tests {
291    use super::{Message, Notification, Request, RequestId};
292
293    #[test]
294    fn shutdown_with_explicit_null() {
295        let text = "{\"jsonrpc\": \"2.0\",\"id\": 3,\"method\": \"shutdown\", \"params\": null }";
296        let msg: Message = serde_json::from_str(text).unwrap();
297
298        assert!(
299            matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown")
300        );
301    }
302
303    #[test]
304    fn shutdown_with_no_params() {
305        let text = "{\"jsonrpc\": \"2.0\",\"id\": 3,\"method\": \"shutdown\"}";
306        let msg: Message = serde_json::from_str(text).unwrap();
307
308        assert!(
309            matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown")
310        );
311    }
312
313    #[test]
314    fn notification_with_explicit_null() {
315        let text = "{\"jsonrpc\": \"2.0\",\"method\": \"exit\", \"params\": null }";
316        let msg: Message = serde_json::from_str(text).unwrap();
317
318        assert!(matches!(msg, Message::Notification(not) if not.method == "exit"));
319    }
320
321    #[test]
322    fn notification_with_no_params() {
323        let text = "{\"jsonrpc\": \"2.0\",\"method\": \"exit\"}";
324        let msg: Message = serde_json::from_str(text).unwrap();
325
326        assert!(matches!(msg, Message::Notification(not) if not.method == "exit"));
327    }
328
329    #[test]
330    fn serialize_request_with_null_params() {
331        let msg = Message::Request(Request {
332            id: RequestId::from(3),
333            method: "shutdown".into(),
334            params: serde_json::Value::Null,
335        });
336        let serialized = serde_json::to_string(&msg).unwrap();
337
338        assert_eq!("{\"id\":3,\"method\":\"shutdown\"}", serialized);
339    }
340
341    #[test]
342    fn serialize_notification_with_null_params() {
343        let msg = Message::Notification(Notification {
344            method: "exit".into(),
345            params: serde_json::Value::Null,
346        });
347        let serialized = serde_json::to_string(&msg).unwrap();
348
349        assert_eq!("{\"method\":\"exit\"}", serialized);
350    }
351}