mcpls_core/lsp/
transport.rs1use std::collections::HashMap;
12
13use serde_json::Value;
14use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
15use tokio::process::{ChildStdin, ChildStdout};
16use tracing::{trace, warn};
17
18use crate::error::{Error, Result};
19use crate::lsp::types::{InboundMessage, JsonRpcNotification, JsonRpcResponse};
20
21#[derive(Debug)]
26pub struct LspTransport {
27 stdin: ChildStdin,
28 stdout: BufReader<ChildStdout>,
29}
30
31impl LspTransport {
32 #[must_use]
39 pub fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self {
40 Self {
41 stdin,
42 stdout: BufReader::new(stdout),
43 }
44 }
45
46 pub async fn send(&mut self, message: &Value) -> Result<()> {
58 let content = serde_json::to_string(message)?;
59 let header = format!("Content-Length: {}\r\n\r\n", content.len());
60
61 trace!("Sending LSP message: {}", content);
62
63 self.stdin.write_all(header.as_bytes()).await?;
64 self.stdin.write_all(content.as_bytes()).await?;
65 self.stdin.flush().await?;
66
67 Ok(())
68 }
69
70 pub async fn receive(&mut self) -> Result<InboundMessage> {
84 let headers = self.read_headers().await?;
85
86 let content_length = headers
87 .get("content-length")
88 .ok_or_else(|| Error::LspProtocolError("Missing Content-Length header".to_string()))?
89 .parse::<usize>()
90 .map_err(|e| Error::LspProtocolError(format!("Invalid Content-Length: {e}")))?;
91
92 let content = self.read_content(content_length).await?;
93
94 trace!("Received LSP message: {}", content);
95
96 let value: Value = serde_json::from_str(&content)?;
97
98 if value.get("id").is_some() {
99 let response: JsonRpcResponse = serde_json::from_value(value)
100 .map_err(|e| Error::LspProtocolError(format!("Invalid response: {e}")))?;
101 Ok(InboundMessage::Response(response))
102 } else {
103 let notification: JsonRpcNotification = serde_json::from_value(value)
104 .map_err(|e| Error::LspProtocolError(format!("Invalid notification: {e}")))?;
105 Ok(InboundMessage::Notification(notification))
106 }
107 }
108
109 async fn read_headers(&mut self) -> Result<HashMap<String, String>> {
114 let mut headers = HashMap::new();
115 let mut line = String::new();
116
117 loop {
118 line.clear();
119 self.stdout.read_line(&mut line).await?;
120
121 if line.is_empty() {
123 return Err(Error::ServerTerminated);
124 }
125
126 if line == "\r\n" || line == "\n" {
127 break;
128 }
129
130 if let Some((key, value)) = line.trim_end().split_once(':') {
131 headers.insert(key.trim().to_lowercase(), value.trim().to_string());
132 } else {
133 warn!("Malformed header: {}", line.trim());
134 }
135 }
136
137 Ok(headers)
138 }
139
140 async fn read_content(&mut self, length: usize) -> Result<String> {
144 let mut buffer = vec![0u8; length];
145 self.stdout.read_exact(&mut buffer).await?;
146
147 String::from_utf8(buffer)
148 .map_err(|e| Error::LspProtocolError(format!("Invalid UTF-8 in content: {e}")))
149 }
150}
151
152#[cfg(test)]
153#[allow(clippy::unwrap_used)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn test_header_parsing() {
159 let headers_text = "Content-Length: 123\r\nContent-Type: application/json\r\n";
160 let mut headers = HashMap::new();
161
162 for line in headers_text.lines() {
163 if let Some((key, value)) = line.split_once(':') {
164 headers.insert(key.trim().to_lowercase(), value.trim().to_string());
165 }
166 }
167
168 assert_eq!(headers.get("content-length"), Some(&"123".to_string()));
169 assert_eq!(
170 headers.get("content-type"),
171 Some(&"application/json".to_string())
172 );
173 }
174
175 #[test]
176 fn test_message_format() {
177 let message = serde_json::json!({
178 "jsonrpc": "2.0",
179 "id": 1,
180 "method": "initialize",
181 "params": {}
182 });
183
184 let content = serde_json::to_string(&message).unwrap();
185 let header = format!("Content-Length: {}\r\n\r\n", content.len());
186
187 assert!(header.starts_with("Content-Length:"));
188 assert!(header.ends_with("\r\n\r\n"));
189 assert!(content.contains("\"jsonrpc\":\"2.0\""));
190 }
191}