use crate::error::{LspMcpError, Result};
use serde::{de::DeserializeOwned, Serialize};
use std::io::{BufRead, BufReader, Read, Write};
use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio};
use std::sync::atomic::{AtomicI32, Ordering};
use std::sync::Mutex;
use tracing::{debug, trace};
#[cfg(windows)]
use std::os::windows::process::CommandExt;
#[derive(Debug, Clone, Serialize, serde::Deserialize)]
#[serde(untagged)]
pub enum Message {
Request(Request),
Response(Response),
Notification(Notification),
}
#[derive(Debug, Clone, Serialize, serde::Deserialize)]
pub struct Request {
pub jsonrpc: String,
pub id: i32,
pub method: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub params: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, serde::Deserialize)]
pub struct Response {
pub jsonrpc: String,
pub id: i32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<ResponseError>,
}
#[derive(Debug, Clone, Serialize, serde::Deserialize)]
pub struct ResponseError {
pub code: i32,
pub message: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, serde::Deserialize)]
pub struct Notification {
pub jsonrpc: String,
pub method: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub params: Option<serde_json::Value>,
}
pub struct StdioTransport {
process: Child,
stdin: Mutex<ChildStdin>,
stdout: Mutex<BufReader<ChildStdout>>,
next_id: AtomicI32,
}
impl StdioTransport {
pub fn spawn(
command: &str,
args: &[String],
env: &std::collections::HashMap<String, String>,
working_dir: &std::path::Path,
) -> Result<Self> {
debug!(
"Spawning language server: {} {:?} in {:?}",
command, args, working_dir
);
let mut cmd = Command::new(command);
cmd.args(args)
.current_dir(working_dir)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
for (key, value) in env {
cmd.env(key, value);
}
#[cfg(windows)]
{
cmd.creation_flags(0x08000000); }
let mut process = cmd.spawn().map_err(|e| {
LspMcpError::ServerStartFailed(format!("Failed to spawn {}: {}", command, e))
})?;
let stdin = process.stdin.take().ok_or_else(|| {
LspMcpError::ServerStartFailed("Failed to get stdin handle".to_string())
})?;
let stdout = process.stdout.take().ok_or_else(|| {
LspMcpError::ServerStartFailed("Failed to get stdout handle".to_string())
})?;
Ok(Self {
process,
stdin: Mutex::new(stdin),
stdout: Mutex::new(BufReader::new(stdout)),
next_id: AtomicI32::new(1),
})
}
pub fn next_id(&self) -> i32 {
self.next_id.fetch_add(1, Ordering::SeqCst)
}
pub fn request<P: Serialize, R: DeserializeOwned>(
&self,
method: &str,
params: P,
) -> Result<R> {
let id = self.next_id();
let request = Request {
jsonrpc: "2.0".to_string(),
id,
method: method.to_string(),
params: Some(serde_json::to_value(params)?),
};
self.send_message(&Message::Request(request))?;
let response = self.receive_response(id)?;
if let Some(error) = response.error {
return Err(LspMcpError::ProtocolError(format!(
"{}: {}",
error.code, error.message
)));
}
match response.result {
Some(result) => Ok(serde_json::from_value(result)?),
None => Err(LspMcpError::ProtocolError("Empty response".to_string())),
}
}
pub fn notify<P: Serialize>(&self, method: &str, params: P) -> Result<()> {
let notification = Notification {
jsonrpc: "2.0".to_string(),
method: method.to_string(),
params: Some(serde_json::to_value(params)?),
};
self.send_message(&Message::Notification(notification))
}
fn send_message(&self, message: &Message) -> Result<()> {
let json = serde_json::to_string(message)?;
let content_length = json.len();
let mut stdin = self.stdin.lock().map_err(|_| LspMcpError::ChannelSend)?;
trace!("Sending: {}", json);
write!(stdin, "Content-Length: {}\r\n\r\n{}", content_length, json)
.map_err(|e| LspMcpError::Io(e))?;
stdin.flush().map_err(|e| LspMcpError::Io(e))?;
Ok(())
}
fn receive_response(&self, expected_id: i32) -> Result<Response> {
loop {
let message = self.receive_message()?;
match message {
Message::Response(response) if response.id == expected_id => {
return Ok(response);
}
Message::Response(response) => {
debug!("Ignoring response with id {}, expected {}", response.id, expected_id);
}
Message::Notification(notif) => {
debug!("Received notification while waiting for response: {}", notif.method);
}
Message::Request(req) => {
debug!("Received server request while waiting for response: {}", req.method);
}
}
}
}
fn receive_message(&self) -> Result<Message> {
let mut stdout = self.stdout.lock().map_err(|_| LspMcpError::ChannelRecv)?;
let mut content_length: Option<usize> = None;
loop {
let mut line = String::new();
stdout.read_line(&mut line).map_err(|e| LspMcpError::Io(e))?;
let line = line.trim();
if line.is_empty() {
break;
}
if let Some(value) = line.strip_prefix("Content-Length:") {
content_length = Some(
value
.trim()
.parse()
.map_err(|_| LspMcpError::ProtocolError("Invalid Content-Length".to_string()))?,
);
}
}
let content_length = content_length
.ok_or_else(|| LspMcpError::ProtocolError("Missing Content-Length header".to_string()))?;
let mut buffer = vec![0u8; content_length];
stdout.read_exact(&mut buffer).map_err(|e| LspMcpError::Io(e))?;
let json = String::from_utf8(buffer)
.map_err(|_| LspMcpError::ProtocolError("Invalid UTF-8".to_string()))?;
trace!("Received: {}", json);
let message: Message = serde_json::from_str(&json)?;
Ok(message)
}
pub fn kill(&mut self) -> Result<()> {
self.process.kill().map_err(|e| LspMcpError::Io(e))?;
Ok(())
}
pub fn is_running(&mut self) -> bool {
matches!(self.process.try_wait(), Ok(None))
}
}
impl Drop for StdioTransport {
fn drop(&mut self) {
let _ = self.kill();
}
}