use std::{
path::{Path, PathBuf},
time::Duration,
};
use crate::util::OnceMap;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UdsTarget {
pub socket_path: PathBuf,
pub pid: u32,
}
type ProbeCacheKey = PathBuf;
static DETECT_CACHE: OnceMap<ProbeCacheKey, Option<UdsTarget>> = OnceMap::new();
pub fn detect_local_daemon(heddle_dir: &Path) -> Option<UdsTarget> {
let key: ProbeCacheKey = heddle_dir.to_path_buf();
DETECT_CACHE.get_or_init_with(&key, || {
let probe = probe(heddle_dir);
match probe.status {
LocalDaemonStatus::Running { pid } => Some(UdsTarget {
socket_path: probe.socket_path,
pid,
}),
LocalDaemonStatus::Stale { .. } | LocalDaemonStatus::Absent => None,
}
})
}
#[cfg(unix)]
pub async fn detect_local_daemon_with_connect_probe(
heddle_dir: &Path,
timeout: Duration,
) -> Option<UdsTarget> {
let target = detect_local_daemon(heddle_dir)?;
match tokio::time::timeout(
timeout,
tokio::net::UnixStream::connect(&target.socket_path),
)
.await
{
Ok(Ok(_stream)) => Some(target),
Ok(Err(_)) | Err(_) => None,
}
}
#[cfg(unix)]
static CHANNEL_CACHE: OnceMap<ProbeCacheKey, tonic::transport::Channel> = OnceMap::new();
#[cfg(unix)]
#[derive(Debug, Clone)]
pub struct LocalDaemonChannel {
pub target: UdsTarget,
pub channel: tonic::transport::Channel,
}
#[cfg(unix)]
pub async fn connect_local_daemon_channel(
heddle_dir: &Path,
connect_timeout: Duration,
) -> Option<LocalDaemonChannel> {
let key: ProbeCacheKey = heddle_dir.to_path_buf();
if let Some(channel) = CHANNEL_CACHE.get(&key) {
let target = detect_local_daemon(heddle_dir)?;
return Some(LocalDaemonChannel { target, channel });
}
match build_channel(heddle_dir, connect_timeout).await {
Ok(LocalDaemonChannel { target, channel }) => {
CHANNEL_CACHE.insert(key, channel.clone());
Some(LocalDaemonChannel { target, channel })
}
Err(_) => None,
}
}
#[cfg(unix)]
async fn build_channel(
heddle_dir: &Path,
connect_timeout: Duration,
) -> std::result::Result<LocalDaemonChannel, ChannelError> {
let target = detect_local_daemon(heddle_dir).ok_or(ChannelError::NoDaemon)?;
let endpoint = tonic::transport::Endpoint::try_from("http://heddle-uds")
.map_err(ChannelError::EndpointBuild)?
.connect_timeout(connect_timeout);
let socket_path = target.socket_path.clone();
let connector = tower::service_fn(move |_uri: tonic::transport::Uri| {
let socket_path = socket_path.clone();
async move {
let stream = tokio::net::UnixStream::connect(&socket_path).await?;
std::io::Result::Ok(hyper_util::rt::TokioIo::new(stream))
}
});
let channel = endpoint
.connect_with_connector(connector)
.await
.map_err(ChannelError::Connect)?;
let mut health = tonic_health::pb::health_client::HealthClient::new(channel.clone());
let request = tonic::Request::new(tonic_health::pb::HealthCheckRequest {
service: String::new(),
});
match tokio::time::timeout(connect_timeout, health.check(request)).await {
Ok(Ok(response)) => {
let status = response.into_inner().status;
if status == tonic_health::pb::health_check_response::ServingStatus::Serving as i32 {
Ok(LocalDaemonChannel { target, channel })
} else {
Err(ChannelError::HealthNotServing)
}
}
Ok(Err(status)) if status.code() == tonic::Code::Unimplemented => {
Ok(LocalDaemonChannel { target, channel })
}
Ok(Err(status)) => Err(ChannelError::HealthRpc(status)),
Err(_elapsed) => Err(ChannelError::HealthRpc(tonic::Status::deadline_exceeded(
"Health.Check timed out",
))),
}
}
#[cfg(unix)]
#[derive(Debug)]
#[allow(dead_code)]
enum ChannelError {
NoDaemon,
EndpointBuild(tonic::transport::Error),
Connect(tonic::transport::Error),
HealthRpc(tonic::Status),
HealthNotServing,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LocalDaemonProbe {
pub socket_path: PathBuf,
pub pid_path: PathBuf,
pub status: LocalDaemonStatus,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LocalDaemonStatus {
Running { pid: u32 },
Stale { pid: u32 },
Absent,
}
pub fn probe(heddle_dir: &Path) -> LocalDaemonProbe {
let socket_path = heddle_dir.join("sockets").join("grpc.sock");
let pid_path = heddle_dir.join("sockets").join("grpc.pid");
let status = match read_pid(&pid_path) {
Some(pid) if pid_alive(pid) => LocalDaemonStatus::Running { pid },
Some(pid) => LocalDaemonStatus::Stale { pid },
None => LocalDaemonStatus::Absent,
};
LocalDaemonProbe {
socket_path,
pid_path,
status,
}
}
fn read_pid(path: &Path) -> Option<u32> {
let raw = std::fs::read_to_string(path).ok()?;
let first = raw.lines().next().unwrap_or("").trim();
first
.parse::<u32>()
.ok()
.or_else(|| raw.trim().parse::<u32>().ok())
}
#[cfg(unix)]
fn pid_alive(pid: u32) -> bool {
unsafe { libc::kill(pid as libc::pid_t, 0) == 0 }
}
#[cfg(not(unix))]
fn pid_alive(_pid: u32) -> bool {
false
}
#[cfg(test)]
mod tests {
use tempfile::TempDir;
use super::*;
#[test]
fn absent_when_no_files() {
let temp = TempDir::new().unwrap();
let probe = probe(temp.path());
assert_eq!(probe.status, LocalDaemonStatus::Absent);
}
#[test]
fn stale_when_pidfile_holds_dead_pid() {
let temp = TempDir::new().unwrap();
let sockets = temp.path().join("sockets");
std::fs::create_dir_all(&sockets).unwrap();
std::fs::write(sockets.join("grpc.pid"), "2147483646").unwrap();
let probe = probe(temp.path());
assert!(matches!(probe.status, LocalDaemonStatus::Stale { .. }));
}
#[test]
fn running_when_pidfile_holds_self_pid() {
let temp = TempDir::new().unwrap();
let sockets = temp.path().join("sockets");
std::fs::create_dir_all(&sockets).unwrap();
std::fs::write(sockets.join("grpc.pid"), std::process::id().to_string()).unwrap();
let probe = probe(temp.path());
match probe.status {
LocalDaemonStatus::Running { pid } => assert_eq!(pid, std::process::id()),
other => panic!("expected Running, got {other:?}"),
}
}
#[test]
fn detect_returns_target_when_running() {
let temp = TempDir::new().unwrap();
let sockets = temp.path().join("sockets");
std::fs::create_dir_all(&sockets).unwrap();
std::fs::write(sockets.join("grpc.pid"), std::process::id().to_string()).unwrap();
let target = detect_local_daemon(temp.path()).expect("daemon detected");
assert_eq!(target.pid, std::process::id());
assert!(
target.socket_path.ends_with("sockets/grpc.sock"),
"socket path was {:?}",
target.socket_path
);
}
#[test]
fn detect_returns_none_when_absent() {
let temp = TempDir::new().unwrap();
assert!(detect_local_daemon(temp.path()).is_none());
}
#[cfg(unix)]
#[tokio::test]
async fn connect_probe_rejects_socketless_pidfile() {
let temp = TempDir::new().unwrap();
let sockets = temp.path().join("sockets");
std::fs::create_dir_all(&sockets).unwrap();
std::fs::write(sockets.join("grpc.pid"), std::process::id().to_string()).unwrap();
let result = detect_local_daemon_with_connect_probe(
temp.path(),
std::time::Duration::from_millis(50),
)
.await;
assert!(
result.is_none(),
"connect probe should reject when no listener is bound"
);
}
#[cfg(unix)]
#[tokio::test]
async fn connect_probe_accepts_live_listener() {
use tokio::net::UnixListener;
let temp = TempDir::new().unwrap();
let sockets = temp.path().join("sockets");
std::fs::create_dir_all(&sockets).unwrap();
let socket_path = sockets.join("grpc.sock");
let _listener = UnixListener::bind(&socket_path).unwrap();
std::fs::write(sockets.join("grpc.pid"), std::process::id().to_string()).unwrap();
let result = detect_local_daemon_with_connect_probe(
temp.path(),
std::time::Duration::from_millis(200),
)
.await;
assert!(
result.is_some(),
"connect probe should succeed when a listener is bound"
);
}
#[cfg(unix)]
#[tokio::test]
async fn connect_channel_is_none_when_daemon_absent() {
let temp = TempDir::new().unwrap();
let result =
connect_local_daemon_channel(temp.path(), std::time::Duration::from_millis(50)).await;
assert!(result.is_none());
}
}