lsp_client_rs/
client.rs

1use anyhow::{anyhow, Result};
2use std::pin::Pin;
3use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4use tokio::net::{TcpStream, UnixStream};
5
6use crate::protocol::{RequestMessage, ResponseMessage};
7
8pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin {}
9impl<T: AsyncRead + AsyncWrite + Unpin + ?Sized> AsyncReadWrite for T {}
10
11type Stream = Pin<Box<dyn AsyncReadWrite + Send>>;
12
13pub struct LspClient {
14    stream: Stream,
15}
16
17impl LspClient {
18    pub async fn new(addr: &str) -> Result<Self> {
19        let scheme = addr.split(':').next().ok_or(anyhow!(
20            "Invalid address format. Expected format: <scheme:address:port> or <scheme:path> for UNIX sockets."
21        ))?;
22
23        let stream: Stream = match scheme {
24            "tcp" => {
25                // Skip the scheme part and rejoin the rest (address and port)
26                let addr = addr
27                    .splitn(2, ':')
28                    .nth(1)
29                    .ok_or(anyhow!("Invalid TCP address format."))?;
30                let tcp_stream = TcpStream::connect(addr).await?;
31                Box::pin(tcp_stream) as Stream
32            }
33            "unix" => {
34                // Skip the scheme part for UNIX domain socket path
35                let path = addr
36                    .splitn(2, ':')
37                    .nth(1)
38                    .ok_or(anyhow!("Invalid UNIX socket path format."))?;
39                let unix_stream = UnixStream::connect(path).await?;
40                Box::pin(unix_stream) as Stream
41            }
42            _ => {
43                return Err(anyhow!(
44                    "Unsupported scheme '{}'. Use 'tcp' or 'unix'.",
45                    scheme
46                ))
47            }
48        };
49
50        Ok(Self { stream })
51    }
52
53    pub async fn send_request(&mut self, request: RequestMessage) -> Result<()> {
54        let request_str = serde_json::to_string(&request)?;
55        let content_length = request_str.len();
56        let header = format!("Content-Length: {}\r\n\r\n{}", content_length, request_str);
57        self.stream.write_all(header.as_bytes()).await?;
58        self.stream.flush().await?;
59        Ok(())
60    }
61
62    pub async fn handle_response(&mut self) -> Result<ResponseMessage> {
63        let mut headers = Vec::new();
64        let mut content_length: Option<usize> = None;
65
66        // Read headers
67        loop {
68            let mut byte = [0];
69            self.stream.read_exact(&mut byte).await?;
70            headers.push(byte[0]);
71
72            // Check if we've reached the end of the headers (double CRLF)
73            if headers.ends_with(b"\r\n\r\n") {
74                let headers_str = String::from_utf8_lossy(&headers);
75                for line in headers_str.lines() {
76                    if line.starts_with("Content-Length:") {
77                        let parts: Vec<&str> = line.splitn(2, ':').collect();
78                        if parts.len() > 1 {
79                            let length_str = parts[1].trim();
80                            content_length = Some(length_str.parse()?);
81                            break;
82                        }
83                    }
84                }
85                break;
86            }
87        }
88
89        let content_length =
90            content_length.ok_or_else(|| anyhow!("Failed to find Content-Length header"))?;
91
92        let mut body = vec![0u8; content_length];
93        self.stream.read_exact(&mut body).await?;
94
95        let response: ResponseMessage = serde_json::from_slice(&body)
96            .map_err(|e| anyhow!("Failed to parse response body: {}", e))?;
97
98        Ok(response)
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use serde_json::json;
106    use tokio_test::io::Builder;
107
108    #[tokio::test]
109    async fn test_send_request_and_response() {
110        // Assume this is the exact request JSON your client will send
111        let request = RequestMessage::new_initialize(
112            1,
113            std::process::id(),
114            "file:///tmp".into(),
115            "unit_test_client".into(),
116            "0.1.0".into(),
117            vec![],
118        );
119
120        let request_json = serde_json::to_string(&request).unwrap();
121        let request_content_length = request_json.len();
122
123        // The actual response from the server
124        let response_payload = json!({
125            "jsonrpc": "2.0",
126            "id": 1, // Match the ID of the request
127            "result": {}
128        })
129        .to_string();
130        let response_content_length = response_payload.len();
131        let server_response = format!(
132            "Content-Length: {}\r\n\r\n{}",
133            response_content_length, response_payload
134        );
135
136        // Set up the mock server
137        let mock_server = Builder::new()
138            .write(
139                format!(
140                    "Content-Length: {}\r\n\r\n{}",
141                    request_content_length, request_json
142                )
143                .as_bytes(),
144            )
145            .read(server_response.as_bytes())
146            .build();
147
148        let mut lsp_client = LspClient {
149            stream: Box::pin(mock_server),
150        };
151
152        // Test sending the request
153        let send_result = lsp_client.send_request(request).await;
154        assert!(send_result.is_ok());
155
156        // Test handling the response
157        let response = lsp_client.handle_response().await;
158        println!("{:?}", response);
159        assert!(response.is_ok());
160        assert_eq!(response.unwrap().result, json!({}));
161    }
162}