use std::collections::HashMap;
use serde_json::Value;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::process::{ChildStdin, ChildStdout};
use tracing::{trace, warn};
use crate::error::{Error, Result};
use crate::lsp::types::{InboundMessage, JsonRpcNotification, JsonRpcResponse};
const MAX_CONTENT_LENGTH: usize = 10 * 1024 * 1024;
#[derive(Debug)]
pub struct LspTransport {
stdin: ChildStdin,
stdout: BufReader<ChildStdout>,
}
impl LspTransport {
#[must_use]
pub fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self {
Self {
stdin,
stdout: BufReader::new(stdout),
}
}
pub async fn send(&mut self, message: &Value) -> Result<()> {
let content = serde_json::to_string(message)?;
let header = format!("Content-Length: {}\r\n\r\n", content.len());
trace!("Sending LSP message: {}", content);
self.stdin.write_all(header.as_bytes()).await?;
self.stdin.write_all(content.as_bytes()).await?;
self.stdin.flush().await?;
Ok(())
}
pub async fn receive(&mut self) -> Result<InboundMessage> {
let headers = self.read_headers().await?;
let content_length = headers
.get("content-length")
.ok_or_else(|| Error::LspProtocolError("Missing Content-Length header".to_string()))?
.parse::<usize>()
.map_err(|e| Error::LspProtocolError(format!("Invalid Content-Length: {e}")))?;
if content_length > MAX_CONTENT_LENGTH {
return Err(Error::LspProtocolError(format!(
"Content-Length {content_length} exceeds maximum allowed size of {MAX_CONTENT_LENGTH} bytes"
)));
}
let content = self.read_content(content_length).await?;
trace!("Received LSP message: {}", content);
let value: Value = serde_json::from_str(&content)?;
if value.get("id").is_some() {
let response: JsonRpcResponse = serde_json::from_value(value)
.map_err(|e| Error::LspProtocolError(format!("Invalid response: {e}")))?;
Ok(InboundMessage::Response(response))
} else {
let notification: JsonRpcNotification = serde_json::from_value(value)
.map_err(|e| Error::LspProtocolError(format!("Invalid notification: {e}")))?;
Ok(InboundMessage::Notification(notification))
}
}
async fn read_headers(&mut self) -> Result<HashMap<String, String>> {
let mut headers = HashMap::new();
let mut line = String::new();
loop {
line.clear();
let bytes_read = self.stdout.read_line(&mut line).await?;
if bytes_read == 0 || line.is_empty() {
trace!(
"EOF detected in read_headers: bytes_read={}, line_len={}",
bytes_read,
line.len()
);
return Err(Error::ServerTerminated);
}
if line == "\r\n" || line == "\n" {
break;
}
if let Some((key, value)) = line.trim_end().split_once(':') {
headers.insert(key.trim().to_lowercase(), value.trim().to_string());
} else {
warn!("Malformed header: {}", line.trim());
}
}
Ok(headers)
}
async fn read_content(&mut self, length: usize) -> Result<String> {
let mut buffer = vec![0u8; length];
self.stdout.read_exact(&mut buffer).await?;
String::from_utf8(buffer)
.map_err(|e| Error::LspProtocolError(format!("Invalid UTF-8 in content: {e}")))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_header_parsing() {
let headers_text = "Content-Length: 123\r\nContent-Type: application/json\r\n";
let mut headers = HashMap::new();
for line in headers_text.lines() {
if let Some((key, value)) = line.split_once(':') {
headers.insert(key.trim().to_lowercase(), value.trim().to_string());
}
}
assert_eq!(headers.get("content-length"), Some(&"123".to_string()));
assert_eq!(
headers.get("content-type"),
Some(&"application/json".to_string())
);
}
#[test]
fn test_message_format() {
let message = serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {}
});
let content = serde_json::to_string(&message).unwrap();
let header = format!("Content-Length: {}\r\n\r\n", content.len());
assert!(header.starts_with("Content-Length:"));
assert!(header.ends_with("\r\n\r\n"));
assert!(content.contains("\"jsonrpc\":\"2.0\""));
}
#[test]
fn test_header_case_insensitive() {
let headers_text = "CONTENT-LENGTH: 123\r\nContent-Type: application/json\r\n";
let mut headers = HashMap::new();
for line in headers_text.lines() {
if let Some((key, value)) = line.split_once(':') {
headers.insert(key.trim().to_lowercase(), value.trim().to_string());
}
}
assert_eq!(headers.get("content-length"), Some(&"123".to_string()));
}
#[test]
fn test_max_content_length_constant() {
assert_eq!(MAX_CONTENT_LENGTH, 10 * 1024 * 1024);
}
#[test]
fn test_header_format_with_multiple_headers() {
let headers_text =
"Content-Length: 42\r\nContent-Type: application/json\r\nX-Custom: value\r\n";
let mut headers = HashMap::new();
for line in headers_text.lines() {
if let Some((key, value)) = line.split_once(':') {
headers.insert(key.trim().to_lowercase(), value.trim().to_string());
}
}
assert_eq!(headers.len(), 3);
assert_eq!(headers.get("content-length"), Some(&"42".to_string()));
assert_eq!(headers.get("x-custom"), Some(&"value".to_string()));
}
#[test]
fn test_message_serialization_response() {
let response = serde_json::json!({
"jsonrpc": "2.0",
"id": 1,
"result": {"key": "value"}
});
let content = serde_json::to_string(&response).unwrap();
assert!(content.contains("\"jsonrpc\":\"2.0\""));
assert!(content.contains("\"id\":1"));
assert!(content.contains("\"result\""));
}
#[test]
fn test_message_serialization_notification() {
let notification = serde_json::json!({
"jsonrpc": "2.0",
"method": "window/showMessage",
"params": {"type": 1, "message": "Hello"}
});
let content = serde_json::to_string(¬ification).unwrap();
assert!(content.contains("\"method\""));
assert!(!content.contains("\"id\""));
}
#[test]
fn test_message_serialization_error_response() {
let error_response = serde_json::json!({
"jsonrpc": "2.0",
"id": 2,
"error": {
"code": -32601,
"message": "Method not found"
}
});
let content = serde_json::to_string(&error_response).unwrap();
assert!(content.contains("\"error\""));
assert!(content.contains("-32601"));
assert!(content.contains("Method not found"));
}
#[test]
fn test_content_length_calculation() {
let message = serde_json::json!({"test": "data"});
let content = serde_json::to_string(&message).unwrap();
let expected_len = content.len();
let header = format!("Content-Length: {}\r\n\r\n", content.len());
assert!(header.contains(&expected_len.to_string()));
}
#[test]
fn test_header_without_colon() {
let malformed_line = "Malformed header without colon";
let result = malformed_line.split_once(':');
assert!(result.is_none(), "Should not parse malformed header");
}
#[test]
fn test_header_with_whitespace() {
let header_line = " Content-Length : 456 ";
if let Some((key, value)) = header_line.split_once(':') {
let key_trimmed = key.trim().to_lowercase();
let value_trimmed = value.trim();
assert_eq!(key_trimmed, "content-length");
assert_eq!(value_trimmed, "456");
}
}
}