mcpls_core/lsp/
transport.rs

1//! LSP transport layer for stdio communication.
2//!
3//! This module implements the LSP header-content message format over stdin/stdout.
4//! Messages follow the format:
5//! ```text
6//! Content-Length: 123\r\n
7//! \r\n
8//! {"jsonrpc":"2.0",...}
9//! ```
10
11use std::collections::HashMap;
12
13use serde_json::Value;
14use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
15use tokio::process::{ChildStdin, ChildStdout};
16use tracing::{trace, warn};
17
18use crate::error::{Error, Result};
19use crate::lsp::types::{InboundMessage, JsonRpcNotification, JsonRpcResponse};
20
21/// LSP transport layer handling header-content format.
22///
23/// This transport handles the LSP protocol's header-content message format,
24/// parsing Content-Length headers and reading exact message content.
25#[derive(Debug)]
26pub struct LspTransport {
27    stdin: ChildStdin,
28    stdout: BufReader<ChildStdout>,
29}
30
31impl LspTransport {
32    /// Create transport from child process stdio.
33    ///
34    /// # Arguments
35    ///
36    /// * `stdin` - The child process's stdin handle for sending messages
37    /// * `stdout` - The child process's stdout handle for receiving messages
38    #[must_use]
39    pub fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self {
40        Self {
41            stdin,
42            stdout: BufReader::new(stdout),
43        }
44    }
45
46    /// Send message to LSP server.
47    ///
48    /// Formats the message with proper Content-Length header and sends it
49    /// to the LSP server via stdin.
50    ///
51    /// # Errors
52    ///
53    /// Returns an error if:
54    /// - Message serialization fails
55    /// - Writing to stdin fails
56    /// - Flushing stdin fails
57    pub async fn send(&mut self, message: &Value) -> Result<()> {
58        let content = serde_json::to_string(message)?;
59        let header = format!("Content-Length: {}\r\n\r\n", content.len());
60
61        trace!("Sending LSP message: {}", content);
62
63        self.stdin.write_all(header.as_bytes()).await?;
64        self.stdin.write_all(content.as_bytes()).await?;
65        self.stdin.flush().await?;
66
67        Ok(())
68    }
69
70    /// Receive next message from LSP server.
71    ///
72    /// Reads headers, extracts Content-Length, reads exact message content,
73    /// and parses it as either a response or notification.
74    ///
75    /// # Errors
76    ///
77    /// Returns an error if:
78    /// - Reading headers fails
79    /// - Content-Length header is missing or invalid
80    /// - Reading message content fails
81    /// - JSON parsing fails
82    /// - Message format is invalid
83    pub async fn receive(&mut self) -> Result<InboundMessage> {
84        let headers = self.read_headers().await?;
85
86        let content_length = headers
87            .get("content-length")
88            .ok_or_else(|| Error::LspProtocolError("Missing Content-Length header".to_string()))?
89            .parse::<usize>()
90            .map_err(|e| Error::LspProtocolError(format!("Invalid Content-Length: {e}")))?;
91
92        let content = self.read_content(content_length).await?;
93
94        trace!("Received LSP message: {}", content);
95
96        let value: Value = serde_json::from_str(&content)?;
97
98        if value.get("id").is_some() {
99            let response: JsonRpcResponse = serde_json::from_value(value)
100                .map_err(|e| Error::LspProtocolError(format!("Invalid response: {e}")))?;
101            Ok(InboundMessage::Response(response))
102        } else {
103            let notification: JsonRpcNotification = serde_json::from_value(value)
104                .map_err(|e| Error::LspProtocolError(format!("Invalid notification: {e}")))?;
105            Ok(InboundMessage::Notification(notification))
106        }
107    }
108
109    /// Read headers until blank line.
110    ///
111    /// Headers are in the format "Key: Value\r\n" and are terminated by
112    /// a blank line ("\r\n").
113    async fn read_headers(&mut self) -> Result<HashMap<String, String>> {
114        let mut headers = HashMap::new();
115        let mut line = String::new();
116
117        loop {
118            line.clear();
119            self.stdout.read_line(&mut line).await?;
120
121            // EOF - stream closed
122            if line.is_empty() {
123                return Err(Error::ServerTerminated);
124            }
125
126            if line == "\r\n" || line == "\n" {
127                break;
128            }
129
130            if let Some((key, value)) = line.trim_end().split_once(':') {
131                headers.insert(key.trim().to_lowercase(), value.trim().to_string());
132            } else {
133                warn!("Malformed header: {}", line.trim());
134            }
135        }
136
137        Ok(headers)
138    }
139
140    /// Read exact number of content bytes.
141    ///
142    /// Reads exactly `length` bytes from stdout and converts to UTF-8 string.
143    async fn read_content(&mut self, length: usize) -> Result<String> {
144        let mut buffer = vec![0u8; length];
145        self.stdout.read_exact(&mut buffer).await?;
146
147        String::from_utf8(buffer)
148            .map_err(|e| Error::LspProtocolError(format!("Invalid UTF-8 in content: {e}")))
149    }
150}
151
152#[cfg(test)]
153#[allow(clippy::unwrap_used)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn test_header_parsing() {
159        let headers_text = "Content-Length: 123\r\nContent-Type: application/json\r\n";
160        let mut headers = HashMap::new();
161
162        for line in headers_text.lines() {
163            if let Some((key, value)) = line.split_once(':') {
164                headers.insert(key.trim().to_lowercase(), value.trim().to_string());
165            }
166        }
167
168        assert_eq!(headers.get("content-length"), Some(&"123".to_string()));
169        assert_eq!(
170            headers.get("content-type"),
171            Some(&"application/json".to_string())
172        );
173    }
174
175    #[test]
176    fn test_message_format() {
177        let message = serde_json::json!({
178            "jsonrpc": "2.0",
179            "id": 1,
180            "method": "initialize",
181            "params": {}
182        });
183
184        let content = serde_json::to_string(&message).unwrap();
185        let header = format!("Content-Length: {}\r\n\r\n", content.len());
186
187        assert!(header.starts_with("Content-Length:"));
188        assert!(header.ends_with("\r\n\r\n"));
189        assert!(content.contains("\"jsonrpc\":\"2.0\""));
190    }
191}