Skip to main content

lsp_server_tokio/
codec.rs

1//! LSP wire protocol codec for Content-Length framed message I/O.
2//!
3//! This module provides [`LspCodec`], which implements the tokio-util codec traits
4//! for encoding and decoding LSP messages according to the wire protocol specification.
5//!
6//! # Wire Protocol Format
7//!
8//! LSP uses HTTP-style Content-Length framing:
9//! ```text
10//! Content-Length: {byte_count}\r\n\r\n{json_body}
11//! ```
12//!
13//! The Content-Length header specifies the byte length of the UTF-8 encoded JSON body.
14//! Headers are ASCII-encoded, followed by a blank line (CRLF CRLF), then the JSON body.
15//!
16//! # Example
17//!
18//! ```rust,no_run
19//! use bytes::BytesMut;
20//! use tokio_util::codec::{Decoder, Encoder};
21//! use lsp_server_tokio::{LspCodec, Message, Request};
22//!
23//! let mut codec = LspCodec::new();
24//! let mut buf = BytesMut::new();
25//!
26//! // Encode a message
27//! let msg = Message::Request(Request::new(1, "test", None));
28//! codec.encode(msg, &mut buf).unwrap();
29//!
30//! // Decode it back
31//! let decoded = codec.decode(&mut buf).unwrap().unwrap();
32//! assert!(decoded.is_request());
33//! ```
34
35use bytes::{Buf, BufMut, BytesMut};
36use std::io::{self, Write};
37use tokio_util::codec::{Decoder, Encoder};
38
39use crate::Message;
40
41/// The header terminator sequence for LSP wire protocol (CRLF CRLF).
42const HEADER_TERMINATOR: &[u8] = b"\r\n\r\n";
43
44/// LSP wire protocol codec implementing Content-Length framing.
45///
46/// This codec handles encoding and decoding of LSP [`Message`] types using
47/// the Content-Length header framing specified by the LSP wire protocol.
48///
49/// # Encoding
50///
51/// Messages are serialized to JSON, then prefixed with a Content-Length header:
52/// ```text
53/// Content-Length: {byte_length}\r\n\r\n{json_body}
54/// ```
55///
56/// # Decoding
57///
58/// The decoder handles partial reads by maintaining state between calls:
59/// - Returns `Ok(None)` if the header is incomplete
60/// - Returns `Ok(None)` if the body is incomplete
61/// - Returns `Ok(Some(message))` when a complete message is available
62///
63/// # Thread Safety
64///
65/// `LspCodec` maintains internal parsing state and should not be shared between
66/// concurrent readers. Use one codec instance per direction (read/write) or
67/// use `Framed` which handles this correctly.
68#[derive(Debug, Default)]
69pub struct LspCodec {
70    /// The content length parsed from headers, None if still reading headers.
71    content_length: Option<usize>,
72}
73
74impl LspCodec {
75    /// Creates a new `LspCodec` ready to encode and decode messages.
76    #[must_use]
77    pub fn new() -> Self {
78        Self {
79            content_length: None,
80        }
81    }
82}
83
84impl Decoder for LspCodec {
85    type Item = Message;
86    type Error = io::Error;
87
88    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
89        // If we don't have content length yet, parse headers
90        if self.content_length.is_none() {
91            // Look for header terminator
92            let Some(header_end) = find_subsequence(src, HEADER_TERMINATOR) else {
93                return Ok(None); // Need more data
94            };
95
96            // Parse Content-Length from headers
97            let headers = &src[..header_end];
98            let content_length = parse_content_length(headers)?;
99
100            // Remove headers from buffer (including terminator)
101            src.advance(header_end + HEADER_TERMINATOR.len());
102            self.content_length = Some(content_length);
103        }
104
105        // Now we have content length, check if body is complete
106        let content_length = self.content_length.unwrap();
107        if src.len() < content_length {
108            return Ok(None); // Need more data
109        }
110
111        // Extract body and parse
112        let body = src.split_to(content_length);
113        self.content_length = None; // Reset for next message
114
115        let message: Message = serde_json::from_slice(&body)
116            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
117
118        Ok(Some(message))
119    }
120}
121
122impl Encoder<Message> for LspCodec {
123    type Error = io::Error;
124
125    fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
126        let json =
127            serde_json::to_vec(&item).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
128
129        // Reserve space for header + body
130        // Header format: "Content-Length: {n}\r\n\r\n" (max ~30 bytes for reasonable sizes)
131        dst.reserve(32 + json.len());
132
133        // Write header using BufMut::writer() for std::io::Write compatibility
134        write!(dst.writer(), "Content-Length: {}\r\n\r\n", json.len())?;
135
136        // Write body
137        dst.extend_from_slice(&json);
138
139        Ok(())
140    }
141}
142
143/// Finds the position of `needle` in `haystack`.
144///
145/// Returns the index of the first byte of the first occurrence of `needle`
146/// within `haystack`, or `None` if not found.
147fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
148    haystack
149        .windows(needle.len())
150        .position(|window| window == needle)
151}
152
153/// Parses the Content-Length value from HTTP-style headers.
154///
155/// Headers are expected to be ASCII-encoded with CRLF line endings.
156/// The function is case-sensitive for "Content-Length" per the LSP specification,
157/// but the LSP spec recommends being lenient with header casing for interoperability.
158fn parse_content_length(headers: &[u8]) -> io::Result<usize> {
159    // Headers are ASCII, so this is safe
160    let headers_str =
161        std::str::from_utf8(headers).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
162
163    for line in headers_str.split("\r\n") {
164        // Case-insensitive match for robustness per LSP spec recommendation
165        let line_lower = line.to_ascii_lowercase();
166        if line_lower.strip_prefix("content-length:").is_some() {
167            // Get the actual value from the original line (after the colon)
168            let value = &line["content-length:".len()..];
169            return value
170                .trim()
171                .parse()
172                .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e));
173        }
174    }
175
176    Err(io::Error::new(
177        io::ErrorKind::InvalidData,
178        "Missing Content-Length header",
179    ))
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use crate::{ErrorCode, Notification, Request, Response, ResponseError};
186    use serde_json::json;
187
188    // ============== Encoder Tests ==============
189
190    #[test]
191    fn encode_request_test() {
192        let mut codec = LspCodec::new();
193        let mut buf = BytesMut::new();
194
195        let req = Request::new(1, "test/method", None);
196        let msg = Message::Request(req);
197        codec.encode(msg, &mut buf).unwrap();
198
199        let output = std::str::from_utf8(&buf).unwrap();
200
201        // Verify header format
202        assert!(output.starts_with("Content-Length: "));
203        assert!(output.contains("\r\n\r\n"));
204
205        // Split header and body
206        let parts: Vec<&str> = output.splitn(2, "\r\n\r\n").collect();
207        assert_eq!(parts.len(), 2);
208
209        // Verify body is valid JSON
210        let body = parts[1];
211        let parsed: serde_json::Value = serde_json::from_str(body).unwrap();
212        assert_eq!(parsed["method"], "test/method");
213        assert_eq!(parsed["id"], 1);
214        assert_eq!(parsed["jsonrpc"], "2.0");
215
216        // Verify Content-Length matches body byte length
217        let header = parts[0];
218        let content_length: usize = header
219            .strip_prefix("Content-Length: ")
220            .unwrap()
221            .parse()
222            .unwrap();
223        assert_eq!(content_length, body.len());
224    }
225
226    #[test]
227    fn encode_response_test() {
228        let mut codec = LspCodec::new();
229        let mut buf = BytesMut::new();
230
231        let resp = Response::ok(42, json!({"result": "value"}));
232        let msg = Message::Response(resp);
233        codec.encode(msg, &mut buf).unwrap();
234
235        let output = std::str::from_utf8(&buf).unwrap();
236        assert!(output.starts_with("Content-Length: "));
237        assert!(output.contains("\r\n\r\n"));
238
239        // Verify body
240        let body = output.split("\r\n\r\n").nth(1).unwrap();
241        let parsed: serde_json::Value = serde_json::from_str(body).unwrap();
242        assert_eq!(parsed["id"], 42);
243        assert!(parsed.get("result").is_some());
244    }
245
246    #[test]
247    fn encode_notification_test() {
248        let mut codec = LspCodec::new();
249        let mut buf = BytesMut::new();
250
251        let notif = Notification::new("textDocument/didOpen", Some(json!({"uri": "file:///test"})));
252        let msg = Message::Notification(notif);
253        codec.encode(msg, &mut buf).unwrap();
254
255        let output = std::str::from_utf8(&buf).unwrap();
256        assert!(output.starts_with("Content-Length: "));
257
258        // Verify body
259        let body = output.split("\r\n\r\n").nth(1).unwrap();
260        let parsed: serde_json::Value = serde_json::from_str(body).unwrap();
261        assert_eq!(parsed["method"], "textDocument/didOpen");
262        assert!(parsed.get("id").is_none());
263    }
264
265    // ============== Decoder Tests ==============
266
267    #[test]
268    fn decode_complete_message_test() {
269        let mut codec = LspCodec::new();
270        let mut buf = BytesMut::new();
271
272        let json_body = r#"{"jsonrpc":"2.0","id":1,"method":"test"}"#;
273        let framed = format!("Content-Length: {}\r\n\r\n{}", json_body.len(), json_body);
274        buf.extend_from_slice(framed.as_bytes());
275
276        let msg = codec.decode(&mut buf).unwrap().unwrap();
277        assert!(msg.is_request());
278
279        if let Message::Request(req) = msg {
280            assert_eq!(req.method, "test");
281        }
282    }
283
284    #[test]
285    fn decode_partial_header_test() {
286        let mut codec = LspCodec::new();
287        let mut buf = BytesMut::new();
288
289        // Feed partial header
290        buf.extend_from_slice(b"Content-Length: ");
291        assert!(codec.decode(&mut buf).unwrap().is_none());
292
293        // Feed more header
294        buf.extend_from_slice(b"40\r\n");
295        assert!(codec.decode(&mut buf).unwrap().is_none());
296
297        // Complete header
298        buf.extend_from_slice(b"\r\n");
299        assert!(codec.decode(&mut buf).unwrap().is_none()); // Still no body
300
301        // Now add body
302        let json_body = r#"{"jsonrpc":"2.0","id":1,"method":"test"}"#;
303        assert_eq!(json_body.len(), 40);
304        buf.extend_from_slice(json_body.as_bytes());
305
306        let msg = codec.decode(&mut buf).unwrap().unwrap();
307        assert!(msg.is_request());
308    }
309
310    #[test]
311    fn decode_partial_body_test() {
312        let mut codec = LspCodec::new();
313        let mut buf = BytesMut::new();
314
315        let json_body = r#"{"jsonrpc":"2.0","id":1,"method":"test"}"#;
316
317        // Feed complete header but partial body
318        buf.extend_from_slice(format!("Content-Length: {}\r\n\r\n", json_body.len()).as_bytes());
319        buf.extend_from_slice(&json_body.as_bytes()[..20]);
320        assert!(codec.decode(&mut buf).unwrap().is_none());
321
322        // Feed remaining body
323        buf.extend_from_slice(&json_body.as_bytes()[20..]);
324        let msg = codec.decode(&mut buf).unwrap().unwrap();
325        assert!(msg.is_request());
326    }
327
328    #[test]
329    fn decode_multiple_messages_test() {
330        let mut codec = LspCodec::new();
331        let mut buf = BytesMut::new();
332
333        // Add two complete messages
334        let json1 = r#"{"jsonrpc":"2.0","id":1,"method":"first"}"#;
335        let json2 = r#"{"jsonrpc":"2.0","id":2,"method":"second"}"#;
336
337        buf.extend_from_slice(
338            format!("Content-Length: {}\r\n\r\n{}", json1.len(), json1).as_bytes(),
339        );
340        buf.extend_from_slice(
341            format!("Content-Length: {}\r\n\r\n{}", json2.len(), json2).as_bytes(),
342        );
343
344        // Decode first
345        let msg1 = codec.decode(&mut buf).unwrap().unwrap();
346        if let Message::Request(req) = msg1 {
347            assert_eq!(req.method, "first");
348        } else {
349            panic!("Expected request");
350        }
351
352        // Buffer should still contain second message
353        assert!(!buf.is_empty());
354
355        // Decode second
356        let msg2 = codec.decode(&mut buf).unwrap().unwrap();
357        if let Message::Request(req) = msg2 {
358            assert_eq!(req.method, "second");
359        } else {
360            panic!("Expected request");
361        }
362
363        // Buffer should be empty
364        assert!(buf.is_empty());
365    }
366
367    #[test]
368    fn encode_decode_roundtrip_test() {
369        let mut codec = LspCodec::new();
370        let mut buf = BytesMut::new();
371
372        // Create various message types
373        let request = Message::Request(Request::new(
374            123,
375            "textDocument/completion",
376            Some(json!({"position": {"line": 10}})),
377        ));
378        let response = Message::Response(Response::ok(456, json!({"items": []})));
379        let notification = Message::Notification(Notification::new("textDocument/didSave", None));
380
381        // Encode all
382        codec.encode(request.clone(), &mut buf).unwrap();
383        codec.encode(response.clone(), &mut buf).unwrap();
384        codec.encode(notification.clone(), &mut buf).unwrap();
385
386        // Decode and verify
387        let decoded_request = codec.decode(&mut buf).unwrap().unwrap();
388        assert!(decoded_request.is_request());
389        if let (Message::Request(orig), Message::Request(dec)) = (&request, &decoded_request) {
390            assert_eq!(orig.id, dec.id);
391            assert_eq!(orig.method, dec.method);
392        }
393
394        let decoded_response = codec.decode(&mut buf).unwrap().unwrap();
395        assert!(decoded_response.is_response());
396
397        let decoded_notification = codec.decode(&mut buf).unwrap().unwrap();
398        assert!(decoded_notification.is_notification());
399
400        assert!(buf.is_empty());
401    }
402
403    #[test]
404    fn content_length_byte_count_test() {
405        let mut codec = LspCodec::new();
406        let mut buf = BytesMut::new();
407
408        // Create a message with Unicode content
409        // The method name contains non-ASCII characters
410        // "test/mehod" with a Japanese character (3 bytes in UTF-8)
411        let req = Request::new(1, "test/\u{65E5}\u{672C}", None); // "test/日本" - 2 Japanese chars (6 bytes each encoded)
412        let msg = Message::Request(req);
413        codec.encode(msg, &mut buf).unwrap();
414
415        let output = std::str::from_utf8(&buf).unwrap();
416
417        // Split header and body
418        let parts: Vec<&str> = output.splitn(2, "\r\n\r\n").collect();
419        let header = parts[0];
420        let body = parts[1];
421
422        // Extract Content-Length
423        let content_length: usize = header
424            .strip_prefix("Content-Length: ")
425            .unwrap()
426            .parse()
427            .unwrap();
428
429        // Content-Length should be BYTE count, not character count
430        assert_eq!(content_length, body.len());
431
432        // Verify the body contains more bytes than characters due to UTF-8 encoding
433        assert!(body.len() > body.chars().count());
434    }
435
436    #[test]
437    fn case_insensitive_header_parsing() {
438        let mut codec = LspCodec::new();
439        let mut buf = BytesMut::new();
440
441        // Use lowercase header (some clients might send this)
442        let json_body = r#"{"jsonrpc":"2.0","id":1,"method":"test"}"#;
443        let framed = format!("content-length: {}\r\n\r\n{}", json_body.len(), json_body);
444        buf.extend_from_slice(framed.as_bytes());
445
446        let msg = codec.decode(&mut buf).unwrap().unwrap();
447        assert!(msg.is_request());
448    }
449
450    #[test]
451    fn response_error_roundtrip() {
452        let mut codec = LspCodec::new();
453        let mut buf = BytesMut::new();
454
455        let error = ResponseError::new(ErrorCode::MethodNotFound, "Method not found");
456        let resp = Message::Response(Response::err(1, error));
457        codec.encode(resp, &mut buf).unwrap();
458
459        let decoded = codec.decode(&mut buf).unwrap().unwrap();
460        if let Message::Response(r) = decoded {
461            assert!(r.error.is_some());
462            assert_eq!(r.error.unwrap().code, -32601);
463        } else {
464            panic!("Expected response");
465        }
466    }
467
468    #[test]
469    fn decode_invalid_json_returns_error() {
470        let mut codec = LspCodec::new();
471        let mut buf = BytesMut::new();
472
473        let invalid_json = "{ not valid json }";
474        let framed = format!(
475            "Content-Length: {}\r\n\r\n{}",
476            invalid_json.len(),
477            invalid_json
478        );
479        buf.extend_from_slice(framed.as_bytes());
480
481        let result = codec.decode(&mut buf);
482        assert!(result.is_err());
483        assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
484    }
485
486    #[test]
487    fn decode_missing_content_length_returns_error() {
488        let mut codec = LspCodec::new();
489        let mut buf = BytesMut::new();
490
491        // No Content-Length header
492        let framed = "Some-Other-Header: value\r\n\r\n{}";
493        buf.extend_from_slice(framed.as_bytes());
494
495        let result = codec.decode(&mut buf);
496        assert!(result.is_err());
497    }
498}