use std::path::{Path, PathBuf};
use std::time::Duration;
use directories::ProjectDirs;
use super::protocol::{CommsRequest, CommsResponse, PROTO_VER, StatusReport};
const COMMS_SUBDIR: &str = "comms";
#[cfg(not(windows))]
const SOCKET_FILE: &str = "comms.sock";
#[cfg(unix)]
const OWNER_ONLY_DIR: u32 = 0o700;
#[cfg(unix)]
const OWNER_ONLY_FILE: u32 = 0o600;
const SPAWN_READY_TIMEOUT: Duration = Duration::from_secs(5);
const SPAWN_POLL_INTERVAL: Duration = Duration::from_millis(50);
const TAKEOVER_DRAIN_TIMEOUT: Duration = Duration::from_secs(3);
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct CommsPaths {
pub comms_dir: PathBuf,
pub socket_path: PathBuf,
}
#[derive(Debug, thiserror::Error)]
pub enum SingletonError {
#[error("could not resolve a per-user data directory for basemind")]
NoDataDir,
#[error("io error on {path}: {source}")]
Io {
path: PathBuf,
#[source]
source: std::io::Error,
},
#[error("a comms daemon is already running at {0}")]
AlreadyRunning(PathBuf),
#[error("spawned comms daemon did not become ready within the timeout")]
SpawnTimeout,
#[error(
"a previous basemind comms daemon (v{version}, pid {pid}) is still running and did not \
stop; run `basemind comms stop` or terminate pid {pid}, then retry"
)]
StalePredecessor {
version: String,
pid: u32,
},
}
pub const COMMS_DIR_ENV: &str = "BASEMIND_COMMS_DIR";
pub fn resolve_paths() -> Result<CommsPaths, SingletonError> {
let comms_dir = match std::env::var_os(COMMS_DIR_ENV) {
Some(dir) if !dir.is_empty() => PathBuf::from(dir),
_ => {
let dirs = ProjectDirs::from("", "", "basemind").ok_or(SingletonError::NoDataDir)?;
dirs.data_dir().join(COMMS_SUBDIR)
}
};
std::fs::create_dir_all(&comms_dir).map_err(|source| SingletonError::Io {
path: comms_dir.clone(),
source,
})?;
#[cfg(unix)]
set_mode(&comms_dir, OWNER_ONLY_DIR)?;
let socket_path = comms_socket_path(&comms_dir);
Ok(CommsPaths {
comms_dir,
socket_path,
})
}
pub fn comms_socket_path(comms_dir: &Path) -> PathBuf {
#[cfg(windows)]
{
let _ = comms_dir;
let user = std::env::var("USERNAME").unwrap_or_else(|_| "default".to_string());
PathBuf::from(format!(r"\\.\pipe\basemind-comms-{user}"))
}
#[cfg(not(windows))]
{
comms_dir.join(SOCKET_FILE)
}
}
#[cfg(unix)]
fn set_mode(path: &Path, mode: u32) -> Result<(), SingletonError> {
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(path, std::fs::Permissions::from_mode(mode)).map_err(|source| {
SingletonError::Io {
path: path.to_path_buf(),
source,
}
})
}
#[cfg(unix)]
pub fn bind_listener(
socket_path: &Path,
probe: impl Fn(&Path) -> bool,
) -> Result<tokio::net::UnixListener, SingletonError> {
use std::os::unix::fs::PermissionsExt;
match std::os::unix::net::UnixListener::bind(socket_path) {
Ok(std_listener) => {
std_listener
.set_nonblocking(true)
.map_err(|source| SingletonError::Io {
path: socket_path.to_path_buf(),
source,
})?;
let _ = std::fs::set_permissions(
socket_path,
std::fs::Permissions::from_mode(OWNER_ONLY_FILE),
);
tokio::net::UnixListener::from_std(std_listener).map_err(|source| SingletonError::Io {
path: socket_path.to_path_buf(),
source,
})
}
Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => {
if probe(socket_path) {
return Err(SingletonError::AlreadyRunning(socket_path.to_path_buf()));
}
std::fs::remove_file(socket_path).map_err(|source| SingletonError::Io {
path: socket_path.to_path_buf(),
source,
})?;
let std_listener =
std::os::unix::net::UnixListener::bind(socket_path).map_err(|source| {
SingletonError::Io {
path: socket_path.to_path_buf(),
source,
}
})?;
std_listener
.set_nonblocking(true)
.map_err(|source| SingletonError::Io {
path: socket_path.to_path_buf(),
source,
})?;
let _ = std::fs::set_permissions(
socket_path,
std::fs::Permissions::from_mode(OWNER_ONLY_FILE),
);
tokio::net::UnixListener::from_std(std_listener).map_err(|source| SingletonError::Io {
path: socket_path.to_path_buf(),
source,
})
}
Err(source) => Err(SingletonError::Io {
path: socket_path.to_path_buf(),
source,
}),
}
}
#[cfg(windows)]
pub fn bind_listener(
socket_path: &Path,
probe: impl Fn(&Path) -> bool,
) -> Result<tokio::net::windows::named_pipe::NamedPipeServer, SingletonError> {
use tokio::net::windows::named_pipe::ServerOptions;
let pipe_name = socket_path.as_os_str();
let io_err = |source: std::io::Error| SingletonError::Io {
path: socket_path.to_path_buf(),
source,
};
match ServerOptions::new()
.first_pipe_instance(true)
.create(pipe_name)
{
Ok(server) => Ok(server),
Err(e) if e.kind() == std::io::ErrorKind::PermissionDenied => {
if probe(socket_path) {
return Err(SingletonError::AlreadyRunning(socket_path.to_path_buf()));
}
ServerOptions::new()
.first_pipe_instance(true)
.create(pipe_name)
.map_err(io_err)
}
Err(source) => Err(io_err(source)),
}
}
pub async fn ensure_daemon_with(
paths: &CommsPaths,
is_alive: impl Fn(&Path) -> bool,
spawn: impl FnOnce(&CommsPaths) -> std::io::Result<()>,
) -> Result<(), SingletonError> {
if is_alive(&paths.socket_path) {
return Ok(());
}
spawn(paths).map_err(|source| SingletonError::Io {
path: paths.socket_path.clone(),
source,
})?;
let deadline = std::time::Instant::now() + SPAWN_READY_TIMEOUT;
while std::time::Instant::now() < deadline {
if is_alive(&paths.socket_path) {
return Ok(());
}
tokio::time::sleep(SPAWN_POLL_INTERVAL).await;
}
Err(SingletonError::SpawnTimeout)
}
pub async fn ensure_daemon(paths: &CommsPaths) -> Result<(), SingletonError> {
if let Some(report) = daemon_status(&paths.socket_path) {
let ours = env!("CARGO_PKG_VERSION");
let compatible = report.proto_ver == PROTO_VER && !version_is_older(&report.version, ours);
if compatible {
return Ok(()); }
tracing::warn!(
daemon_version = %report.version,
daemon_pid = report.pid,
ours,
"comms: a previous/incompatible daemon holds the socket; taking over"
);
request_stop(&paths.socket_path);
let deadline = std::time::Instant::now() + TAKEOVER_DRAIN_TIMEOUT;
while std::time::Instant::now() < deadline {
if !probe_alive(&paths.socket_path) {
break;
}
tokio::time::sleep(SPAWN_POLL_INTERVAL).await;
}
if probe_alive(&paths.socket_path) {
return Err(SingletonError::StalePredecessor {
version: report.version,
pid: report.pid,
});
}
}
ensure_daemon_with(paths, probe_alive, spawn_detached_daemon).await
}
fn version_is_older(daemon: &str, ours: &str) -> bool {
fn triple(v: &str) -> (u64, u64, u64) {
let core = v.split('-').next().unwrap_or(v);
let mut it = core.split('.').map(|p| p.parse::<u64>().unwrap_or(0));
(
it.next().unwrap_or(0),
it.next().unwrap_or(0),
it.next().unwrap_or(0),
)
}
triple(daemon) < triple(ours)
}
fn daemon_status(socket_path: &Path) -> Option<StatusReport> {
match roundtrip(socket_path, &CommsRequest::Status)? {
CommsResponse::Status(report) => Some(report),
_ => None,
}
}
fn request_stop(socket_path: &Path) {
let _ = roundtrip(socket_path, &CommsRequest::Stop);
}
fn roundtrip(socket_path: &Path, req: &CommsRequest) -> Option<CommsResponse> {
use std::io::{Read, Write};
let mut stream = open_endpoint(socket_path)?;
let body = rmp_serde::to_vec_named(req).ok()?;
let len = u32::try_from(body.len()).ok()?;
stream.write_all(&len.to_be_bytes()).ok()?;
stream.write_all(&body).ok()?;
let mut prefix = [0u8; 4];
stream.read_exact(&mut prefix).ok()?;
let rlen = u32::from_be_bytes(prefix) as usize;
if rlen > 64 * 1024 {
return None;
}
let mut buf = vec![0u8; rlen];
stream.read_exact(&mut buf).ok()?;
rmp_serde::from_slice::<CommsResponse>(&buf).ok()
}
#[cfg(unix)]
fn open_endpoint(socket_path: &Path) -> Option<impl std::io::Read + std::io::Write> {
let stream = std::os::unix::net::UnixStream::connect(socket_path).ok()?;
let _ = stream.set_read_timeout(Some(Duration::from_millis(800)));
let _ = stream.set_write_timeout(Some(Duration::from_millis(800)));
Some(stream)
}
#[cfg(windows)]
fn open_endpoint(socket_path: &Path) -> Option<impl std::io::Read + std::io::Write> {
std::fs::OpenOptions::new()
.read(true)
.write(true)
.open(socket_path)
.ok()
}
const PROBE_ATTEMPTS: u32 = 4;
const PROBE_RETRY_BACKOFF: Duration = Duration::from_millis(100);
#[cfg(any(unix, windows))]
pub fn probe_alive(socket_path: &Path) -> bool {
for attempt in 0..PROBE_ATTEMPTS {
if probe_once(socket_path) {
return true;
}
if attempt + 1 < PROBE_ATTEMPTS {
std::thread::sleep(PROBE_RETRY_BACKOFF);
}
}
false
}
#[cfg(unix)]
fn probe_once(socket_path: &Path) -> bool {
use std::io::{Read, Write};
use std::os::unix::net::UnixStream;
let Ok(mut stream) = UnixStream::connect(socket_path) else {
return false;
};
let _ = stream.set_read_timeout(Some(Duration::from_millis(500)));
let _ = stream.set_write_timeout(Some(Duration::from_millis(500)));
let body = match rmp_serde::to_vec_named(&super::protocol::CommsRequest::Ping) {
Ok(b) => b,
Err(_) => return false,
};
let len = match u32::try_from(body.len()) {
Ok(l) => l,
Err(_) => return false,
};
if stream.write_all(&len.to_be_bytes()).is_err() || stream.write_all(&body).is_err() {
return false;
}
let mut prefix = [0u8; 4];
stream.read_exact(&mut prefix).is_ok()
}
#[cfg(windows)]
fn probe_once(socket_path: &Path) -> bool {
use std::io::{Read, Write};
let Ok(mut stream) = std::fs::OpenOptions::new()
.read(true)
.write(true)
.open(socket_path)
else {
return false;
};
let body = match rmp_serde::to_vec_named(&super::protocol::CommsRequest::Ping) {
Ok(b) => b,
Err(_) => return false,
};
let len = match u32::try_from(body.len()) {
Ok(l) => l,
Err(_) => return false,
};
if stream.write_all(&len.to_be_bytes()).is_err() || stream.write_all(&body).is_err() {
return false;
}
let mut prefix = [0u8; 4];
stream.read_exact(&mut prefix).is_ok()
}
#[cfg(not(any(unix, windows)))]
pub fn probe_alive(_socket_path: &Path) -> bool {
false
}
pub fn spawn_detached_daemon(_paths: &CommsPaths) -> std::io::Result<()> {
let exe = std::env::current_exe()?;
let mut command = std::process::Command::new(exe);
command
.arg("comms")
.arg("daemon")
.stdin(std::process::Stdio::null())
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null());
#[cfg(unix)]
{
use std::os::unix::process::CommandExt;
unsafe {
command.pre_exec(|| {
let _ = detach_session();
Ok(())
});
}
}
#[cfg(windows)]
{
use std::os::windows::process::CommandExt;
const DETACHED_PROCESS: u32 = 0x0000_0008;
const CREATE_NEW_PROCESS_GROUP: u32 = 0x0000_0200;
const CREATE_NO_WINDOW: u32 = 0x0800_0000;
command.creation_flags(DETACHED_PROCESS | CREATE_NEW_PROCESS_GROUP | CREATE_NO_WINDOW);
}
command.spawn()?;
Ok(())
}
#[cfg(unix)]
fn detach_session() -> std::io::Result<()> {
let rc = unsafe { setsid() };
if rc == -1 {
Err(std::io::Error::last_os_error())
} else {
Ok(())
}
}
#[cfg(unix)]
unsafe extern "C" {
fn setsid() -> i32;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn version_is_older_orders_releases_and_ignores_prerelease() {
assert!(version_is_older("0.6.3", "0.10.0"));
assert!(version_is_older("0.9.0", "0.10.0"));
assert!(version_is_older("0.10.0", "0.10.1"));
assert!(!version_is_older("0.10.0", "0.10.0"));
assert!(!version_is_older("0.11.0", "0.10.0"));
assert!(!version_is_older("1.0.0", "0.10.0"));
assert!(!version_is_older("0.10.0-rc.1", "0.10.0"));
assert!(version_is_older("0.9.0-rc.2", "0.10.0"));
}
#[cfg(unix)]
#[test]
fn bind_as_lock_admits_exactly_one_winner_in_a_race() {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
let dir = tempfile::tempdir().expect("tempdir");
let socket = dir.path().join("race.sock");
let winners = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
const N: usize = 16;
let listeners = Arc::new(std::sync::Mutex::new(Vec::new()));
for _ in 0..N {
let socket = socket.clone();
let winners = winners.clone();
let listeners = listeners.clone();
handles.push(std::thread::spawn(move || {
let probe = |p: &std::path::Path| p.exists();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("rt");
let result = rt.block_on(async { bind_listener(&socket, probe) });
if let Ok(listener) = result {
winners.fetch_add(1, Ordering::SeqCst);
listeners.lock().expect("lock").push((listener, rt));
}
}));
}
for h in handles {
h.join().expect("join");
}
assert_eq!(
winners.load(Ordering::SeqCst),
1,
"exactly one binder may win the singleton lock"
);
}
#[tokio::test]
async fn ensure_daemon_noops_when_already_alive() {
let paths = CommsPaths {
comms_dir: PathBuf::from("/tmp/x"),
socket_path: PathBuf::from("/tmp/x/comms.sock"),
};
let spawned = std::cell::Cell::new(false);
let res = ensure_daemon_with(
&paths,
|_| true,
|_| {
spawned.set(true);
Ok(())
},
)
.await;
assert!(res.is_ok());
assert!(
!spawned.get(),
"must not spawn when a daemon already answers"
);
}
#[tokio::test]
async fn ensure_daemon_spawns_then_waits_for_ready() {
let paths = CommsPaths {
comms_dir: PathBuf::from("/tmp/x"),
socket_path: PathBuf::from("/tmp/x/comms.sock"),
};
let alive = std::sync::atomic::AtomicBool::new(false);
let res = ensure_daemon_with(
&paths,
|_| alive.load(std::sync::atomic::Ordering::SeqCst),
|_| {
alive.store(true, std::sync::atomic::Ordering::SeqCst);
Ok(())
},
)
.await;
assert!(res.is_ok(), "daemon became ready after spawn");
}
}