use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use std::time::Instant;
use anyhow::{Context, Result};
use fs4::fs_std::FileExt;
use serde_json::{json, Value};
use tokio::io::{AsyncBufReadExt, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::net::UnixListener;
use tokio::process::Command;
use tokio::sync::Mutex;
use tracing::{debug, error, info, warn};
use crate::lsp::mux::protocol::{self, ClientTag, DocumentState};
use crate::lsp::transport::{read_message, write_message};
type SharedWriter = Arc<Mutex<Box<dyn AsyncWrite + Unpin + Send>>>;
struct MuxState {
clients: HashMap<ClientTag, SharedWriter>,
doc_state: DocumentState,
cached_init_result: Value,
cached_capabilities: Vec<Value>,
edit_lock_owner: Option<ClientTag>,
next_tag: u32,
idle_since: Option<Instant>,
}
impl MuxState {
fn new(init_result: Value) -> Self {
Self {
clients: HashMap::new(),
doc_state: DocumentState::new(),
cached_init_result: init_result,
cached_capabilities: Vec::new(),
edit_lock_owner: None,
next_tag: 0,
idle_since: Some(Instant::now()),
}
}
fn next_tag(&mut self) -> ClientTag {
let tag = char::from(b'a' + (self.next_tag % 26) as u8).to_string();
self.next_tag += 1;
tag
}
}
pub async fn run(
socket_path: &Path,
lock_path: &Path,
workspace_root: &Path,
idle_timeout_secs: u64,
server_command: &str,
server_args: &[String],
server_env: &[(String, String)],
) -> Result<()> {
let lock_file = std::fs::File::create(lock_path)
.with_context(|| format!("failed to create lock file: {}", lock_path.display()))?;
lock_file
.try_lock_exclusive()
.context("another mux instance holds the lock")?;
use std::io::Write;
writeln!(&lock_file, "{}", std::process::id())?;
let mut child = Command::new(server_command)
.args(server_args)
.envs(server_env.iter().map(|(k, v)| (k, v)))
.current_dir(workspace_root)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.process_group(0) .kill_on_drop(true)
.spawn()
.with_context(|| format!("failed to spawn LSP server: {server_command}"))?;
let child_pgid: Option<libc::pid_t> = child.id().map(|id| id as libc::pid_t);
let server_stdin = child.stdin.take().context("no stdin on child")?;
let server_stdout = child.stdout.take().context("no stdout on child")?;
let server_writer: SharedWriter = Arc::new(Mutex::new(
Box::new(server_stdin) as Box<dyn AsyncWrite + Unpin + Send>
));
let mut server_reader = BufReader::new(server_stdout);
if let Some(stderr) = child.stderr.take() {
tokio::spawn(async move {
let mut reader = BufReader::new(stderr);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) | Err(_) => break,
Ok(_) => info!(target: "mux::server_stderr", "{}", line.trim_end()),
}
}
});
}
if let Some(pid) = child.id() {
tokio::spawn(watch_memory(pid));
}
let init_request = json!({
"jsonrpc": "2.0",
"id": 0,
"method": "initialize",
"params": {
"processId": null,
"capabilities": {
"textDocument": {
"synchronization": {
"dynamicRegistration": true,
"didSave": true
},
"definition": { "dynamicRegistration": true },
"references": { "dynamicRegistration": true },
"hover": { "dynamicRegistration": true },
"rename": { "dynamicRegistration": true },
"documentSymbol": {
"dynamicRegistration": true,
"hierarchicalDocumentSymbolSupport": true
},
"completion": { "dynamicRegistration": true }
},
"workspace": {
"workspaceFolders": true,
"applyEdit": true
}
},
"rootUri": url::Url::from_file_path(workspace_root).map(|u| u.to_string()).unwrap_or_default()
}
});
{
let mut w = server_writer.lock().await;
write_message(&mut *w, &init_request).await?;
}
let init_result = loop {
let msg = read_message(&mut server_reader)
.await
.context("failed to read message during initialize handshake")?;
if msg.get("id").and_then(|v| v.as_i64()) == Some(0) && msg.get("method").is_none() {
break msg
.get("result")
.cloned()
.context("initialize response missing 'result'")?;
}
if let Some(id) = msg.get("id") {
if msg.get("method").is_some() {
debug!(
"auto-responding to server request during init: {}",
msg.get("method").unwrap()
);
let response = json!({
"jsonrpc": "2.0",
"id": id,
"result": null,
});
let mut w = server_writer.lock().await;
write_message(&mut *w, &response).await?;
}
}
};
info!("LSP server initialized successfully");
let initialized_notif = json!({
"jsonrpc": "2.0",
"method": "initialized",
"params": {}
});
{
let mut w = server_writer.lock().await;
write_message(&mut *w, &initialized_notif).await?;
}
if socket_path.exists() {
std::fs::remove_file(socket_path).ok();
}
let listener = UnixListener::bind(socket_path)
.with_context(|| format!("failed to bind socket: {}", socket_path.display()))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(socket_path, std::fs::Permissions::from_mode(0o600));
}
{
let mut stdout = tokio::io::stdout();
stdout.write_all(b"ready\n").await?;
stdout.flush().await?;
}
let state = Arc::new(Mutex::new(MuxState::new(init_result)));
let result = event_loop(
&listener,
&mut server_reader,
&server_writer,
&state,
idle_timeout_secs,
)
.await;
let _ = child.kill().await;
if let Some(pgid) = child_pgid {
unsafe {
libc::killpg(pgid, libc::SIGTERM);
}
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
unsafe {
libc::killpg(pgid, libc::SIGKILL);
}
}
std::fs::remove_file(socket_path).ok();
reclaim_kotlin_analyzer_home(server_env);
info!("mux process shutting down");
result
}
fn reclaim_kotlin_analyzer_home(server_env: &[(String, String)]) {
let Some(home) = kotlin_home_from_env(server_env) else {
return;
};
match std::fs::remove_dir_all(&home) {
Ok(()) => info!("reclaimed kotlin-lsp analyzer home {}", home.display()),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
Err(e) => warn!("failed to reclaim kotlin-lsp home {}: {e}", home.display()),
}
}
fn kotlin_home_from_env(server_env: &[(String, String)]) -> Option<std::path::PathBuf> {
let jto = server_env
.iter()
.find(|(k, _)| k == "JAVA_TOOL_OPTIONS")
.map(|(_, v)| v.as_str())?;
let dir = jto
.rsplit_once("-Duser.home=")?
.1
.split_whitespace()
.next()?;
let path = std::path::PathBuf::from(dir);
crate::lsp::servers::is_codescout_kotlin_home(&path).then_some(path)
}
async fn event_loop(
listener: &UnixListener,
server_reader: &mut BufReader<tokio::process::ChildStdout>,
server_writer: &SharedWriter,
state: &Arc<Mutex<MuxState>>,
idle_timeout_secs: u64,
) -> Result<()> {
let idle_timeout = std::time::Duration::from_secs(idle_timeout_secs);
let watchdog_interval = tokio::time::Duration::from_secs(10);
let mut watchdog_tick = tokio::time::interval(watchdog_interval);
watchdog_tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.context("install SIGTERM handler")?;
let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())
.context("install SIGINT handler")?;
loop {
tokio::select! {
accept_result = listener.accept() => {
match accept_result {
Ok((stream, _addr)) => {
let (read_half, write_half) = stream.into_split();
let writer: SharedWriter = Arc::new(Mutex::new(
Box::new(write_half) as Box<dyn AsyncWrite + Unpin + Send>,
));
let mut st = state.lock().await;
let tag = st.next_tag();
st.clients.insert(tag.clone(), writer.clone());
st.idle_since = None;
let init_msg = json!({
"type": "init",
"result": st.cached_init_result,
"registered_capabilities": st.cached_capabilities,
});
drop(st);
let w = writer.clone();
let tag_clone = tag.clone();
tokio::spawn(async move {
let mut w = w.lock().await;
if let Err(e) = write_message(&mut *w, &init_msg).await {
warn!(tag = %tag_clone, "failed to send init to client: {e}");
}
});
let reader = BufReader::new(read_half);
let sw = server_writer.clone();
let st_clone = state.clone();
tokio::spawn(client_reader_task(tag, reader, sw, st_clone));
info!("client connected");
}
Err(e) => {
warn!("failed to accept client connection: {e}");
}
}
}
server_msg = read_message(server_reader) => {
match server_msg {
Ok(msg) => {
handle_server_message(msg, state, server_writer).await;
}
Err(e) => {
info!("LSP server disconnected: {e}");
break;
}
}
}
_ = watchdog_tick.tick() => {
let st = state.lock().await;
if let Some(since) = st.idle_since {
if since.elapsed() >= idle_timeout {
info!("idle timeout reached ({idle_timeout_secs}s), shutting down");
break;
}
}
}
_ = sigterm.recv() => {
info!("mux received SIGTERM, exiting event loop");
break;
}
_ = sigint.recv() => {
info!("mux received SIGINT, exiting event loop");
break;
}
}
}
Ok(())
}
async fn client_reader_task(
tag: ClientTag,
mut reader: BufReader<tokio::net::unix::OwnedReadHalf>,
server_writer: SharedWriter,
state: Arc<Mutex<MuxState>>,
) {
while let Ok(mut msg) = read_message(&mut reader).await {
if let Err(e) = handle_client_message(&tag, &mut msg, &server_writer, &state).await {
warn!(tag = %tag, "error handling client message: {e}");
}
}
handle_client_disconnect(&tag, &server_writer, &state).await;
}
async fn handle_client_message(
tag: &str,
msg: &mut Value,
server_writer: &SharedWriter,
state: &Arc<Mutex<MuxState>>,
) -> Result<()> {
let method = msg.get("method").and_then(|m| m.as_str()).map(String::from);
if method.is_some() {
if let Some(id) = msg.get("id") {
let tagged = protocol::tag_request_id(id, tag);
msg["id"] = tagged;
}
}
if let Some(ref method) = method {
let mut st = state.lock().await;
match method.as_str() {
"textDocument/didOpen" => {
if let Some(uri) = extract_text_document_uri(msg) {
let forward = st.doc_state.open(&uri, tag);
if !forward {
debug!(tag = %tag, uri = %uri, "didOpen suppressed (already open)");
return Ok(());
}
}
}
"textDocument/didClose" => {
if let Some(uri) = extract_text_document_uri(msg) {
let forward = st.doc_state.close(&uri, tag);
if !forward {
debug!(tag = %tag, uri = %uri, "didClose suppressed (other clients still have it open)");
return Ok(());
}
}
}
"textDocument/didChange" => {
if let Some(uri) = extract_text_document_uri(msg) {
let version = st.doc_state.next_version(&uri);
if let Some(td) = msg
.get_mut("params")
.and_then(|p| p.get_mut("textDocument"))
{
td["version"] = json!(version);
}
}
}
"textDocument/rename" => {
st.edit_lock_owner = Some(tag.to_string());
}
_ => {}
}
}
let mut w = server_writer.lock().await;
write_message(&mut *w, msg).await?;
Ok(())
}
async fn handle_server_message(
mut msg: Value,
state: &Arc<Mutex<MuxState>>,
server_writer: &SharedWriter,
) {
let has_id = msg.get("id").is_some();
let has_method = msg.get("method").and_then(|m| m.as_str()).is_some();
if has_id && !has_method {
handle_server_response(&mut msg, state).await;
} else if has_id && has_method {
handle_server_request(&msg, state, server_writer).await;
} else if has_method {
broadcast_to_clients(&msg, state).await;
} else {
tracing::debug!(
?msg,
"mux: dropping server message with no id and no method"
);
}
}
async fn handle_server_response(msg: &mut Value, state: &Arc<Mutex<MuxState>>) {
let id = match msg.get("id") {
Some(id) => id.clone(),
None => return,
};
let (tag, original_id) = match protocol::untag_response_id(&id) {
Some(pair) => pair,
None => {
debug!("server response with untagged id: {id}");
return;
}
};
msg["id"] = original_id;
{
let mut st = state.lock().await;
if st.edit_lock_owner.as_deref() == Some(&tag) {
st.edit_lock_owner = None;
}
}
let writer = {
let st = state.lock().await;
st.clients.get(&tag).cloned()
};
if let Some(writer) = writer {
let mut w = writer.lock().await;
if let Err(e) = write_message(&mut *w, msg).await {
warn!(tag = %tag, "failed to send response to client: {e}");
}
} else {
debug!(tag = %tag, "response for disconnected client, dropping");
}
}
async fn handle_server_request(
msg: &Value,
state: &Arc<Mutex<MuxState>>,
server_writer: &SharedWriter,
) {
let method = msg
.get("method")
.and_then(|m| m.as_str())
.unwrap_or_default();
let id = msg.get("id").cloned().unwrap_or(Value::Null);
match method {
"workspace/applyEdit" => {
let writer = {
let st = state.lock().await;
st.edit_lock_owner
.as_ref()
.and_then(|tag| st.clients.get(tag).cloned())
};
if let Some(writer) = writer {
let mut w = writer.lock().await;
if let Err(e) = write_message(&mut *w, msg).await {
warn!("failed to forward applyEdit to client: {e}");
send_auto_response(&id, server_writer, false).await;
}
} else {
send_auto_response(&id, server_writer, true).await;
}
}
"client/registerCapability" => {
{
let mut st = state.lock().await;
st.cached_capabilities.push(msg.clone());
}
let response = json!({
"jsonrpc": "2.0",
"id": id,
"result": null,
});
let mut w = server_writer.lock().await;
if let Err(e) = write_message(&mut *w, &response).await {
error!("failed to send auto-response to server: {e}");
}
}
_ => {
let response = json!({
"jsonrpc": "2.0",
"id": id,
"result": null
});
let mut w = server_writer.lock().await;
if let Err(e) = write_message(&mut *w, &response).await {
error!("failed to send auto-response to server: {e}");
}
}
}
}
async fn send_auto_response(id: &Value, server_writer: &SharedWriter, success: bool) {
let result = if success {
json!({ "applied": true })
} else {
json!({ "applied": false })
};
let response = json!({
"jsonrpc": "2.0",
"id": id,
"result": result
});
let mut w = server_writer.lock().await;
if let Err(e) = write_message(&mut *w, &response).await {
error!("failed to send auto-response to server: {e}");
}
}
async fn broadcast_to_clients(msg: &Value, state: &Arc<Mutex<MuxState>>) {
let writers: Vec<(ClientTag, SharedWriter)> = {
let st = state.lock().await;
st.clients
.iter()
.map(|(tag, w)| (tag.clone(), w.clone()))
.collect()
};
for (tag, writer) in writers {
let mut w = writer.lock().await;
if let Err(e) = write_message(&mut *w, msg).await {
debug!(tag = %tag, "failed to broadcast to client: {e}");
}
}
}
async fn handle_client_disconnect(
tag: &str,
server_writer: &SharedWriter,
state: &Arc<Mutex<MuxState>>,
) {
info!(tag = %tag, "client disconnected");
let uris_to_close = {
let mut st = state.lock().await;
st.clients.remove(tag);
if st.edit_lock_owner.as_deref() == Some(tag) {
st.edit_lock_owner = None;
}
let uris = st.doc_state.disconnect(tag);
if st.clients.is_empty() {
st.idle_since = Some(Instant::now());
info!("no clients connected, starting idle timer");
}
uris
};
for uri in uris_to_close {
let close_msg = json!({
"jsonrpc": "2.0",
"method": "textDocument/didClose",
"params": {
"textDocument": { "uri": uri }
}
});
let mut w = server_writer.lock().await;
if let Err(e) = write_message(&mut *w, &close_msg).await {
warn!("failed to send didClose to server for {uri}: {e}");
}
}
}
fn extract_text_document_uri(msg: &Value) -> Option<String> {
msg.get("params")?
.get("textDocument")?
.get("uri")?
.as_str()
.map(String::from)
}
async fn watch_memory(pid: u32) {
const WARN_KB: u64 = 4 * 1024 * 1024; const ERROR_KB: u64 = 8 * 1024 * 1024;
let mut ticker = tokio::time::interval(std::time::Duration::from_secs(60));
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
ticker.tick().await;
let Some((rss_kb, swap_kb)) = read_proc_memory(pid) else {
break;
};
let total_kb = rss_kb + swap_kb;
let rss_gib = rss_kb as f64 / (1024.0 * 1024.0);
let swap_gib = swap_kb as f64 / (1024.0 * 1024.0);
let total_gib = total_kb as f64 / (1024.0 * 1024.0);
if total_kb >= ERROR_KB {
error!(
target: "mux::memory",
"LSP server memory CRITICAL (pid={}): {:.1} GiB total (rss={:.1} GiB swap={:.1} GiB)",
pid, total_gib, rss_gib, swap_gib
);
} else if total_kb >= WARN_KB {
warn!(
target: "mux::memory",
"LSP server memory high (pid={}): {:.1} GiB total (rss={:.1} GiB swap={:.1} GiB)",
pid, total_gib, rss_gib, swap_gib
);
} else {
debug!(
target: "mux::memory",
"LSP server memory (pid={}): rss={:.1} GiB swap={:.1} GiB",
pid, rss_gib, swap_gib
);
}
}
}
#[cfg(target_os = "linux")]
fn read_proc_memory(pid: u32) -> Option<(u64, u64)> {
let content = std::fs::read_to_string(format!("/proc/{pid}/status")).ok()?;
let mut rss_kb = None;
let mut swap_kb = None;
for line in content.lines() {
if let Some(rest) = line.strip_prefix("VmRSS:") {
rss_kb = rest.split_whitespace().next().and_then(|v| v.parse().ok());
} else if let Some(rest) = line.strip_prefix("VmSwap:") {
swap_kb = rest.split_whitespace().next().and_then(|v| v.parse().ok());
}
if rss_kb.is_some() && swap_kb.is_some() {
break;
}
}
Some((rss_kb?, swap_kb.unwrap_or(0)))
}
#[cfg(not(target_os = "linux"))]
fn read_proc_memory(_pid: u32) -> Option<(u64, u64)> {
None
}
#[cfg(test)]
mod kotlin_home_tests {
use super::*;
#[test]
fn kotlin_home_from_env_extracts_guarded_home() {
let home = crate::lsp::servers::kotlin_analyzer_home("abc123");
let env = vec![
("GRADLE_USER_HOME".to_string(), "/tmp/g".to_string()),
(
"JAVA_TOOL_OPTIONS".to_string(),
format!("-Duser.home={}", home.display()),
),
];
assert_eq!(kotlin_home_from_env(&env), Some(home));
}
#[test]
fn kotlin_home_from_env_takes_last_user_home() {
let home = crate::lsp::servers::kotlin_analyzer_home("ws9");
let env = vec![(
"JAVA_TOOL_OPTIONS".to_string(),
format!(
"-Xmx2g -Duser.home=/home/real -Duser.home={}",
home.display()
),
)];
assert_eq!(kotlin_home_from_env(&env), Some(home));
}
#[test]
fn kotlin_home_from_env_rejects_foreign_and_absent() {
let env = vec![(
"JAVA_TOOL_OPTIONS".to_string(),
"-Duser.home=/home/victim".to_string(),
)];
assert_eq!(kotlin_home_from_env(&env), None);
let env2 = vec![("GRADLE_USER_HOME".to_string(), "/tmp/g".to_string())];
assert_eq!(kotlin_home_from_env(&env2), None);
}
}
#[cfg(all(test, unix))]
mod process_group_reaping_tests {
use tokio::process::Command;
fn pid_alive(pid: libc::pid_t) -> bool {
unsafe { libc::kill(pid, 0) == 0 }
}
#[tokio::test]
async fn killpg_reaps_grandchild_in_child_process_group() {
let mut child = Command::new("sh")
.arg("-c")
.arg("sleep 60 & echo $! ; wait")
.stdout(std::process::Stdio::piped())
.process_group(0) .kill_on_drop(true)
.spawn()
.expect("spawn group leader");
let pgid = child.id().expect("child has a pid") as libc::pid_t;
use tokio::io::AsyncReadExt as _;
let mut stdout = child.stdout.take().expect("child stdout");
let grandchild_pid = {
let mut buf = Vec::new();
let read = tokio::time::timeout(std::time::Duration::from_secs(5), async {
let mut byte = [0u8; 1];
loop {
let n = stdout.read(&mut byte).await.expect("read stdout");
if n == 0 || byte[0] == b'\n' {
break;
}
buf.push(byte[0]);
}
})
.await;
read.expect("timed out reading grandchild pid");
String::from_utf8(buf)
.expect("utf8 pid")
.trim()
.parse::<libc::pid_t>()
.expect("parse grandchild pid")
};
assert!(
pid_alive(grandchild_pid),
"grandchild should be alive before killpg"
);
unsafe {
libc::killpg(pgid, libc::SIGTERM);
}
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
unsafe {
libc::killpg(pgid, libc::SIGKILL);
}
let _ = child.wait().await;
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
assert!(
!pid_alive(grandchild_pid),
"grandchild (pid {grandchild_pid}) should be reaped by killpg on the group; \
it orphaned instead — process_group(0) wiring is broken"
);
}
#[tokio::test]
async fn killpg_on_dead_group_is_harmless() {
let mut child = Command::new("sh")
.arg("-c")
.arg("exit 0")
.process_group(0)
.kill_on_drop(true)
.spawn()
.expect("spawn short-lived child");
let pgid = child.id().expect("child has a pid") as libc::pid_t;
let _ = child.wait().await;
unsafe {
libc::killpg(pgid, libc::SIGTERM);
libc::killpg(pgid, libc::SIGKILL);
}
}
}