Skip to main content

async_lsp_client/
message.rs

1use tokio::{
2    io::{AsyncReadExt, AsyncWriteExt},
3    process::{ChildStdin, ChildStdout},
4};
5
6use serde::{Deserialize, Serialize};
7use serde_json::{Map, Value};
8use tower_lsp::jsonrpc::{Request, Response};
9use tracing::debug;
10
11pub async fn get_message(stdout: &mut ChildStdout) -> Option<Message> {
12    let mut headers = Vec::new();
13    let mut content_length: Option<usize> = None;
14
15    loop {
16        let mut byte = [0];
17        if stdout.read_exact(&mut byte).await.is_err() {
18            return None;
19        }
20        headers.push(byte[0]);
21
22        // Check if we've reached the end of the headers (double CRLF)
23        if headers.ends_with(b"\r\n\r\n") {
24            let headers_str = String::from_utf8_lossy(&headers);
25            for line in headers_str.lines() {
26                if line.starts_with("Content-Length:") {
27                    let parts: Vec<&str> = line.splitn(2, ':').collect();
28                    if parts.len() > 1 {
29                        let length_str = parts[1].trim();
30                        content_length = Some(length_str.parse().unwrap());
31                        break;
32                    }
33                }
34            }
35            break;
36        }
37    }
38
39    let content_length = content_length.expect("Failed to find Content-Length header");
40
41    let mut body = vec![0u8; content_length];
42    stdout.read_exact(&mut body).await.unwrap();
43
44    let value: Map<String, Value> = serde_json::from_slice(&body).unwrap();
45    if cfg!(feature = "tracing") {
46        debug!("<==== {}", String::from_utf8(body).unwrap());
47    }
48    if value.contains_key("method") {
49        if value.contains_key("id") {
50            let request: Request = serde_json::from_value(Value::Object(value)).unwrap();
51            Some(Message::Request(request))
52        } else {
53            let notification: NotificationMessage =
54                serde_json::from_value(Value::Object(value)).unwrap();
55            Some(Message::Notification(notification))
56        }
57    } else {
58        let response: Response = serde_json::from_value(Value::Object(value)).unwrap();
59        Some(Message::Response(response))
60    }
61}
62
63pub async fn send_message(message: Value, stdin: &mut ChildStdin) {
64    let request_str = message.to_string();
65    if cfg!(feature = "tracing") {
66        debug!("====> {}", request_str);
67    }
68    let content_length = request_str.len();
69    let content = format!("Content-Length: {}\r\n\r\n{}", content_length, request_str);
70    stdin.write_all(content.as_bytes()).await.unwrap();
71    stdin.flush().await.unwrap();
72}
73
74#[derive(Serialize, Deserialize, Debug)]
75pub enum Message {
76    Request(Request),
77    Response(Response),
78    Notification(NotificationMessage),
79}
80
81#[derive(Serialize, Deserialize, Debug)]
82pub struct NotificationMessage {
83    pub jsonrpc: String,
84    pub method: String,
85    pub params: Option<serde_json::Value>,
86}