use std::{
io,
path::{Path, PathBuf},
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::Duration,
};
#[cfg(unix)]
use daemon::local_daemon::is_heddle_process;
use crate::util::OnceMap;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UdsTarget {
pub socket_path: PathBuf,
pub pid: u32,
}
type ProbeCacheKey = PathBuf;
static DETECT_CACHE: OnceMap<ProbeCacheKey, Option<UdsTarget>> = OnceMap::new();
pub fn detect_local_daemon(heddle_dir: &Path) -> Option<UdsTarget> {
let key: ProbeCacheKey = heddle_dir.to_path_buf();
DETECT_CACHE.get_or_init_with(&key, || {
let probe = probe(heddle_dir);
match probe.status {
LocalDaemonStatus::Running { pid } => Some(UdsTarget {
socket_path: probe.socket_path,
pid,
}),
LocalDaemonStatus::Stale { .. } | LocalDaemonStatus::Absent => None,
}
})
}
#[cfg(unix)]
pub async fn detect_local_daemon_with_connect_probe(
heddle_dir: &Path,
timeout: Duration,
) -> Option<UdsTarget> {
let target = detect_local_daemon(heddle_dir)?;
match tokio::time::timeout(
timeout,
tokio::net::UnixStream::connect(&target.socket_path),
)
.await
{
Ok(Ok(stream)) => match check_peer_uid_matches_self(&stream) {
Ok(()) => Some(target),
Err(_) => None,
},
Ok(Err(_)) | Err(_) => None,
}
}
#[cfg(unix)]
static CHANNEL_CACHE: OnceMap<ProbeCacheKey, tonic::transport::Channel> = OnceMap::new();
#[cfg(unix)]
#[doc(hidden)]
static CHANNEL_BUILD_COUNT: OnceMap<ProbeCacheKey, Arc<AtomicU64>> = OnceMap::new();
#[cfg(unix)]
#[doc(hidden)]
pub fn channel_build_count(heddle_dir: &Path) -> u64 {
CHANNEL_BUILD_COUNT
.get(&heddle_dir.to_path_buf())
.map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0)
}
#[cfg(unix)]
#[derive(Debug, Clone)]
pub struct LocalDaemonChannel {
pub target: UdsTarget,
pub channel: tonic::transport::Channel,
}
#[cfg(unix)]
pub async fn connect_local_daemon_channel(
heddle_dir: &Path,
connect_timeout: Duration,
) -> Option<LocalDaemonChannel> {
let key: ProbeCacheKey = heddle_dir.to_path_buf();
if let Some(channel) = CHANNEL_CACHE.get(&key) {
let target = detect_local_daemon(heddle_dir)?;
return Some(LocalDaemonChannel { target, channel });
}
match build_channel(heddle_dir, connect_timeout).await {
Ok(LocalDaemonChannel { target, channel }) => {
CHANNEL_CACHE.insert(key, channel.clone());
Some(LocalDaemonChannel { target, channel })
}
Err(_) => None,
}
}
#[cfg(unix)]
async fn build_channel(
heddle_dir: &Path,
connect_timeout: Duration,
) -> std::result::Result<LocalDaemonChannel, ChannelError> {
CHANNEL_BUILD_COUNT
.get_or_init_with(&heddle_dir.to_path_buf(), || Arc::new(AtomicU64::new(0)))
.fetch_add(1, Ordering::Relaxed);
let target = detect_local_daemon(heddle_dir).ok_or(ChannelError::NoDaemon)?;
let endpoint = tonic::transport::Endpoint::try_from("http://heddle-uds")
.map_err(ChannelError::EndpointBuild)?
.connect_timeout(connect_timeout);
let socket_path = target.socket_path.clone();
let connector = tower::service_fn(move |_uri: tonic::transport::Uri| {
let socket_path = socket_path.clone();
async move {
let stream = tokio::net::UnixStream::connect(&socket_path).await?;
check_peer_uid_matches_self(&stream)?;
std::io::Result::Ok(hyper_util::rt::TokioIo::new(stream))
}
});
let channel = endpoint
.connect_with_connector(connector)
.await
.map_err(ChannelError::Connect)?;
let mut health = tonic_health::pb::health_client::HealthClient::new(channel.clone());
let request = tonic::Request::new(tonic_health::pb::HealthCheckRequest {
service: String::new(),
});
match tokio::time::timeout(connect_timeout, health.check(request)).await {
Ok(Ok(response)) => {
let status = response.into_inner().status;
if status == tonic_health::pb::health_check_response::ServingStatus::Serving as i32 {
Ok(LocalDaemonChannel { target, channel })
} else {
Err(ChannelError::HealthNotServing)
}
}
Ok(Err(status)) if status.code() == tonic::Code::Unimplemented => {
Ok(LocalDaemonChannel { target, channel })
}
Ok(Err(status)) => Err(ChannelError::HealthRpc(status)),
Err(_elapsed) => Err(ChannelError::HealthRpc(tonic::Status::deadline_exceeded(
"Health.Check timed out",
))),
}
}
#[cfg(unix)]
#[derive(Debug)]
#[allow(dead_code)]
enum ChannelError {
NoDaemon,
EndpointBuild(tonic::transport::Error),
Connect(tonic::transport::Error),
HealthRpc(tonic::Status),
HealthNotServing,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LocalDaemonProbe {
pub socket_path: PathBuf,
pub pid_path: PathBuf,
pub status: LocalDaemonStatus,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LocalDaemonStatus {
Running { pid: u32 },
Stale { pid: u32 },
Absent,
}
#[doc(hidden)]
static PROBE_RUN_COUNT: OnceMap<ProbeCacheKey, Arc<AtomicU64>> = OnceMap::new();
#[doc(hidden)]
pub fn probe_run_count(heddle_dir: &Path) -> u64 {
PROBE_RUN_COUNT
.get(&heddle_dir.to_path_buf())
.map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0)
}
pub fn probe(heddle_dir: &Path) -> LocalDaemonProbe {
PROBE_RUN_COUNT
.get_or_init_with(&heddle_dir.to_path_buf(), || Arc::new(AtomicU64::new(0)))
.fetch_add(1, Ordering::Relaxed);
let socket_path = heddle_dir.join("sockets").join("grpc.sock");
let pid_path = heddle_dir.join("sockets").join("grpc.pid");
let status = match read_pid(&pid_path) {
Some(pid) if pid_alive(pid) && pid_identity_verified(pid) => {
LocalDaemonStatus::Running { pid }
}
Some(pid) => LocalDaemonStatus::Stale { pid },
None => LocalDaemonStatus::Absent,
};
LocalDaemonProbe {
socket_path,
pid_path,
status,
}
}
fn read_pid(path: &Path) -> Option<u32> {
let raw = std::fs::read_to_string(path).ok()?;
let first = raw.lines().next().unwrap_or("").trim();
first
.parse::<u32>()
.ok()
.or_else(|| raw.trim().parse::<u32>().ok())
}
#[cfg(unix)]
fn pid_identity_verified(pid: u32) -> bool {
let Ok(pid) = i32::try_from(pid) else {
return false;
};
is_heddle_process(pid)
}
#[cfg(not(unix))]
fn pid_identity_verified(_pid: u32) -> bool {
false
}
#[cfg(unix)]
fn check_peer_uid_matches_self(stream: &tokio::net::UnixStream) -> io::Result<()> {
let creds = stream.peer_cred()?;
enforce_peer_uid(creds.uid(), unsafe { libc::geteuid() })
}
#[cfg(unix)]
fn enforce_peer_uid(peer_uid: u32, our_uid: u32) -> io::Result<()> {
if peer_uid != our_uid {
return Err(io::Error::new(
io::ErrorKind::PermissionDenied,
format!("daemon peer uid {peer_uid} does not match client uid {our_uid}"),
));
}
Ok(())
}
#[cfg(unix)]
fn pid_alive(pid: u32) -> bool {
unsafe { libc::kill(pid as libc::pid_t, 0) == 0 }
}
#[cfg(not(unix))]
fn pid_alive(_pid: u32) -> bool {
false
}
#[cfg(test)]
mod tests {
use tempfile::TempDir;
use super::*;
#[test]
fn absent_when_no_files() {
let temp = TempDir::new().unwrap();
let probe = probe(temp.path());
assert_eq!(probe.status, LocalDaemonStatus::Absent);
}
#[test]
fn stale_when_pidfile_holds_dead_pid() {
let temp = TempDir::new().unwrap();
let sockets = temp.path().join("sockets");
std::fs::create_dir_all(&sockets).unwrap();
std::fs::write(sockets.join("grpc.pid"), "2147483646").unwrap();
let probe = probe(temp.path());
assert!(matches!(probe.status, LocalDaemonStatus::Stale { .. }));
}
#[test]
fn running_when_pidfile_holds_self_pid() {
let temp = TempDir::new().unwrap();
let sockets = temp.path().join("sockets");
std::fs::create_dir_all(&sockets).unwrap();
std::fs::write(sockets.join("grpc.pid"), std::process::id().to_string()).unwrap();
let probe = probe(temp.path());
match probe.status {
LocalDaemonStatus::Running { pid } => assert_eq!(pid, std::process::id()),
other => panic!("expected Running, got {other:?}"),
}
}
#[cfg(unix)]
#[test]
fn stale_when_pidfile_holds_live_non_heddle_pid() {
let mut child = std::process::Command::new("/bin/sleep")
.arg("30")
.env_clear()
.spawn()
.expect("spawn sleep");
let temp = TempDir::new().unwrap();
let sockets = temp.path().join("sockets");
std::fs::create_dir_all(&sockets).unwrap();
std::fs::write(sockets.join("grpc.pid"), child.id().to_string()).unwrap();
let probe = probe(temp.path());
let _ = child.kill();
let _ = child.wait();
match probe.status {
LocalDaemonStatus::Stale { pid } => assert_eq!(pid, child.id()),
other => panic!("expected Stale for live non-Heddle pid, got {other:?}"),
}
}
#[test]
fn detect_returns_target_when_running() {
let temp = TempDir::new().unwrap();
let sockets = temp.path().join("sockets");
std::fs::create_dir_all(&sockets).unwrap();
std::fs::write(sockets.join("grpc.pid"), std::process::id().to_string()).unwrap();
let target = detect_local_daemon(temp.path()).expect("daemon detected");
assert_eq!(target.pid, std::process::id());
assert!(
target.socket_path.ends_with("sockets/grpc.sock"),
"socket path was {:?}",
target.socket_path
);
}
#[test]
fn detect_returns_none_when_absent() {
let temp = TempDir::new().unwrap();
assert!(detect_local_daemon(temp.path()).is_none());
}
#[cfg(unix)]
#[test]
fn enforce_peer_uid_accepts_matching_uid() {
assert!(enforce_peer_uid(1000, 1000).is_ok());
}
#[cfg(unix)]
#[test]
fn enforce_peer_uid_rejects_mismatched_uid() {
let err = enforce_peer_uid(1001, 1000).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::PermissionDenied);
}
#[cfg(unix)]
#[tokio::test]
async fn connect_probe_rejects_socketless_pidfile() {
let temp = TempDir::new().unwrap();
let sockets = temp.path().join("sockets");
std::fs::create_dir_all(&sockets).unwrap();
std::fs::write(sockets.join("grpc.pid"), std::process::id().to_string()).unwrap();
let result = detect_local_daemon_with_connect_probe(
temp.path(),
std::time::Duration::from_millis(50),
)
.await;
assert!(
result.is_none(),
"connect probe should reject when no listener is bound"
);
}
#[cfg(unix)]
#[tokio::test]
async fn connect_probe_accepts_live_listener() {
use tokio::net::UnixListener;
let temp = TempDir::new().unwrap();
let sockets = temp.path().join("sockets");
std::fs::create_dir_all(&sockets).unwrap();
let socket_path = sockets.join("grpc.sock");
let _listener = match UnixListener::bind(&socket_path) {
Ok(listener) => listener,
Err(err) if err.kind() == std::io::ErrorKind::PermissionDenied => {
eprintln!("skipping live-listener connect probe test: UDS bind denied: {err}");
return;
}
Err(err) => panic!("bind local daemon socket: {err}"),
};
std::fs::write(sockets.join("grpc.pid"), std::process::id().to_string()).unwrap();
let result = detect_local_daemon_with_connect_probe(
temp.path(),
std::time::Duration::from_millis(200),
)
.await;
assert!(
result.is_some(),
"connect probe should succeed when a listener is bound"
);
}
#[cfg(unix)]
#[tokio::test]
async fn check_peer_uid_matches_self_accepts_socketpair() {
let (peer, _local) = tokio::net::UnixStream::pair().expect("socketpair");
assert!(check_peer_uid_matches_self(&peer).is_ok());
}
#[cfg(unix)]
#[tokio::test]
async fn connect_channel_is_none_when_daemon_absent() {
let temp = TempDir::new().unwrap();
let result =
connect_local_daemon_channel(temp.path(), std::time::Duration::from_millis(50)).await;
assert!(result.is_none());
}
}