use std::{os::unix::fs::PermissionsExt, path::Path, sync::Arc, time::Duration};
use tokio::{
net::UnixListener,
sync::{watch, Semaphore},
};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use crate::{
config::DaemonConfig,
forwarder::Forwarder,
metrics::MetricsStore,
session::{run_session, SessionConfig},
validator::Validator,
};
const SHUTDOWN_DRAIN_TIMEOUT: Duration = Duration::from_secs(30);
#[allow(
clippy::struct_field_names,
reason = "listener field is the natural name for a UnixListener inside Listener"
)]
pub struct Listener {
listener: UnixListener,
semaphore: Arc<Semaphore>,
cfg: DaemonConfig,
forwarder: Forwarder,
local_hostname: Arc<str>,
}
impl Listener {
pub fn bind(cfg: DaemonConfig, forwarder: Forwarder) -> std::io::Result<Self> {
let path = Path::new(&cfg.listen_socket);
if path.exists() {
std::fs::remove_file(path)?;
}
let listener = UnixListener::bind(path)?;
apply_socket_permissions(path, &cfg.socket_mode)?;
let semaphore = Arc::new(Semaphore::new(cfg.max_connections));
let local_hostname = detect_hostname();
info!(socket = %cfg.listen_socket, "listening for client connections");
Ok(Self {
listener,
semaphore,
cfg,
forwarder,
local_hostname,
})
}
pub async fn run(
self,
shutdown: CancellationToken,
validator_rx: watch::Receiver<Arc<Validator>>,
metrics: Arc<MetricsStore>,
) {
let Self {
listener,
semaphore,
cfg,
forwarder,
local_hostname,
} = self;
loop {
let permit = tokio::select! {
biased;
() = shutdown.cancelled() => break,
result = semaphore.clone().acquire_owned() => if let Ok(p) = result { p } else {
error!("connection semaphore closed — shutting down accept loop");
return;
},
};
let (stream, addr) = tokio::select! {
biased;
() = shutdown.cancelled() => break,
result = listener.accept() => match result {
Ok(pair) => pair,
Err(e) => {
error!(error = %e, "accept() failed");
drop(permit);
continue;
}
},
};
let peer = addr
.as_pathname()
.map_or_else(|| "<anonymous>".to_owned(), |p| p.display().to_string());
info!(peer = %peer, "accepted connection");
let session_cfg = SessionConfig {
framing: cfg.framing,
max_message_size: cfg.max_message_size,
sender_mode: cfg.sender,
local_hostname: Arc::clone(&local_hostname),
peer: Arc::from(peer.as_str()),
};
let vr = validator_rx.clone();
let fwd = forwarder.clone();
let m = Arc::clone(&metrics);
tokio::spawn(async move {
run_session(stream, session_cfg, vr, fwd, m).await;
info!(peer = %peer, "connection closed");
drop(permit);
});
}
info!("shutting down; waiting for active sessions to finish");
let total = u32::try_from(cfg.max_connections).unwrap_or(u32::MAX);
match tokio::time::timeout(SHUTDOWN_DRAIN_TIMEOUT, semaphore.acquire_many(total)).await {
Ok(Ok(_)) => info!("all sessions finished; shutdown complete"),
Ok(Err(_)) => {} Err(_) => warn!("graceful shutdown timed out; forcing exit"),
};
}
}
pub(crate) fn detect_hostname() -> Arc<str> {
std::fs::read_to_string("/etc/hostname")
.map(|s| s.trim().to_owned())
.ok()
.filter(|s| !s.is_empty())
.unwrap_or_else(|| "-".to_owned())
.into()
}
pub(crate) fn apply_socket_permissions(path: &Path, mode_str: &str) -> std::io::Result<()> {
let trimmed = mode_str.trim_start_matches('0');
let trimmed = if trimmed.is_empty() { "0" } else { trimmed };
let mode = u32::from_str_radix(trimmed, 8).unwrap_or_else(|_| {
warn!(mode = %mode_str, "invalid socket_mode; using 0660");
0o660
});
std::fs::set_permissions(path, std::fs::Permissions::from_mode(mode))
}