1use anyhow::{anyhow, Result};
2use std::pin::Pin;
3use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4use tokio::net::{TcpStream, UnixStream};
5
6use crate::protocol::{RequestMessage, ResponseMessage};
7
8pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin {}
9impl<T: AsyncRead + AsyncWrite + Unpin + ?Sized> AsyncReadWrite for T {}
10
11type Stream = Pin<Box<dyn AsyncReadWrite + Send>>;
12
13pub struct LspClient {
14 stream: Stream,
15}
16
17impl LspClient {
18 pub async fn new(addr: &str) -> Result<Self> {
19 let scheme = addr.split(':').next().ok_or(anyhow!(
20 "Invalid address format. Expected format: <scheme:address:port> or <scheme:path> for UNIX sockets."
21 ))?;
22
23 let stream: Stream = match scheme {
24 "tcp" => {
25 let addr = addr
27 .splitn(2, ':')
28 .nth(1)
29 .ok_or(anyhow!("Invalid TCP address format."))?;
30 let tcp_stream = TcpStream::connect(addr).await?;
31 Box::pin(tcp_stream) as Stream
32 }
33 "unix" => {
34 let path = addr
36 .splitn(2, ':')
37 .nth(1)
38 .ok_or(anyhow!("Invalid UNIX socket path format."))?;
39 let unix_stream = UnixStream::connect(path).await?;
40 Box::pin(unix_stream) as Stream
41 }
42 _ => {
43 return Err(anyhow!(
44 "Unsupported scheme '{}'. Use 'tcp' or 'unix'.",
45 scheme
46 ))
47 }
48 };
49
50 Ok(Self { stream })
51 }
52
53 pub async fn send_request(&mut self, request: RequestMessage) -> Result<()> {
54 let request_str = serde_json::to_string(&request)?;
55 let content_length = request_str.len();
56 let header = format!("Content-Length: {}\r\n\r\n{}", content_length, request_str);
57 self.stream.write_all(header.as_bytes()).await?;
58 self.stream.flush().await?;
59 Ok(())
60 }
61
62 pub async fn handle_response(&mut self) -> Result<ResponseMessage> {
63 let mut headers = Vec::new();
64 let mut content_length: Option<usize> = None;
65
66 loop {
68 let mut byte = [0];
69 self.stream.read_exact(&mut byte).await?;
70 headers.push(byte[0]);
71
72 if headers.ends_with(b"\r\n\r\n") {
74 let headers_str = String::from_utf8_lossy(&headers);
75 for line in headers_str.lines() {
76 if line.starts_with("Content-Length:") {
77 let parts: Vec<&str> = line.splitn(2, ':').collect();
78 if parts.len() > 1 {
79 let length_str = parts[1].trim();
80 content_length = Some(length_str.parse()?);
81 break;
82 }
83 }
84 }
85 break;
86 }
87 }
88
89 let content_length =
90 content_length.ok_or_else(|| anyhow!("Failed to find Content-Length header"))?;
91
92 let mut body = vec![0u8; content_length];
93 self.stream.read_exact(&mut body).await?;
94
95 let response: ResponseMessage = serde_json::from_slice(&body)
96 .map_err(|e| anyhow!("Failed to parse response body: {}", e))?;
97
98 Ok(response)
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use serde_json::json;
106 use tokio_test::io::Builder;
107
108 #[tokio::test]
109 async fn test_send_request_and_response() {
110 let request = RequestMessage::new_initialize(
112 1,
113 std::process::id(),
114 "file:///tmp".into(),
115 "unit_test_client".into(),
116 "0.1.0".into(),
117 vec![],
118 );
119
120 let request_json = serde_json::to_string(&request).unwrap();
121 let request_content_length = request_json.len();
122
123 let response_payload = json!({
125 "jsonrpc": "2.0",
126 "id": 1, "result": {}
128 })
129 .to_string();
130 let response_content_length = response_payload.len();
131 let server_response = format!(
132 "Content-Length: {}\r\n\r\n{}",
133 response_content_length, response_payload
134 );
135
136 let mock_server = Builder::new()
138 .write(
139 format!(
140 "Content-Length: {}\r\n\r\n{}",
141 request_content_length, request_json
142 )
143 .as_bytes(),
144 )
145 .read(server_response.as_bytes())
146 .build();
147
148 let mut lsp_client = LspClient {
149 stream: Box::pin(mock_server),
150 };
151
152 let send_result = lsp_client.send_request(request).await;
154 assert!(send_result.is_ok());
155
156 let response = lsp_client.handle_response().await;
158 println!("{:?}", response);
159 assert!(response.is_ok());
160 assert_eq!(response.unwrap().result, json!({}));
161 }
162}