use std::io::{BufRead, BufReader, Write};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread::{self, JoinHandle};
use std::time::Duration;
use anyhow::{Context, Result};
use interprocess::local_socket::{
GenericFilePath, ListenerOptions, ToFsName,
traits::{Listener as _, Stream as _},
};
use serde::{Deserialize, Serialize};
const SUPERVISOR_DIR: &str = "supervisor";
const SUN_PATH_MAX: usize = 107;
pub fn socket_path(project_root: &Path, session_uuid: &str) -> PathBuf {
let preferred = project_root
.join(".agent-doc")
.join(SUPERVISOR_DIR)
.join(format!("{session_uuid}.sock"));
if preferred.as_os_str().len() <= SUN_PATH_MAX {
return preferred;
}
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(project_root.as_os_str().as_encoded_bytes());
let hash = format!("{:x}", hasher.finalize());
let short_hash = &hash[..12];
let short_uuid = &session_uuid[..session_uuid.len().min(8)];
let runtime_dir = std::env::var("XDG_RUNTIME_DIR")
.unwrap_or_else(|_| std::env::temp_dir().to_string_lossy().into_owned());
PathBuf::from(runtime_dir)
.join("agent-doc")
.join(format!("{short_hash}-{short_uuid}.sock"))
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(tag = "method", rename_all = "snake_case")]
pub enum IpcMethod {
Restart {
#[serde(default = "default_restart_mode")]
mode: String,
},
Inject {
bytes: String,
},
State,
Pid,
Stop {
#[serde(default)]
graceful: bool,
},
}
fn default_restart_mode() -> String {
"continue".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IpcResponse {
pub ok: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
impl IpcResponse {
pub fn ok(data: serde_json::Value) -> Self {
Self {
ok: true,
data: Some(data),
error: None,
}
}
pub fn ok_empty() -> Self {
Self {
ok: true,
data: None,
error: None,
}
}
pub fn err(msg: impl Into<String>) -> Self {
Self {
ok: false,
data: None,
error: Some(msg.into()),
}
}
}
pub struct SupervisorIpc {
socket_path: PathBuf,
stop: Arc<AtomicBool>,
handle: Option<JoinHandle<()>>,
}
impl SupervisorIpc {
pub fn start<F>(
project_root: &Path,
session_uuid: &str,
handler: F,
) -> Result<Self>
where
F: Fn(IpcMethod) -> IpcResponse + Send + 'static,
{
let sock = socket_path(project_root, session_uuid);
if let Some(parent) = sock.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("create supervisor dir: {}", parent.display()))?;
}
if sock.exists() {
let _ = std::fs::remove_file(&sock);
}
let name = sock.clone().to_fs_name::<GenericFilePath>()?;
let opts = ListenerOptions::new().name(name);
let listener = opts
.create_sync()
.with_context(|| format!("bind supervisor socket: {}", sock.display()))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let perms = std::fs::Permissions::from_mode(0o600);
let _ = std::fs::set_permissions(&sock, perms);
}
let stop = Arc::new(AtomicBool::new(false));
let stop_clone = stop.clone();
let sock_clone = sock.clone();
let handle = thread::Builder::new()
.name("supervisor-ipc".into())
.spawn(move || {
loop {
if stop_clone.load(Ordering::Relaxed) {
break;
}
match listener.accept() {
Ok(stream) => {
let (reader_half, mut writer_half) = stream.split();
let mut reader = BufReader::new(reader_half);
let mut line = String::new();
if reader.read_line(&mut line).unwrap_or(0) > 0 {
let trimmed = line.trim();
if !trimmed.is_empty() {
let response =
match serde_json::from_str::<IpcMethod>(trimmed) {
Ok(method) => handler(method),
Err(e) => {
IpcResponse::err(format!("parse error: {e}"))
}
};
let mut resp_json = serde_json::to_string(&response)
.unwrap_or_else(|e| {
format!(
r#"{{"ok":false,"error":"serialize error: {e}"}}"#
)
});
resp_json.push('\n');
if let Err(e) = writer_half.write_all(resp_json.as_bytes()) {
eprintln!("[supervisor::ipc] write error: {e}");
}
let _ = writer_half.flush();
}
}
}
Err(e) => {
if stop_clone.load(Ordering::Relaxed) {
break;
}
eprintln!("[supervisor::ipc] accept error: {e}");
}
}
}
let _ = std::fs::remove_file(&sock_clone);
})
.context("spawn supervisor-ipc thread")?;
Ok(Self {
socket_path: sock,
stop,
handle: Some(handle),
})
}
pub fn stop(&mut self) {
self.stop.store(true, Ordering::Relaxed);
if self.socket_path.exists() {
let _ = try_connect(&self.socket_path);
}
if let Some(h) = self.handle.take() {
let _ = h.join();
}
let _ = std::fs::remove_file(&self.socket_path);
}
#[allow(dead_code)] pub fn path(&self) -> &Path {
&self.socket_path
}
}
impl Drop for SupervisorIpc {
fn drop(&mut self) {
if self.handle.is_some() {
self.stop();
}
}
}
fn try_connect(sock: &Path) -> Result<interprocess::local_socket::Stream> {
let name = sock.to_fs_name::<GenericFilePath>()?;
let opts = interprocess::local_socket::ConnectOptions::new().name(name);
let stream = opts
.connect_sync()
.context("failed to connect to supervisor socket")?;
Ok(stream)
}
#[allow(dead_code)] pub fn send_command(sock: &Path, method: &IpcMethod) -> Result<IpcResponse> {
let stream = try_connect(sock)?;
let (reader_half, mut writer_half) = stream.split();
let mut msg = serde_json::to_string(method)?;
msg.push('\n');
writer_half.write_all(msg.as_bytes())?;
writer_half.flush()?;
let (tx, rx) = std::sync::mpsc::channel();
std::thread::spawn(move || {
let mut reader = BufReader::new(reader_half);
let mut line = String::new();
let result = reader.read_line(&mut line);
let _ = tx.send((result, line));
});
match rx.recv_timeout(Duration::from_secs(2)) {
Ok((Ok(0), _)) => anyhow::bail!("supervisor closed connection without responding"),
Ok((Ok(_), line)) => {
let resp: IpcResponse = serde_json::from_str(line.trim())
.context("failed to parse supervisor response")?;
Ok(resp)
}
Ok((Err(e), _)) => anyhow::bail!("supervisor read error: {e}"),
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
anyhow::bail!("supervisor response timeout (2s)")
}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
anyhow::bail!("supervisor reader thread disconnected")
}
}
}
#[allow(dead_code)] pub fn is_active(project_root: &Path, session_uuid: &str) -> bool {
let sock = socket_path(project_root, session_uuid);
if !sock.exists() {
return false;
}
try_connect(&sock).is_ok()
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicU32;
fn start_echo_handler(root: &Path, uuid: &str) -> SupervisorIpc {
SupervisorIpc::start(root, uuid, |method| match method {
IpcMethod::State => IpcResponse::ok(serde_json::json!({
"running": true,
"state": "healthy",
"restart_count": 0,
})),
IpcMethod::Pid => IpcResponse::ok(serde_json::json!({ "pid": 12345 })),
IpcMethod::Stop { .. } => IpcResponse::ok_empty(),
IpcMethod::Restart { mode } => {
IpcResponse::ok(serde_json::json!({ "pid": 99999, "mode": mode }))
}
IpcMethod::Inject { bytes } => {
IpcResponse::ok(serde_json::json!({ "n": bytes.len() }))
}
})
.expect("start test handler")
}
#[test]
fn roundtrip_state_query() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path();
std::fs::create_dir_all(root.join(".agent-doc")).unwrap();
let mut ipc = start_echo_handler(root, "test-state");
std::thread::sleep(Duration::from_millis(50));
let sock = socket_path(root, "test-state");
let resp = send_command(&sock, &IpcMethod::State).unwrap();
assert!(resp.ok);
let data = resp.data.unwrap();
assert_eq!(data["state"], "healthy");
assert_eq!(data["running"], true);
ipc.stop();
}
#[test]
fn roundtrip_pid() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path();
std::fs::create_dir_all(root.join(".agent-doc")).unwrap();
let mut ipc = start_echo_handler(root, "test-pid");
std::thread::sleep(Duration::from_millis(50));
let sock = socket_path(root, "test-pid");
let resp = send_command(&sock, &IpcMethod::Pid).unwrap();
assert!(resp.ok);
assert_eq!(resp.data.unwrap()["pid"], 12345);
ipc.stop();
}
#[test]
fn roundtrip_restart() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path();
std::fs::create_dir_all(root.join(".agent-doc")).unwrap();
let mut ipc = start_echo_handler(root, "test-restart");
std::thread::sleep(Duration::from_millis(50));
let sock = socket_path(root, "test-restart");
let resp = send_command(
&sock,
&IpcMethod::Restart {
mode: "fresh".to_string(),
},
)
.unwrap();
assert!(resp.ok);
assert_eq!(resp.data.unwrap()["mode"], "fresh");
ipc.stop();
}
#[test]
fn roundtrip_inject() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path();
std::fs::create_dir_all(root.join(".agent-doc")).unwrap();
let mut ipc = start_echo_handler(root, "test-inject");
std::thread::sleep(Duration::from_millis(50));
let sock = socket_path(root, "test-inject");
let resp = send_command(
&sock,
&IpcMethod::Inject {
bytes: "/agent-doc plan.md\r".to_string(),
},
)
.unwrap();
assert!(resp.ok);
assert_eq!(resp.data.unwrap()["n"], 19);
ipc.stop();
}
#[test]
fn malformed_json_returns_error() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path();
std::fs::create_dir_all(root.join(".agent-doc")).unwrap();
let mut ipc = start_echo_handler(root, "test-malformed");
std::thread::sleep(Duration::from_millis(50));
let sock = socket_path(root, "test-malformed");
let stream = try_connect(&sock).unwrap();
let (reader_half, mut writer_half) = stream.split();
writer_half
.write_all(b"{\"not_a_method\": true}\n")
.unwrap();
writer_half.flush().unwrap();
let mut reader = BufReader::new(reader_half);
let mut line = String::new();
reader.read_line(&mut line).unwrap();
let resp: IpcResponse = serde_json::from_str(line.trim()).unwrap();
assert!(!resp.ok);
assert!(resp.error.unwrap().contains("parse error"));
ipc.stop();
}
#[cfg(unix)]
#[test]
fn socket_permissions_are_0600() {
use std::os::unix::fs::PermissionsExt;
let dir = tempfile::tempdir().unwrap();
let root = dir.path();
std::fs::create_dir_all(root.join(".agent-doc")).unwrap();
let mut ipc = start_echo_handler(root, "test-perms");
std::thread::sleep(Duration::from_millis(50));
let sock = socket_path(root, "test-perms");
let meta = std::fs::metadata(&sock).unwrap();
let mode = meta.permissions().mode() & 0o777;
assert_eq!(
mode & 0o600,
0o600,
"owner should have rw, got mode: {mode:o}"
);
assert_eq!(
mode & 0o077,
0,
"group/other should have no access, got mode: {mode:o}"
);
ipc.stop();
}
#[test]
fn stale_socket_cleaned_on_start() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path();
let supervisor_dir = root.join(".agent-doc").join("supervisor");
std::fs::create_dir_all(&supervisor_dir).unwrap();
let sock = socket_path(root, "test-stale");
std::fs::write(&sock, "stale").unwrap();
assert!(sock.exists());
let mut ipc = start_echo_handler(root, "test-stale");
std::thread::sleep(Duration::from_millis(50));
let resp = send_command(&sock, &IpcMethod::Pid).unwrap();
assert!(resp.ok);
ipc.stop();
}
#[test]
fn concurrent_clients() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path();
std::fs::create_dir_all(root.join(".agent-doc")).unwrap();
let call_count = Arc::new(AtomicU32::new(0));
let count_clone = call_count.clone();
let mut ipc = SupervisorIpc::start(root, "test-concurrent", move |method| {
count_clone.fetch_add(1, Ordering::Relaxed);
match method {
IpcMethod::Pid => IpcResponse::ok(serde_json::json!({ "pid": 1 })),
_ => IpcResponse::ok_empty(),
}
})
.unwrap();
std::thread::sleep(Duration::from_millis(50));
let sock = socket_path(root, "test-concurrent");
let mut handles = Vec::new();
for _ in 0..5 {
let s = sock.clone();
handles.push(std::thread::spawn(move || {
send_command(&s, &IpcMethod::Pid).unwrap()
}));
}
for h in handles {
let resp = h.join().unwrap();
assert!(resp.ok);
}
assert!(call_count.load(Ordering::Relaxed) >= 5);
ipc.stop();
}
#[test]
fn is_active_detects_running_listener() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path();
std::fs::create_dir_all(root.join(".agent-doc")).unwrap();
assert!(!is_active(root, "test-active"));
let mut ipc = start_echo_handler(root, "test-active");
std::thread::sleep(Duration::from_millis(50));
assert!(is_active(root, "test-active"));
ipc.stop();
assert!(!is_active(root, "test-active"));
}
#[test]
fn stop_removes_socket_file() {
let dir = tempfile::tempdir().unwrap();
let root = dir.path();
std::fs::create_dir_all(root.join(".agent-doc")).unwrap();
let mut ipc = start_echo_handler(root, "test-cleanup");
std::thread::sleep(Duration::from_millis(50));
let sock = socket_path(root, "test-cleanup");
assert!(sock.exists());
ipc.stop();
assert!(!sock.exists(), "socket should be removed after stop");
}
#[test]
fn socket_path_falls_back_for_long_paths() {
let long_root = PathBuf::from("/tmp")
.join("a".repeat(80))
.join("nested");
let uuid = "12345678-abcd-ef01-2345-6789abcdef01";
let path = socket_path(&long_root, uuid);
assert!(
path.as_os_str().len() <= SUN_PATH_MAX,
"fallback path {} is {} bytes, exceeds {SUN_PATH_MAX}",
path.display(),
path.as_os_str().len()
);
assert!(
path.to_string_lossy().contains("agent-doc"),
"fallback path should contain agent-doc: {}",
path.display()
);
let path2 = socket_path(&long_root, uuid);
assert_eq!(path, path2);
}
#[test]
fn socket_path_prefers_project_dir_for_short_paths() {
let short_root = PathBuf::from("/tmp/proj");
let uuid = "abcd1234";
let path = socket_path(&short_root, uuid);
assert_eq!(
path,
short_root
.join(".agent-doc")
.join("supervisor")
.join("abcd1234.sock")
);
}
#[test]
fn long_path_fallback_binds_successfully() {
let long_root = PathBuf::from("/tmp")
.join("a".repeat(80))
.join("nested");
let uuid = "test-long-path";
let sock = socket_path(&long_root, uuid);
if let Some(parent) = sock.parent() {
std::fs::create_dir_all(parent).unwrap();
}
let mut ipc = SupervisorIpc::start(&long_root, uuid, |method| match method {
IpcMethod::State => IpcResponse::ok(serde_json::json!({"running": true})),
_ => IpcResponse::err("not implemented"),
})
.expect("should bind despite long project root");
std::thread::sleep(Duration::from_millis(50));
assert!(is_active(&long_root, uuid));
ipc.stop();
assert!(!is_active(&long_root, uuid));
}
}