use std::path::PathBuf;
use astrid_core::session_token::SessionToken;
use tokio::net::UnixListener;
use tracing::warn;
#[must_use]
pub(crate) fn kernel_socket_path() -> PathBuf {
use astrid_core::dirs::AstridHome;
match AstridHome::resolve() {
Ok(home) => home.socket_path(),
Err(e) => {
warn!(error = %e, "Failed to resolve ASTRID_HOME; falling back to /tmp/.astrid/run/system.sock");
PathBuf::from("/tmp/.astrid/run/system.sock")
},
}
}
#[cfg(any(target_os = "macos", target_os = "freebsd", target_os = "openbsd"))]
const MAX_SOCKET_PATH_LEN: usize = 104;
#[cfg(not(any(target_os = "macos", target_os = "freebsd", target_os = "openbsd")))]
const MAX_SOCKET_PATH_LEN: usize = 108;
pub(crate) fn bind_session_socket() -> Result<UnixListener, std::io::Error> {
let path = kernel_socket_path();
prepare_socket_path(&path)?;
remove_readiness_file();
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
std::io::Error::other(format!(
"Failed to create socket parent directory {}: {e}",
parent.display()
))
})?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o700))?;
}
}
UnixListener::bind(&path)
}
pub(crate) fn generate_session_token() -> Result<(SessionToken, PathBuf), std::io::Error> {
use astrid_core::dirs::AstridHome;
let token = SessionToken::generate();
let home = AstridHome::resolve().map_err(|e| {
std::io::Error::other(format!(
"Cannot generate session token: failed to resolve ASTRID_HOME: {e}"
))
})?;
let path = home.token_path();
token.write_to_file(&path)?;
Ok((token, path))
}
fn prepare_socket_path(path: &std::path::Path) -> Result<(), std::io::Error> {
let path_len = path.as_os_str().as_encoded_bytes().len();
if path_len >= MAX_SOCKET_PATH_LEN {
return Err(std::io::Error::other(format!(
"Socket path is {path_len} bytes, exceeding the platform limit of {MAX_SOCKET_PATH_LEN} bytes: {}",
path.display()
)));
}
if path.is_symlink() {
warn!(path = %path.display(), "Removing unexpected symlink at socket path");
std::fs::remove_file(path).map_err(|e| {
std::io::Error::other(format!(
"Failed to remove symlink at socket path {}: {e}",
path.display()
))
})?;
} else if path.exists() {
match std::os::unix::net::UnixStream::connect(path) {
Ok(_stream) => {
return Err(std::io::Error::other(format!(
"Another kernel instance is already running on this socket: {}",
path.display()
)));
},
Err(e) if e.kind() == std::io::ErrorKind::ConnectionRefused => {
std::fs::remove_file(path).map_err(|e| {
std::io::Error::other(format!(
"Failed to remove stale socket {}: {e}",
path.display()
))
})?;
},
Err(e) => {
return Err(std::io::Error::other(format!(
"Failed to probe existing socket {}: {e}",
path.display()
)));
},
}
}
Ok(())
}
#[must_use]
pub fn readiness_path() -> PathBuf {
use astrid_core::dirs::AstridHome;
match AstridHome::resolve() {
Ok(home) => home.ready_path(),
Err(e) => {
warn!(
error = %e,
"Failed to resolve ASTRID_HOME; falling back to /tmp/.astrid/run/system.ready"
);
PathBuf::from("/tmp/.astrid/run/system.ready")
},
}
}
pub fn write_readiness_file() -> Result<(), std::io::Error> {
use std::fs::OpenOptions;
let path = readiness_path();
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let mut opts = OpenOptions::new();
opts.write(true).create(true).truncate(true);
#[cfg(unix)]
{
use std::os::unix::fs::OpenOptionsExt;
opts.mode(0o600);
}
opts.open(&path)?;
Ok(())
}
pub fn remove_readiness_file() {
let _ = std::fs::remove_file(readiness_path());
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn path_too_long_is_rejected() {
let long_name = "a".repeat(MAX_SOCKET_PATH_LEN + 10);
let path = PathBuf::from(format!("/tmp/{long_name}.sock"));
let err = prepare_socket_path(&path).unwrap_err();
assert!(
err.to_string().contains("exceeding the platform limit"),
"unexpected error: {err}"
);
}
#[test]
fn stale_socket_is_removed() {
let dir = tempfile::tempdir().unwrap();
let sock = dir.path().join("test.sock");
let _listener = std::os::unix::net::UnixListener::bind(&sock).unwrap();
drop(_listener);
assert!(sock.exists(), "socket file should exist after bind");
prepare_socket_path(&sock).unwrap();
assert!(!sock.exists(), "stale socket should have been removed");
}
#[test]
fn live_socket_is_rejected() {
let dir = tempfile::tempdir().unwrap();
let sock = dir.path().join("test.sock");
let _listener = std::os::unix::net::UnixListener::bind(&sock).unwrap();
let err = prepare_socket_path(&sock).unwrap_err();
assert!(
err.to_string().contains("already running"),
"unexpected error: {err}"
);
}
#[test]
fn symlink_is_removed() {
let dir = tempfile::tempdir().unwrap();
let target = dir.path().join("target");
std::fs::write(&target, "not a socket").unwrap();
let sock = dir.path().join("test.sock");
std::os::unix::fs::symlink(&target, &sock).unwrap();
assert!(sock.is_symlink());
prepare_socket_path(&sock).unwrap();
assert!(!sock.exists(), "symlink should have been removed");
assert!(target.exists(), "target should be untouched");
}
#[test]
fn nonexistent_path_succeeds() {
let dir = tempfile::tempdir().unwrap();
let sock = dir.path().join("does_not_exist.sock");
prepare_socket_path(&sock).unwrap();
}
}