use std::{
fs,
path::{Path, PathBuf},
process::{Command, Stdio},
time::Duration,
};
use anyhow::{Context, Result, anyhow};
use repo::daemon::{
ERR_MOUNT_UNSUPPORTED, EndpointState, MOUNT_PROTOCOL_VERSION, MountDaemonRequest,
MountDaemonResponse, MountRegistryFile, load_endpoint, mount_daemon_endpoint_path,
mount_daemon_registry_path, pid_alive, remove_endpoint, send_json_request,
};
use tracing::{debug, warn};
use crate::cli::commands::RecoveryAdvice;
#[derive(Debug)]
pub enum DaemonMountError {
Unavailable(String),
Fatal(anyhow::Error),
}
impl std::fmt::Display for DaemonMountError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Unavailable(reason) => write!(f, "{reason}"),
Self::Fatal(err) => write!(f, "{err}"),
}
}
}
const SPAWN_RETRY_DELAY_MS: u64 = 50;
const SPAWN_RETRIES: usize = 10;
pub fn ensure_daemon_endpoint(
repo_root: &Path,
spawn_if_missing: bool,
) -> Result<Option<EndpointState>> {
let endpoint_path = mount_daemon_endpoint_path(repo_root);
if let Some(endpoint) = read_live_endpoint(&endpoint_path)? {
return Ok(Some(endpoint));
}
sweep_stale_mounts(repo_root);
remove_endpoint(&endpoint_path);
if !spawn_if_missing {
return Ok(None);
}
spawn_daemon_detached(repo_root)?;
for _ in 0..SPAWN_RETRIES {
if let Some(endpoint) = read_live_endpoint(&endpoint_path)? {
return Ok(Some(endpoint));
}
std::thread::sleep(Duration::from_millis(SPAWN_RETRY_DELAY_MS));
}
Err(anyhow!(
"daemon endpoint never appeared at {}; check `heddle daemon serve` for errors",
endpoint_path.display()
))
}
fn read_live_endpoint(endpoint_path: &Path) -> Result<Option<EndpointState>> {
let endpoint = match load_endpoint(endpoint_path) {
Ok(endpoint) => endpoint,
Err(objects::error::HeddleError::Io(error))
if error.kind() == std::io::ErrorKind::NotFound =>
{
return Ok(None);
}
Err(error) => {
warn!(%error, path = %endpoint_path.display(), "ignoring unreadable daemon endpoint");
return Ok(None);
}
};
if endpoint.version != MOUNT_PROTOCOL_VERSION {
warn!(
recorded = endpoint.version,
expected = MOUNT_PROTOCOL_VERSION,
"daemon version mismatch on endpoint file; treating as stale"
);
return Ok(None);
}
if let Some(pid) = endpoint.pid
&& !pid_alive(pid)
{
warn!(pid, "daemon PID is dead; treating endpoint as stale");
return Ok(None);
}
Ok(Some(endpoint))
}
pub fn sweep_stale_mounts(repo_root: &Path) {
let registry_path = mount_daemon_registry_path(repo_root);
let Ok(contents) = fs::read_to_string(®istry_path) else {
return;
};
let registry: MountRegistryFile = match serde_json::from_str(&contents) {
Ok(registry) => registry,
Err(error) => {
warn!(%error, path = %registry_path.display(), "stale mount registry was unparseable; removing");
let _ = fs::remove_file(®istry_path);
return;
}
};
for entry in ®istry.mounts {
debug!(thread = %entry.thread_id, path = %entry.mount_path.display(), "sweeping stale mount");
attempt_fusermount_unmount(&entry.mount_path);
}
let _ = fs::remove_file(®istry_path);
}
#[cfg(target_os = "linux")]
fn attempt_fusermount_unmount(mount_path: &Path) {
if let Err(error) = Command::new("fusermount")
.arg("-u")
.arg(mount_path)
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null())
.status()
{
warn!(%error, path = %mount_path.display(), "fusermount -u failed during sweep");
}
}
#[cfg(not(target_os = "linux"))]
fn attempt_fusermount_unmount(_mount_path: &Path) {
}
pub fn spawn_daemon_detached(repo_root: &Path) -> Result<()> {
let current_exe = std::env::current_exe().context("locate current heddle executable")?;
let mut command = Command::new(current_exe);
command
.arg("--repo")
.arg(repo_root)
.arg("daemon")
.arg("serve")
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null());
#[cfg(target_os = "linux")]
{
use std::os::unix::process::CommandExt;
unsafe {
command.pre_exec(|| {
if libc::setsid() == -1 {
return Err(std::io::Error::last_os_error());
}
Ok(())
});
}
}
command
.spawn()
.with_context(|| format!("spawn heddle daemon for {}", repo_root.display()))?;
Ok(())
}
pub fn rpc(
repo_root: &Path,
request: &MountDaemonRequest,
spawn_if_missing: bool,
) -> Result<Option<MountDaemonResponse>> {
let Some(endpoint) = ensure_daemon_endpoint(repo_root, spawn_if_missing)? else {
return Ok(None);
};
let response: MountDaemonResponse = match send_json_request(&endpoint, request) {
Ok(response) => response,
Err(error) => {
return Err(refine_rpc_error(repo_root, &endpoint, error));
}
};
if response.version() != MOUNT_PROTOCOL_VERSION {
return Err(anyhow!(
"daemon responded with protocol version {}, expected {}",
response.version(),
MOUNT_PROTOCOL_VERSION
));
}
Ok(Some(response))
}
fn refine_rpc_error(
repo_root: &Path,
endpoint: &EndpointState,
error: objects::error::HeddleError,
) -> anyhow::Error {
let endpoint_path = mount_daemon_endpoint_path(repo_root);
if let Ok(recorded) = load_endpoint(&endpoint_path)
&& recorded.version < MOUNT_PROTOCOL_VERSION
{
return anyhow!(error).context(RecoveryAdvice::stale_daemon_protocol(
recorded.version,
MOUNT_PROTOCOL_VERSION,
));
}
anyhow!(error).context(format!(
"RPC to daemon at {}:{}",
endpoint.host, endpoint.port
))
}
pub fn mount_via_daemon_classified(
repo_root: &Path,
thread_id: &str,
mount_path: &Path,
) -> std::result::Result<PathBuf, DaemonMountError> {
let endpoint = match ensure_daemon_endpoint(repo_root, true) {
Ok(Some(endpoint)) => endpoint,
Ok(None) => {
return Err(DaemonMountError::Unavailable(
"daemon endpoint not available and spawn was disabled".to_string(),
));
}
Err(error) => {
return Err(DaemonMountError::Unavailable(format!(
"could not start daemon: {error:#}"
)));
}
};
let request = MountDaemonRequest::Mount {
thread_id: thread_id.to_string(),
mount_path: mount_path.to_path_buf(),
repo_root: repo_root.to_path_buf(),
};
let response: MountDaemonResponse = match send_json_request(&endpoint, &request) {
Ok(response) => response,
Err(error) => {
return Err(DaemonMountError::Unavailable(format!(
"RPC to daemon at {}:{} failed: {error}",
endpoint.host, endpoint.port
)));
}
};
if response.version() != MOUNT_PROTOCOL_VERSION {
return Err(DaemonMountError::Fatal(anyhow!(
"daemon responded with protocol version {}, expected {}",
response.version(),
MOUNT_PROTOCOL_VERSION
)));
}
match response {
MountDaemonResponse::Mount {
ok: true,
mount_path,
..
} => Ok(mount_path),
MountDaemonResponse::Error { code, message, .. } if code == ERR_MOUNT_UNSUPPORTED => Err(
DaemonMountError::Unavailable(format!("daemon cannot mount on this host: {message}")),
),
MountDaemonResponse::Error { code, message, .. } => Err(DaemonMountError::Fatal(anyhow!(
"daemon mount failed: [{code}] {message}"
))),
other => Err(DaemonMountError::Fatal(anyhow!(
"daemon returned unexpected response: {other:?}"
))),
}
}
pub fn unmount_via_daemon(repo_root: &Path, thread_id: &str) -> Result<bool> {
let request = MountDaemonRequest::Unmount {
thread_id: thread_id.to_string(),
};
let response = rpc(repo_root, &request, false)?;
match response {
Some(MountDaemonResponse::Unmount { was_mounted, .. }) => Ok(was_mounted),
Some(MountDaemonResponse::Error { code, message, .. }) => {
Err(anyhow!("daemon unmount failed: [{code}] {message}"))
}
Some(other) => Err(anyhow!("daemon returned unexpected response: {other:?}")),
None => Ok(false),
}
}
#[cfg(test)]
mod tests {
use std::{io::Write, net::TcpListener, path::PathBuf};
use repo::daemon::{
EndpointState, MOUNT_PROTOCOL_VERSION, MountDaemonRequest, MountRegistryFile,
PersistedMount, mount_daemon_endpoint_path, mount_daemon_registry_path, persist_endpoint,
};
use tempfile::TempDir;
use super::*;
fn write_endpoint(repo_root: &Path, endpoint: &EndpointState) {
let path = mount_daemon_endpoint_path(repo_root);
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).unwrap();
}
persist_endpoint(&path, endpoint).unwrap();
}
#[test]
fn read_live_endpoint_treats_version_skew_as_stale() {
let tmp = TempDir::new().unwrap();
write_endpoint(
tmp.path(),
&EndpointState {
version: MOUNT_PROTOCOL_VERSION + 99,
host: "127.0.0.1".to_string(),
port: 1,
pid: Some(1),
},
);
let endpoint_path = mount_daemon_endpoint_path(tmp.path());
let result = read_live_endpoint(&endpoint_path).unwrap();
assert!(
result.is_none(),
"version-skewed endpoint must be treated as stale"
);
}
#[cfg(unix)]
#[test]
fn read_live_endpoint_returns_alive_endpoint() {
let tmp = TempDir::new().unwrap();
write_endpoint(
tmp.path(),
&EndpointState {
version: MOUNT_PROTOCOL_VERSION,
host: "127.0.0.1".to_string(),
port: 9999,
pid: Some(1),
},
);
let endpoint_path = mount_daemon_endpoint_path(tmp.path());
let result = read_live_endpoint(&endpoint_path).unwrap();
assert!(result.is_some(), "alive endpoint must be returned as-is");
}
#[cfg(unix)]
#[test]
fn read_live_endpoint_detects_dead_pid() {
let tmp = TempDir::new().unwrap();
write_endpoint(
tmp.path(),
&EndpointState {
version: MOUNT_PROTOCOL_VERSION,
host: "127.0.0.1".to_string(),
port: 9999,
pid: Some(0x7fff_fffe),
},
);
let endpoint_path = mount_daemon_endpoint_path(tmp.path());
let result = read_live_endpoint(&endpoint_path).unwrap();
assert!(result.is_none(), "endpoint with dead PID must be stale");
}
#[test]
fn read_live_endpoint_handles_missing_file() {
let tmp = TempDir::new().unwrap();
let endpoint_path = mount_daemon_endpoint_path(tmp.path());
let result = read_live_endpoint(&endpoint_path).unwrap();
assert!(result.is_none());
}
#[test]
fn sweep_stale_mounts_clears_registry_file() {
let tmp = TempDir::new().unwrap();
let registry_path = mount_daemon_registry_path(tmp.path());
std::fs::create_dir_all(registry_path.parent().unwrap()).unwrap();
let registry = MountRegistryFile {
mounts: vec![PersistedMount {
thread_id: "ghost".to_string(),
mount_path: PathBuf::from("/nonexistent-mount-point"),
pid: 1,
since_ms: 0,
}],
};
std::fs::write(
®istry_path,
serde_json::to_vec_pretty(®istry).unwrap(),
)
.unwrap();
sweep_stale_mounts(tmp.path());
assert!(
!registry_path.exists(),
"sweep must remove the registry file even when entries can't be unmounted"
);
}
#[test]
fn sweep_stale_mounts_is_noop_when_registry_absent() {
let tmp = TempDir::new().unwrap();
sweep_stale_mounts(tmp.path()); }
#[test]
fn rpc_hints_at_stale_daemon_when_endpoint_version_is_older() {
let tmp = TempDir::new().unwrap();
let repo_root: PathBuf = tmp.path().to_path_buf();
let listener = match TcpListener::bind("127.0.0.1:0") {
Ok(listener) => listener,
Err(err) if err.kind() == std::io::ErrorKind::PermissionDenied => {
eprintln!("skipping stale-daemon RPC test: loopback bind denied: {err}");
return;
}
Err(err) => panic!("bind loopback listener: {err}"),
};
let port = listener.local_addr().unwrap().port();
let server_repo = repo_root.clone();
let server = std::thread::spawn(move || {
if let Ok((mut stream, _)) = listener.accept() {
let endpoint_path = mount_daemon_endpoint_path(&server_repo);
persist_endpoint(
&endpoint_path,
&EndpointState {
version: MOUNT_PROTOCOL_VERSION - 1,
host: "127.0.0.1".to_string(),
port,
pid: Some(1),
},
)
.unwrap();
let _ = stream.write_all(b"this is not json\n");
}
});
write_endpoint(
&repo_root,
&EndpointState {
version: MOUNT_PROTOCOL_VERSION,
host: "127.0.0.1".to_string(),
port,
pid: Some(1),
},
);
let err = rpc(&repo_root, &MountDaemonRequest::Health {}, false)
.expect_err("v1 daemon reply must surface as an error");
let _ = server.join();
let chain = format!("{err:#}");
assert!(
chain.contains("heddled daemon is older"),
"expected stale-daemon hint in error chain, got: {chain}"
);
assert!(
chain.contains("heddle daemon stop"),
"expected remediation hint in error chain, got: {chain}"
);
assert!(
chain.contains(&format!("v{}", MOUNT_PROTOCOL_VERSION - 1)),
"expected recorded daemon version in error chain, got: {chain}"
);
assert!(
chain.contains(&format!("v{MOUNT_PROTOCOL_VERSION}")),
"expected CLI version in error chain, got: {chain}"
);
}
}