1use bytes::{Buf, BufMut, BytesMut};
36use std::io::{self, Write};
37use tokio_util::codec::{Decoder, Encoder};
38
39use crate::Message;
40
41const HEADER_TERMINATOR: &[u8] = b"\r\n\r\n";
43
44#[derive(Debug, Default)]
69pub struct LspCodec {
70 content_length: Option<usize>,
72}
73
74impl LspCodec {
75 #[must_use]
77 pub fn new() -> Self {
78 Self {
79 content_length: None,
80 }
81 }
82}
83
84impl Decoder for LspCodec {
85 type Item = Message;
86 type Error = io::Error;
87
88 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
89 if self.content_length.is_none() {
91 let Some(header_end) = find_subsequence(src, HEADER_TERMINATOR) else {
93 return Ok(None); };
95
96 let headers = &src[..header_end];
98 let content_length = parse_content_length(headers)?;
99
100 src.advance(header_end + HEADER_TERMINATOR.len());
102 self.content_length = Some(content_length);
103 }
104
105 let content_length = self.content_length.unwrap();
107 if src.len() < content_length {
108 return Ok(None); }
110
111 let body = src.split_to(content_length);
113 self.content_length = None; let message: Message = serde_json::from_slice(&body)
116 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
117
118 Ok(Some(message))
119 }
120}
121
122impl Encoder<Message> for LspCodec {
123 type Error = io::Error;
124
125 fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
126 let json =
127 serde_json::to_vec(&item).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
128
129 dst.reserve(32 + json.len());
132
133 write!(dst.writer(), "Content-Length: {}\r\n\r\n", json.len())?;
135
136 dst.extend_from_slice(&json);
138
139 Ok(())
140 }
141}
142
143fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
148 haystack
149 .windows(needle.len())
150 .position(|window| window == needle)
151}
152
153fn parse_content_length(headers: &[u8]) -> io::Result<usize> {
159 let headers_str =
161 std::str::from_utf8(headers).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
162
163 for line in headers_str.split("\r\n") {
164 let line_lower = line.to_ascii_lowercase();
166 if line_lower.strip_prefix("content-length:").is_some() {
167 let value = &line["content-length:".len()..];
169 return value
170 .trim()
171 .parse()
172 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e));
173 }
174 }
175
176 Err(io::Error::new(
177 io::ErrorKind::InvalidData,
178 "Missing Content-Length header",
179 ))
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use crate::{ErrorCode, Notification, Request, Response, ResponseError};
186 use serde_json::json;
187
188 #[test]
191 fn encode_request_test() {
192 let mut codec = LspCodec::new();
193 let mut buf = BytesMut::new();
194
195 let req = Request::new(1, "test/method", None);
196 let msg = Message::Request(req);
197 codec.encode(msg, &mut buf).unwrap();
198
199 let output = std::str::from_utf8(&buf).unwrap();
200
201 assert!(output.starts_with("Content-Length: "));
203 assert!(output.contains("\r\n\r\n"));
204
205 let parts: Vec<&str> = output.splitn(2, "\r\n\r\n").collect();
207 assert_eq!(parts.len(), 2);
208
209 let body = parts[1];
211 let parsed: serde_json::Value = serde_json::from_str(body).unwrap();
212 assert_eq!(parsed["method"], "test/method");
213 assert_eq!(parsed["id"], 1);
214 assert_eq!(parsed["jsonrpc"], "2.0");
215
216 let header = parts[0];
218 let content_length: usize = header
219 .strip_prefix("Content-Length: ")
220 .unwrap()
221 .parse()
222 .unwrap();
223 assert_eq!(content_length, body.len());
224 }
225
226 #[test]
227 fn encode_response_test() {
228 let mut codec = LspCodec::new();
229 let mut buf = BytesMut::new();
230
231 let resp = Response::ok(42, json!({"result": "value"}));
232 let msg = Message::Response(resp);
233 codec.encode(msg, &mut buf).unwrap();
234
235 let output = std::str::from_utf8(&buf).unwrap();
236 assert!(output.starts_with("Content-Length: "));
237 assert!(output.contains("\r\n\r\n"));
238
239 let body = output.split("\r\n\r\n").nth(1).unwrap();
241 let parsed: serde_json::Value = serde_json::from_str(body).unwrap();
242 assert_eq!(parsed["id"], 42);
243 assert!(parsed.get("result").is_some());
244 }
245
246 #[test]
247 fn encode_notification_test() {
248 let mut codec = LspCodec::new();
249 let mut buf = BytesMut::new();
250
251 let notif = Notification::new("textDocument/didOpen", Some(json!({"uri": "file:///test"})));
252 let msg = Message::Notification(notif);
253 codec.encode(msg, &mut buf).unwrap();
254
255 let output = std::str::from_utf8(&buf).unwrap();
256 assert!(output.starts_with("Content-Length: "));
257
258 let body = output.split("\r\n\r\n").nth(1).unwrap();
260 let parsed: serde_json::Value = serde_json::from_str(body).unwrap();
261 assert_eq!(parsed["method"], "textDocument/didOpen");
262 assert!(parsed.get("id").is_none());
263 }
264
265 #[test]
268 fn decode_complete_message_test() {
269 let mut codec = LspCodec::new();
270 let mut buf = BytesMut::new();
271
272 let json_body = r#"{"jsonrpc":"2.0","id":1,"method":"test"}"#;
273 let framed = format!("Content-Length: {}\r\n\r\n{}", json_body.len(), json_body);
274 buf.extend_from_slice(framed.as_bytes());
275
276 let msg = codec.decode(&mut buf).unwrap().unwrap();
277 assert!(msg.is_request());
278
279 if let Message::Request(req) = msg {
280 assert_eq!(req.method, "test");
281 }
282 }
283
284 #[test]
285 fn decode_partial_header_test() {
286 let mut codec = LspCodec::new();
287 let mut buf = BytesMut::new();
288
289 buf.extend_from_slice(b"Content-Length: ");
291 assert!(codec.decode(&mut buf).unwrap().is_none());
292
293 buf.extend_from_slice(b"40\r\n");
295 assert!(codec.decode(&mut buf).unwrap().is_none());
296
297 buf.extend_from_slice(b"\r\n");
299 assert!(codec.decode(&mut buf).unwrap().is_none()); let json_body = r#"{"jsonrpc":"2.0","id":1,"method":"test"}"#;
303 assert_eq!(json_body.len(), 40);
304 buf.extend_from_slice(json_body.as_bytes());
305
306 let msg = codec.decode(&mut buf).unwrap().unwrap();
307 assert!(msg.is_request());
308 }
309
310 #[test]
311 fn decode_partial_body_test() {
312 let mut codec = LspCodec::new();
313 let mut buf = BytesMut::new();
314
315 let json_body = r#"{"jsonrpc":"2.0","id":1,"method":"test"}"#;
316
317 buf.extend_from_slice(format!("Content-Length: {}\r\n\r\n", json_body.len()).as_bytes());
319 buf.extend_from_slice(&json_body.as_bytes()[..20]);
320 assert!(codec.decode(&mut buf).unwrap().is_none());
321
322 buf.extend_from_slice(&json_body.as_bytes()[20..]);
324 let msg = codec.decode(&mut buf).unwrap().unwrap();
325 assert!(msg.is_request());
326 }
327
328 #[test]
329 fn decode_multiple_messages_test() {
330 let mut codec = LspCodec::new();
331 let mut buf = BytesMut::new();
332
333 let json1 = r#"{"jsonrpc":"2.0","id":1,"method":"first"}"#;
335 let json2 = r#"{"jsonrpc":"2.0","id":2,"method":"second"}"#;
336
337 buf.extend_from_slice(
338 format!("Content-Length: {}\r\n\r\n{}", json1.len(), json1).as_bytes(),
339 );
340 buf.extend_from_slice(
341 format!("Content-Length: {}\r\n\r\n{}", json2.len(), json2).as_bytes(),
342 );
343
344 let msg1 = codec.decode(&mut buf).unwrap().unwrap();
346 if let Message::Request(req) = msg1 {
347 assert_eq!(req.method, "first");
348 } else {
349 panic!("Expected request");
350 }
351
352 assert!(!buf.is_empty());
354
355 let msg2 = codec.decode(&mut buf).unwrap().unwrap();
357 if let Message::Request(req) = msg2 {
358 assert_eq!(req.method, "second");
359 } else {
360 panic!("Expected request");
361 }
362
363 assert!(buf.is_empty());
365 }
366
367 #[test]
368 fn encode_decode_roundtrip_test() {
369 let mut codec = LspCodec::new();
370 let mut buf = BytesMut::new();
371
372 let request = Message::Request(Request::new(
374 123,
375 "textDocument/completion",
376 Some(json!({"position": {"line": 10}})),
377 ));
378 let response = Message::Response(Response::ok(456, json!({"items": []})));
379 let notification = Message::Notification(Notification::new("textDocument/didSave", None));
380
381 codec.encode(request.clone(), &mut buf).unwrap();
383 codec.encode(response.clone(), &mut buf).unwrap();
384 codec.encode(notification.clone(), &mut buf).unwrap();
385
386 let decoded_request = codec.decode(&mut buf).unwrap().unwrap();
388 assert!(decoded_request.is_request());
389 if let (Message::Request(orig), Message::Request(dec)) = (&request, &decoded_request) {
390 assert_eq!(orig.id, dec.id);
391 assert_eq!(orig.method, dec.method);
392 }
393
394 let decoded_response = codec.decode(&mut buf).unwrap().unwrap();
395 assert!(decoded_response.is_response());
396
397 let decoded_notification = codec.decode(&mut buf).unwrap().unwrap();
398 assert!(decoded_notification.is_notification());
399
400 assert!(buf.is_empty());
401 }
402
403 #[test]
404 fn content_length_byte_count_test() {
405 let mut codec = LspCodec::new();
406 let mut buf = BytesMut::new();
407
408 let req = Request::new(1, "test/\u{65E5}\u{672C}", None); let msg = Message::Request(req);
413 codec.encode(msg, &mut buf).unwrap();
414
415 let output = std::str::from_utf8(&buf).unwrap();
416
417 let parts: Vec<&str> = output.splitn(2, "\r\n\r\n").collect();
419 let header = parts[0];
420 let body = parts[1];
421
422 let content_length: usize = header
424 .strip_prefix("Content-Length: ")
425 .unwrap()
426 .parse()
427 .unwrap();
428
429 assert_eq!(content_length, body.len());
431
432 assert!(body.len() > body.chars().count());
434 }
435
436 #[test]
437 fn case_insensitive_header_parsing() {
438 let mut codec = LspCodec::new();
439 let mut buf = BytesMut::new();
440
441 let json_body = r#"{"jsonrpc":"2.0","id":1,"method":"test"}"#;
443 let framed = format!("content-length: {}\r\n\r\n{}", json_body.len(), json_body);
444 buf.extend_from_slice(framed.as_bytes());
445
446 let msg = codec.decode(&mut buf).unwrap().unwrap();
447 assert!(msg.is_request());
448 }
449
450 #[test]
451 fn response_error_roundtrip() {
452 let mut codec = LspCodec::new();
453 let mut buf = BytesMut::new();
454
455 let error = ResponseError::new(ErrorCode::MethodNotFound, "Method not found");
456 let resp = Message::Response(Response::err(1, error));
457 codec.encode(resp, &mut buf).unwrap();
458
459 let decoded = codec.decode(&mut buf).unwrap().unwrap();
460 if let Message::Response(r) = decoded {
461 assert!(r.error.is_some());
462 assert_eq!(r.error.unwrap().code, -32601);
463 } else {
464 panic!("Expected response");
465 }
466 }
467
468 #[test]
469 fn decode_invalid_json_returns_error() {
470 let mut codec = LspCodec::new();
471 let mut buf = BytesMut::new();
472
473 let invalid_json = "{ not valid json }";
474 let framed = format!(
475 "Content-Length: {}\r\n\r\n{}",
476 invalid_json.len(),
477 invalid_json
478 );
479 buf.extend_from_slice(framed.as_bytes());
480
481 let result = codec.decode(&mut buf);
482 assert!(result.is_err());
483 assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
484 }
485
486 #[test]
487 fn decode_missing_content_length_returns_error() {
488 let mut codec = LspCodec::new();
489 let mut buf = BytesMut::new();
490
491 let framed = "Some-Other-Header: value\r\n\r\n{}";
493 buf.extend_from_slice(framed.as_bytes());
494
495 let result = codec.decode(&mut buf);
496 assert!(result.is_err());
497 }
498}