use crate::transport::rpc::{self, JsonRpcRequest, JsonRpcResponse};
use crate::AppState;
use anyhow::{Context, Result};
use std::path::{Path, PathBuf};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::{UnixListener, UnixStream};
pub const SOCKET_FILE_NAME: &str = "trusty-memory.sock";
pub const UDS_ADDR_FILE: &str = "uds_addr";
pub fn socket_path() -> PathBuf {
runtime_dir().join(SOCKET_FILE_NAME)
}
pub fn socket_path_for(data_root: &Path) -> PathBuf {
let production = trusty_common::resolve_data_dir("trusty-memory")
.ok()
.map(|d| {
let with_palaces = d.join("palaces");
(d, with_palaces)
});
if let Some((bare, with_palaces)) = production {
if data_root == bare || data_root == with_palaces {
return socket_path();
}
}
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
data_root.hash(&mut hasher);
let h = hasher.finish();
runtime_dir().join(format!("trusty-memory-{h:016x}.sock"))
}
fn runtime_dir() -> PathBuf {
if let Ok(d) = std::env::var("XDG_RUNTIME_DIR") {
if !d.is_empty() {
return PathBuf::from(d);
}
}
if let Ok(d) = std::env::var("TMPDIR") {
if !d.is_empty() {
return PathBuf::from(d);
}
}
std::env::temp_dir()
}
pub fn write_uds_addr_file(data_root: &Path, sock_path: &Path) -> std::io::Result<()> {
let path = data_root.join(UDS_ADDR_FILE);
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let tmp = path.with_extension("addr.tmp");
{
use std::io::Write;
let mut f = std::fs::File::create(&tmp)?;
writeln!(f, "{}", sock_path.display())?;
f.sync_all()?;
}
std::fs::rename(&tmp, &path)?;
Ok(())
}
pub async fn clean_stale_socket(sock_path: &Path) -> Result<()> {
if !sock_path.exists() {
return Ok(());
}
match UnixStream::connect(sock_path).await {
Ok(_stream) => {
anyhow::bail!(
"another trusty-memory daemon is already listening on {}",
sock_path.display()
);
}
Err(_) => {
std::fs::remove_file(sock_path).with_context(|| {
format!(
"remove stale socket {} (no live owner)",
sock_path.display()
)
})?;
Ok(())
}
}
}
pub async fn bind_uds(sock_path: &Path) -> Result<UnixListener> {
clean_stale_socket(sock_path).await?;
if let Some(parent) = sock_path.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("create socket parent {}", parent.display()))?;
}
let listener = UnixListener::bind(sock_path)
.with_context(|| format!("bind UDS at {}", sock_path.display()))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = std::fs::metadata(sock_path)?.permissions();
perms.set_mode(0o600);
std::fs::set_permissions(sock_path, perms)?;
}
Ok(listener)
}
pub async fn run_uds(state: AppState, listener: UnixListener) -> Result<()> {
tracing::info!("UDS listener accepting connections");
loop {
match listener.accept().await {
Ok((stream, _addr)) => {
let state = state.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(state, stream).await {
tracing::debug!("UDS connection ended: {e:#}");
}
});
}
Err(e) => {
tracing::warn!("UDS accept error: {e:#}");
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
}
}
}
async fn handle_connection(state: AppState, stream: UnixStream) -> Result<()> {
let (read_half, mut write_half) = stream.into_split();
let mut reader = BufReader::new(read_half);
let mut line = String::new();
loop {
line.clear();
let n = reader.read_line(&mut line).await.context("UDS read_line")?;
if n == 0 {
return Ok(());
}
let trimmed = line.trim_end_matches(['\n', '\r']);
if trimmed.is_empty() {
continue;
}
let response = match serde_json::from_str::<JsonRpcRequest>(trimmed) {
Ok(req) => {
let is_notification = req.id.is_none() || req.id == Some(serde_json::Value::Null);
let resp = rpc::dispatch(&state, req).await;
if is_notification {
continue;
}
resp
}
Err(e) => JsonRpcResponse::err(
serde_json::Value::Null,
rpc::error_codes::PARSE_ERROR,
format!("parse error: {e}"),
),
};
let mut serialized = serde_json::to_string(&response).context("serialise response")?;
debug_assert!(
!serialized.contains('\n'),
"response must not contain embedded newlines: {serialized}"
);
serialized.push('\n');
write_half
.write_all(serialized.as_bytes())
.await
.context("UDS write_all")?;
write_half.flush().await.context("UDS flush")?;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn socket_path_uses_tmp_dir_on_macos() {
let _guard = crate::commands::env_test_lock().lock().await;
let original_tmpdir = std::env::var("TMPDIR").ok();
let original_xdg = std::env::var("XDG_RUNTIME_DIR").ok();
unsafe {
std::env::set_var("TMPDIR", "/tmp");
std::env::remove_var("XDG_RUNTIME_DIR");
}
let p = socket_path();
assert!(
p.ends_with(SOCKET_FILE_NAME),
"expected suffix {SOCKET_FILE_NAME}, got {}",
p.display()
);
unsafe {
match original_tmpdir {
Some(v) => std::env::set_var("TMPDIR", v),
None => std::env::remove_var("TMPDIR"),
}
match original_xdg {
Some(v) => std::env::set_var("XDG_RUNTIME_DIR", v),
None => std::env::remove_var("XDG_RUNTIME_DIR"),
}
}
}
#[tokio::test]
async fn runtime_dir_uses_xdg_runtime_dir_first() {
let _guard = crate::commands::env_test_lock().lock().await;
let original_tmpdir = std::env::var("TMPDIR").ok();
let original_xdg = std::env::var("XDG_RUNTIME_DIR").ok();
unsafe {
std::env::set_var("XDG_RUNTIME_DIR", "/tmp/xdg-test");
std::env::set_var("TMPDIR", "/tmp/tmpdir-test");
}
let d = runtime_dir();
assert_eq!(d, PathBuf::from("/tmp/xdg-test"));
unsafe {
match original_tmpdir {
Some(v) => std::env::set_var("TMPDIR", v),
None => std::env::remove_var("TMPDIR"),
}
match original_xdg {
Some(v) => std::env::set_var("XDG_RUNTIME_DIR", v),
None => std::env::remove_var("XDG_RUNTIME_DIR"),
}
}
}
#[test]
fn write_uds_addr_file_round_trip() {
let tmp = tempfile::tempdir().expect("tempdir");
let sock = PathBuf::from("/tmp/foo.sock");
write_uds_addr_file(tmp.path(), &sock).expect("write");
let raw = std::fs::read_to_string(tmp.path().join(UDS_ADDR_FILE)).expect("read");
assert_eq!(raw.trim(), "/tmp/foo.sock");
assert!(raw.ends_with('\n'));
}
#[tokio::test]
async fn stale_socket_is_cleaned_up() {
let tmp = tempfile::tempdir().expect("tempdir");
let sock = tmp.path().join("leftover.sock");
std::fs::write(&sock, b"").expect("touch");
assert!(sock.exists());
clean_stale_socket(&sock).await.expect("clean");
assert!(!sock.exists(), "stale socket must be removed");
}
#[test]
fn socket_path_for_data_root_returns_per_root_path() {
let a = socket_path_for(Path::new("/tmp/a-test-root"));
let b = socket_path_for(Path::new("/tmp/b-test-root"));
let a_again = socket_path_for(Path::new("/tmp/a-test-root"));
assert_ne!(a, b, "different data roots must yield different sockets");
assert_eq!(a, a_again, "same data root must yield same socket");
assert!(
a.to_string_lossy().contains("trusty-memory-"),
"per-root socket must carry the per-root prefix: {}",
a.display()
);
}
#[tokio::test]
async fn bind_uds_creates_socket_file() {
let tmp = tempfile::tempdir().expect("tempdir");
let sock = tmp.path().join("daemon.sock");
let _listener = bind_uds(&sock).await.expect("bind");
assert!(sock.exists(), "socket file must exist after bind");
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mode = std::fs::metadata(&sock).unwrap().permissions().mode() & 0o777;
assert_eq!(mode, 0o600, "socket must be owner-only");
}
}
}