use std::collections::HashMap;
use std::io::BufRead;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::mpsc::{self, Receiver, Sender};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::Duration;
use crate::app::Editor;
use crate::server::ipc::{ServerConnection, ServerListener, SocketPaths, StreamWrapper};
use crate::server::protocol::{
ClientControl, FileRequest, ServerControl, ServerHello, VersionMismatch, PROTOCOL_VERSION,
};
const ACCEPT_POLL_INTERVAL: Duration = Duration::from_millis(20);
enum LocalControlRequest {
OpenFiles {
files: Vec<FileRequest>,
wait_id: Option<u64>,
},
OpenWindow { path: PathBuf },
}
struct Shared {
req_rx: Mutex<Receiver<LocalControlRequest>>,
waiters: Arc<Mutex<HashMap<u64, Sender<()>>>>,
}
static GLOBAL: OnceLock<Shared> = OnceLock::new();
static SESSION_ID: OnceLock<String> = OnceLock::new();
pub fn local_session_id() -> Option<&'static str> {
SESSION_ID.get().map(String::as_str)
}
pub fn start() -> std::io::Result<&'static str> {
if let Some(id) = SESSION_ID.get() {
return Ok(id.as_str());
}
let session_id = generate_session_id();
let paths = SocketPaths::for_session_name(&session_id)?;
let bound = bind_and_spawn(paths)?;
#[allow(clippy::let_underscore_must_use)]
let _ = GLOBAL.set(Shared {
req_rx: Mutex::new(bound.req_rx),
waiters: bound.waiters,
});
let id = SESSION_ID.get_or_init(|| session_id);
tracing::info!("Local control socket listening as session {}", id);
Ok(id.as_str())
}
struct BoundControl {
req_rx: Receiver<LocalControlRequest>,
waiters: Arc<Mutex<HashMap<u64, Sender<()>>>>,
shutdown: Arc<AtomicBool>,
}
fn bind_and_spawn(paths: SocketPaths) -> std::io::Result<BoundControl> {
paths.cleanup_if_stale();
let listener = ServerListener::bind(paths.clone())?;
paths.write_pid(std::process::id())?;
let (req_tx, req_rx) = mpsc::channel::<LocalControlRequest>();
let waiters: Arc<Mutex<HashMap<u64, Sender<()>>>> = Arc::new(Mutex::new(HashMap::new()));
let shutdown = Arc::new(AtomicBool::new(false));
let accept_waiters = waiters.clone();
let accept_shutdown = shutdown.clone();
std::thread::Builder::new()
.name("fresh-local-control".to_string())
.spawn(move || {
accept_loop(listener, req_tx, accept_waiters, accept_shutdown);
})?;
Ok(BoundControl {
req_rx,
waiters,
shutdown,
})
}
pub fn pump(editor: &mut Editor) -> bool {
let Some(shared) = GLOBAL.get() else {
return false;
};
let mut changed = false;
loop {
let req = {
let rx = shared.req_rx.lock().unwrap();
rx.try_recv()
};
match req {
Ok(LocalControlRequest::OpenFiles { files, wait_id }) => {
let last = files.len().saturating_sub(1);
for (i, fr) in files.into_iter().enumerate() {
let file_wait_id = if i == last { wait_id } else { None };
editor.queue_file_open(
PathBuf::from(fr.path),
fr.line,
fr.column,
fr.end_line,
fr.end_column,
fr.message,
file_wait_id,
);
}
editor.process_pending_file_opens();
changed = true;
}
Ok(LocalControlRequest::OpenWindow { path }) => {
if path.is_absolute() {
let label = path
.file_name()
.map(|s| s.to_string_lossy().into_owned())
.unwrap_or_else(|| path.to_string_lossy().into_owned());
let id = editor.create_window_at(path, label);
editor.set_active_window(id);
editor.show_file_explorer();
changed = true;
} else {
tracing::warn!("OpenWindow ignored: path must be absolute: {:?}", path);
}
}
Err(_) => break,
}
}
let completed = editor.take_completed_waits();
if !completed.is_empty() {
let mut waiters = shared.waiters.lock().unwrap();
for wait_id in completed {
if let Some(notifier) = waiters.remove(&wait_id) {
#[allow(clippy::let_underscore_must_use)]
let _ = notifier.send(());
}
}
}
changed
}
fn accept_loop(
mut listener: ServerListener,
req_tx: Sender<LocalControlRequest>,
waiters: Arc<Mutex<HashMap<u64, Sender<()>>>>,
shutdown: Arc<AtomicBool>,
) {
let next_wait_id = Arc::new(AtomicU64::new(1));
while !shutdown.load(Ordering::SeqCst) {
match listener.accept() {
Ok(Some(conn)) => {
let req_tx = req_tx.clone();
let waiters = waiters.clone();
let next_wait_id = next_wait_id.clone();
#[allow(clippy::let_underscore_must_use)]
let _ = std::thread::Builder::new()
.name("fresh-local-control-conn".to_string())
.spawn(move || {
handle_connection(conn, req_tx, waiters, next_wait_id);
});
}
Ok(None) => std::thread::sleep(ACCEPT_POLL_INTERVAL),
Err(e) => {
tracing::warn!("Local control accept error: {}", e);
std::thread::sleep(ACCEPT_POLL_INTERVAL);
}
}
}
}
fn handle_connection(
conn: ServerConnection,
req_tx: Sender<LocalControlRequest>,
waiters: Arc<Mutex<HashMap<u64, Sender<()>>>>,
next_wait_id: Arc<AtomicU64>,
) {
let mut reader = std::io::BufReader::new(&conn.control);
if let Err(e) = handshake(&conn, &mut reader) {
tracing::debug!("Local control handshake failed: {}", e);
return;
}
loop {
#[cfg(not(windows))]
#[allow(clippy::let_underscore_must_use)]
let _ = conn.control.set_nonblocking(false);
let msg = match read_msg(&mut reader) {
Ok(Some(m)) => m,
Ok(None) => return, Err(e) => {
tracing::debug!("Local control read error: {}", e);
return;
}
};
match msg {
ClientControl::OpenWindow { path } => {
#[allow(clippy::let_underscore_must_use)]
let _ = req_tx.send(LocalControlRequest::OpenWindow {
path: PathBuf::from(path),
});
}
ClientControl::OpenFiles { files, wait } => {
let wait_slot = if wait {
let id = next_wait_id.fetch_add(1, Ordering::SeqCst);
let (tx, rx) = mpsc::channel::<()>();
waiters.lock().unwrap().insert(id, tx);
Some((id, rx))
} else {
None
};
#[allow(clippy::let_underscore_must_use)]
let _ = req_tx.send(LocalControlRequest::OpenFiles {
files,
wait_id: wait_slot.as_ref().map(|(id, _)| *id),
});
if let Some((id, rx)) = wait_slot {
#[allow(clippy::let_underscore_must_use)]
let _ = rx.recv();
let done =
serde_json::to_string(&ServerControl::WaitComplete).unwrap_or_default();
#[allow(clippy::let_underscore_must_use)]
let _ = conn.write_control(&done);
waiters.lock().unwrap().remove(&id);
}
}
other => {
tracing::debug!("Local control ignoring unexpected message: {:?}", other);
}
}
}
}
fn handshake(
conn: &ServerConnection,
reader: &mut std::io::BufReader<&StreamWrapper>,
) -> std::io::Result<()> {
#[cfg(not(windows))]
conn.control.set_nonblocking(false)?;
let hello_json = read_msg(reader)?
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "no hello"))?;
let hello = match hello_json {
ClientControl::Hello(h) => h,
_ => return Err(std::io::Error::other("expected Hello")),
};
if hello.protocol_version != PROTOCOL_VERSION {
let mismatch = VersionMismatch {
server_version: env!("CARGO_PKG_VERSION").to_string(),
client_version: hello.client_version.clone(),
action: if hello.protocol_version > PROTOCOL_VERSION {
"upgrade_server".to_string()
} else {
"restart_server".to_string()
},
message: format!(
"Protocol version mismatch: server={}, client={}",
PROTOCOL_VERSION, hello.protocol_version
),
};
let response = serde_json::to_string(&ServerControl::VersionMismatch(mismatch))
.map_err(std::io::Error::other)?;
conn.write_control(&response)?;
return Err(std::io::Error::other("version mismatch"));
}
let session_id = local_session_id().unwrap_or("local").to_string();
let response = serde_json::to_string(&ServerControl::Hello(ServerHello::new(session_id)))
.map_err(std::io::Error::other)?;
conn.write_control(&response)
}
fn read_msg(
reader: &mut std::io::BufReader<&StreamWrapper>,
) -> std::io::Result<Option<ClientControl>> {
let mut line = String::new();
match reader.read_line(&mut line) {
Ok(0) => Ok(None),
Ok(_) => {
let trimmed = line.trim();
if trimmed.is_empty() {
return Ok(None);
}
let msg = serde_json::from_str::<ClientControl>(trimmed)
.map_err(|e| std::io::Error::other(format!("invalid control message: {}", e)))?;
Ok(Some(msg))
}
Err(e) => Err(e),
}
}
fn generate_session_id() -> String {
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
format!("local-{}-{}", std::process::id(), nanos)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::server::ipc::ClientConnection;
use crate::server::protocol::{ClientHello, TermSize};
use tempfile::TempDir;
fn connect_client(paths: &SocketPaths) -> ClientConnection {
let conn = ClientConnection::connect(paths).expect("client connect");
let hello = ClientHello::new(TermSize::new(80, 24));
conn.write_control(&serde_json::to_string(&ClientControl::Hello(hello)).unwrap())
.expect("write hello");
let resp = conn
.read_control()
.expect("read server hello")
.expect("server hello present");
match serde_json::from_str::<ServerControl>(&resp).expect("parse server hello") {
ServerControl::Hello(_) => {}
other => panic!("expected ServerControl::Hello, got {:?}", other),
}
conn
}
fn file_req(path: &str) -> FileRequest {
FileRequest {
path: path.to_string(),
line: None,
column: None,
end_line: None,
end_column: None,
message: None,
}
}
#[test]
fn open_window_request_is_received() {
let dir = TempDir::new().unwrap();
let paths = SocketPaths::for_session_name_in_dir("open-window-test", dir.path());
let bound = bind_and_spawn(paths.clone()).expect("bind");
let conn = connect_client(&paths);
let msg = ClientControl::OpenWindow {
path: "/abs/project".to_string(),
};
conn.write_control(&serde_json::to_string(&msg).unwrap())
.unwrap();
match bound.req_rx.recv().expect("request forwarded") {
LocalControlRequest::OpenWindow { path } => {
assert_eq!(path, PathBuf::from("/abs/project"));
}
other => panic!("expected OpenWindow, got {:?}", req_kind(&other)),
}
bound.shutdown.store(true, Ordering::SeqCst);
}
#[test]
fn waited_open_files_completes_after_signal() {
let dir = TempDir::new().unwrap();
let paths = SocketPaths::for_session_name_in_dir("wait-open-test", dir.path());
let bound = bind_and_spawn(paths.clone()).expect("bind");
let conn = connect_client(&paths);
let msg = ClientControl::OpenFiles {
files: vec![file_req("/abs/COMMIT_EDITMSG")],
wait: true,
};
conn.write_control(&serde_json::to_string(&msg).unwrap())
.unwrap();
let wait_id = match bound.req_rx.recv().expect("request forwarded") {
LocalControlRequest::OpenFiles { files, wait_id } => {
assert_eq!(files.len(), 1);
assert_eq!(files[0].path, "/abs/COMMIT_EDITMSG");
wait_id.expect("wait id assigned for waited open")
}
other => panic!("expected OpenFiles, got {:?}", req_kind(&other)),
};
let notifier = bound
.waiters
.lock()
.unwrap()
.remove(&wait_id)
.expect("waiter registered before request was forwarded");
notifier.send(()).unwrap();
let line = conn
.read_control()
.expect("read wait complete")
.expect("wait complete present");
match serde_json::from_str::<ServerControl>(&line).expect("parse") {
ServerControl::WaitComplete => {}
other => panic!("expected WaitComplete, got {:?}", other),
}
bound.shutdown.store(true, Ordering::SeqCst);
}
#[test]
fn unwaited_open_files_has_no_wait_id() {
let dir = TempDir::new().unwrap();
let paths = SocketPaths::for_session_name_in_dir("nowait-open-test", dir.path());
let bound = bind_and_spawn(paths.clone()).expect("bind");
let conn = connect_client(&paths);
let msg = ClientControl::OpenFiles {
files: vec![file_req("/abs/file.txt")],
wait: false,
};
conn.write_control(&serde_json::to_string(&msg).unwrap())
.unwrap();
match bound.req_rx.recv().expect("request forwarded") {
LocalControlRequest::OpenFiles { wait_id, .. } => {
assert!(wait_id.is_none());
}
other => panic!("expected OpenFiles, got {:?}", req_kind(&other)),
}
assert!(bound.waiters.lock().unwrap().is_empty());
bound.shutdown.store(true, Ordering::SeqCst);
}
#[test]
fn multiple_commands_on_one_connection_all_received() {
let dir = TempDir::new().unwrap();
let paths = SocketPaths::for_session_name_in_dir("multi-cmd-test", dir.path());
let bound = bind_and_spawn(paths.clone()).expect("bind");
let conn = connect_client(&paths);
let msgs = [
ClientControl::OpenWindow {
path: "/abs/a".to_string(),
},
ClientControl::OpenWindow {
path: "/abs/b".to_string(),
},
ClientControl::OpenFiles {
files: vec![file_req("/abs/file.txt")],
wait: false,
},
];
for m in &msgs {
conn.write_control(&serde_json::to_string(m).unwrap())
.unwrap();
}
let r1 = bound.req_rx.recv().expect("first request");
let r2 = bound.req_rx.recv().expect("second request");
let r3 = bound.req_rx.recv().expect("third request");
match r1 {
LocalControlRequest::OpenWindow { path } => assert_eq!(path, PathBuf::from("/abs/a")),
other => panic!("expected OpenWindow a, got {}", req_kind(&other)),
}
match r2 {
LocalControlRequest::OpenWindow { path } => assert_eq!(path, PathBuf::from("/abs/b")),
other => panic!("expected OpenWindow b, got {}", req_kind(&other)),
}
match r3 {
LocalControlRequest::OpenFiles { files, wait_id } => {
assert_eq!(files[0].path, "/abs/file.txt");
assert!(wait_id.is_none());
}
other => panic!("expected OpenFiles, got {}", req_kind(&other)),
}
bound.shutdown.store(true, Ordering::SeqCst);
}
#[test]
fn command_sent_after_a_delay_is_still_received() {
let dir = TempDir::new().unwrap();
let paths = SocketPaths::for_session_name_in_dir("delayed-cmd-test", dir.path());
let bound = bind_and_spawn(paths.clone()).expect("bind");
let conn = connect_client(&paths);
std::thread::sleep(Duration::from_millis(150));
let msg = ClientControl::OpenWindow {
path: "/abs/delayed".to_string(),
};
conn.write_control(&serde_json::to_string(&msg).unwrap())
.unwrap();
match bound
.req_rx
.recv_timeout(Duration::from_secs(5))
.expect("delayed request must still be forwarded")
{
LocalControlRequest::OpenWindow { path } => {
assert_eq!(path, PathBuf::from("/abs/delayed"));
}
other => panic!("expected OpenWindow, got {}", req_kind(&other)),
}
bound.shutdown.store(true, Ordering::SeqCst);
}
fn req_kind(req: &LocalControlRequest) -> &'static str {
match req {
LocalControlRequest::OpenFiles { .. } => "OpenFiles",
LocalControlRequest::OpenWindow { .. } => "OpenWindow",
}
}
}