use std::{future::Future, num::NonZeroUsize, sync::Arc, time::Duration};
use tokio::time::timeout;
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use tracing::{info, warn};
use crate::{
blocklist::{fetch::Fetcher, scheduler::BlocklistScheduler},
config::Config,
error::Result,
resolver::{
pipeline::{engine::build_engine, listener::DnsListeners, middleware::ProtectiveConfig},
state::ResolverState,
upstream::{SharedUpstreamPool, UpstreamConfig},
},
storage::{Db, forward_zones::ForwardZoneRepository, upstreams::UpstreamRepository},
telemetry::{
LiveLog, QUERY_LOG_CHANNEL_CAPACITY, QueryLogPurger, QueryLogWriter, Stats, TelemetrySink,
},
web::{AdminServer, AppState},
};
pub const DEFAULT_DRAIN_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Debug, Clone)]
pub struct RuntimeAddrs {
pub dns_udp: Vec<std::net::SocketAddr>,
pub admin: std::net::SocketAddr,
}
pub struct App {
config: Config,
shutdown_token: CancellationToken,
tracker: TaskTracker,
drain_timeout: Duration,
}
impl App {
pub fn new(config: Config) -> Self {
Self {
config,
shutdown_token: CancellationToken::new(),
tracker: TaskTracker::new(),
drain_timeout: DEFAULT_DRAIN_TIMEOUT,
}
}
pub fn with_drain_timeout(mut self, timeout: Duration) -> Self {
self.drain_timeout = timeout;
self
}
pub fn drain_timeout(&self) -> Duration {
self.drain_timeout
}
pub fn cancellation_token(&self) -> CancellationToken {
self.shutdown_token.clone()
}
fn spawn_subsystem<F, Fut>(&self, name: &'static str, f: F)
where
F: FnOnce(CancellationToken) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let token = self.shutdown_token.clone();
self.tracker.spawn(async move {
tracing::debug!(subsystem = name, "subsystem started");
f(token).await;
tracing::debug!(subsystem = name, "subsystem stopped");
});
}
pub async fn run_until_ready(
self,
shutdown: impl Future<Output = ()>,
on_ready: impl FnOnce(RuntimeAddrs),
) -> Result<()> {
let db = Db::connect(&self.config.db_path).await?;
let state = ResolverState::hydrate(&db).await?;
let scheduler =
BlocklistScheduler::new(db.blocklists(), Arc::clone(&state), Fetcher::new());
scheduler.load_from_cache().await;
let rows = db.upstreams().list_enabled().await?;
let configs: Vec<_> = rows
.iter()
.filter_map(|r| match UpstreamConfig::try_from(r) {
Ok(c) => Some(c),
Err(_) => {
tracing::warn!(addr = %r.address, "skipping unmappable upstream");
None
}
})
.collect();
let settings = state.settings();
let pool = Arc::new(SharedUpstreamPool::new(
crate::resolver::pipeline::engine::build_upstream_pool(
&configs,
&self.tracker,
settings.upstream_selection_strategy,
settings.upstream_parallel_fanout,
)
.await,
));
drop(settings);
let forward_zone_rows = db.forward_zones().list_enabled().await?;
let forward_zones =
crate::resolver::forward_zone::ForwardZoneSet::build(&forward_zone_rows, &self.tracker)
.await;
state.store_forward_zones(forward_zones);
let query_log_state = Arc::clone(&state);
let (query_log_tx, query_log_rx) = tokio::sync::mpsc::channel(QUERY_LOG_CHANNEL_CAPACITY);
let telemetry = Arc::new(
TelemetrySink::new(Arc::new(LiveLog::default()), Arc::new(Stats::new()))
.with_query_log(query_log_tx, Arc::clone(&query_log_state)),
);
let reverse = Arc::new(crate::resolver::reverse::ReverseResolver::new(
crate::resolver::pipeline::engine::build_internal_service(
Arc::clone(&state),
Arc::clone(&pool),
),
));
let app_state = AppState {
db: db.clone(),
resolver: Arc::clone(&state),
telemetry: Arc::clone(&telemetry),
refresh: scheduler.trigger(),
cookie_policy: self.config.session_cookie_secure,
csrf_key: crate::web::random_csrf_key(),
setup_done: Arc::new(std::sync::atomic::AtomicBool::new(false)),
upstream_pool: Arc::clone(&pool),
tracker: self.tracker.clone(),
started_at: std::time::Instant::now(),
reverse,
};
let engine = build_engine(state, pool, telemetry, &ProtectiveConfig::default());
let udp_sockets_per_addr = std::thread::available_parallelism()
.map(NonZeroUsize::get)
.unwrap_or(1);
let listeners = DnsListeners::bind(&self.config.dns_addrs, udp_sockets_per_addr)?;
let dns_udp = listeners.udp_local_addrs();
listeners.serve(engine, self.shutdown_token.clone(), &self.tracker);
let admin = AdminServer::bind(self.config.admin_addr, app_state).await?;
let admin_addr = admin.local_addr()?;
self.spawn_subsystem("web-admin", move |token| async move {
admin.serve(token).await;
});
self.spawn_subsystem("blocklist-refresh", move |token| async move {
scheduler.run(token).await;
});
let query_log_writer =
QueryLogWriter::new(query_log_rx, db.query_log(), Arc::clone(&query_log_state));
self.spawn_subsystem("query-log-writer", move |token| async move {
query_log_writer.run(token).await;
});
let query_log_purger = QueryLogPurger::new(db.query_log(), query_log_state);
self.spawn_subsystem("query-log-purge", move |token| async move {
query_log_purger.run(token).await;
});
self.tracker.close();
let unique_dns: std::collections::BTreeSet<_> = dns_udp.iter().copied().collect();
info!(
dns_addrs = ?unique_dns,
admin_addr = %admin_addr,
"runtime ready, awaiting shutdown signal",
);
on_ready(RuntimeAddrs {
dns_udp,
admin: admin_addr,
});
shutdown.await;
info!("shutdown signal received, draining…");
self.shutdown_token.cancel();
match timeout(self.drain_timeout, self.tracker.wait()).await {
Ok(()) => {
info!("all tasks drained cleanly");
}
Err(_elapsed) => {
warn!(
drain_timeout = ?self.drain_timeout,
"drain timeout elapsed; some tasks may still be running — forcing exit",
);
}
}
Ok(())
}
async fn run_until_shutdown(self, shutdown: impl Future<Output = ()>) -> Result<()> {
self.run_until_ready(shutdown, |_| {}).await
}
pub async fn run(self) -> Result<()> {
let signal = make_shutdown_signal();
self.run_until_shutdown(signal).await
}
}
impl From<Config> for App {
fn from(config: Config) -> Self {
Self::new(config)
}
}
#[cfg(unix)]
async fn make_shutdown_signal() {
use tokio::signal::unix::{SignalKind, signal};
let mut sigterm = signal(SignalKind::terminate()).expect("failed to register SIGTERM handler");
let mut sigint = signal(SignalKind::interrupt()).expect("failed to register SIGINT handler");
tokio::select! {
_ = sigterm.recv() => {
info!("received SIGTERM");
}
_ = sigint.recv() => {
info!("received SIGINT");
}
}
}
#[cfg(not(unix))]
async fn make_shutdown_signal() {
tokio::signal::ctrl_c()
.await
.expect("failed to listen for Ctrl-C");
info!("received Ctrl-C");
}
#[cfg(test)]
mod tests {
use std::{net::SocketAddr, path::PathBuf, time::Duration};
use tempfile::TempDir;
use super::*;
use crate::config::{Config, SessionCookieSecurePolicy};
async fn run_config() -> (TempDir, Config) {
let dir = TempDir::new().expect("temp dir");
let db_path = dir.path().join("test.db");
let _ = Db::connect(&db_path).await.expect("create test db");
let config = Config {
dns_addrs: vec!["127.0.0.1:0".parse::<SocketAddr>().unwrap()],
admin_addr: "127.0.0.1:0".parse::<SocketAddr>().unwrap(),
db_path,
session_cookie_secure: SessionCookieSecurePolicy::Never,
};
(dir, config)
}
fn test_config() -> Config {
Config {
dns_addrs: vec!["127.0.0.1:5353".parse::<SocketAddr>().unwrap()],
admin_addr: "127.0.0.1:18080".parse::<SocketAddr>().unwrap(),
db_path: PathBuf::from(":memory:"),
session_cookie_secure: SessionCookieSecurePolicy::Never,
}
}
#[test]
fn new_sets_default_drain_timeout() {
let app = App::new(test_config());
assert_eq!(app.drain_timeout(), DEFAULT_DRAIN_TIMEOUT);
}
#[test]
fn with_drain_timeout_overrides() {
let app = App::new(test_config()).with_drain_timeout(Duration::from_millis(50));
assert_eq!(app.drain_timeout(), Duration::from_millis(50));
}
#[test]
fn from_config_builds_app() {
let cfg = test_config();
let app = App::from(cfg);
assert_eq!(app.drain_timeout(), DEFAULT_DRAIN_TIMEOUT);
}
#[test]
fn cancellation_token_is_cloneable() {
let app = App::new(test_config());
let token1 = app.cancellation_token();
let token2 = app.cancellation_token();
token1.cancel();
assert!(token2.is_cancelled());
}
#[tokio::test]
async fn run_with_immediate_shutdown_returns_ok() {
let (_dir, config) = run_config().await;
let app = App::new(config).with_drain_timeout(Duration::from_millis(500));
let result = app.run_until_shutdown(async {}).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn run_with_token_shutdown_returns_ok() {
let (_dir, config) = run_config().await;
let app = App::new(config).with_drain_timeout(Duration::from_millis(500));
let token = app.cancellation_token();
let result = app
.run_until_shutdown(async move {
tokio::task::yield_now().await;
token.cancel();
})
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn run_returns_ok_even_when_drain_times_out() {
use std::time::Instant;
let (_dir, config) = run_config().await;
let drain_timeout = Duration::from_millis(80);
let app = App::new(config).with_drain_timeout(drain_timeout);
app.spawn_subsystem("rogue", |_token| async {
tokio::time::sleep(Duration::from_secs(60)).await;
});
let start = Instant::now();
let result = app.run_until_shutdown(async {}).await;
let elapsed = start.elapsed();
assert!(result.is_ok());
assert!(
elapsed >= drain_timeout,
"should have waited for the drain timeout, waited {elapsed:?}"
);
assert!(
elapsed < Duration::from_secs(5),
"should not have waited for the rogue task, waited {elapsed:?}"
);
}
#[tokio::test]
async fn full_stack_serves_dns_and_wizard_before_admin_exists() {
use tokio::net::UdpSocket;
use tokio::sync::oneshot;
use crate::codec::{header::Header, name::Name, writer::Writer};
use crate::storage::local_records::{LocalRecordRepository, NewLocalRecord, RecordType};
let dir = TempDir::new().expect("temp dir");
let db_path = dir.path().join("test.db");
let db = Db::connect(&db_path).await.expect("create db");
db.local_records()
.add(NewLocalRecord {
name: "router.home.lan".to_string(),
record_type: RecordType::A,
value: "192.168.1.1".to_string(),
ttl: 300,
})
.await
.expect("seed local record");
drop(db);
let config = Config {
dns_addrs: vec!["127.0.0.1:0".parse::<SocketAddr>().unwrap()],
admin_addr: "127.0.0.1:0".parse::<SocketAddr>().unwrap(),
db_path,
session_cookie_secure: SessionCookieSecurePolicy::Never,
};
let app = App::new(config).with_drain_timeout(Duration::from_secs(2));
let (ready_tx, ready_rx) = oneshot::channel::<RuntimeAddrs>();
let (stop_tx, stop_rx) = oneshot::channel::<()>();
let handle = tokio::spawn(async move {
app.run_until_ready(
async move {
let _ = stop_rx.await;
},
move |addrs| {
let _ = ready_tx.send(addrs);
},
)
.await
});
let addrs = timeout(Duration::from_secs(5), ready_rx)
.await
.expect("startup within 5s")
.expect("ready signal");
let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.unwrap();
let base = format!("http://{}", addrs.admin);
let root = client.get(format!("{base}/")).send().await.expect("GET /");
assert_eq!(root.status(), 303, "root redirects while no admin exists");
assert_eq!(root.headers().get("location").unwrap(), "/setup");
let setup = client
.get(format!("{base}/setup"))
.send()
.await
.expect("GET /setup");
assert_eq!(setup.status(), 200, "wizard must render");
assert!(
setup.text().await.unwrap().contains("Welcome"),
"wizard page content",
);
let server = addrs.dns_udp[0];
let sock = UdpSocket::bind("127.0.0.1:0").await.expect("client socket");
sock.connect(server).await.expect("connect");
let mut w = Writer::with_capacity(64);
Header::new(0xBEEF)
.with_qdcount(1)
.with_rd(true)
.write(&mut w);
let qname: Name = "router.home.lan.".parse().expect("name");
qname.write(&mut w);
w.write_u16(1u16); w.write_u16(1u16); sock.send(&w.finish()).await.expect("send query");
let mut buf = vec![0u8; 512];
let n = timeout(Duration::from_secs(5), sock.recv(&mut buf))
.await
.expect("response within 5s")
.expect("recv");
let resp = &buf[..n];
assert_eq!(u16::from_be_bytes([resp[0], resp[1]]), 0xBEEF, "txn id");
assert_ne!(resp[2] & 0x80, 0, "QR bit must be set (response)");
assert_eq!(resp[3] & 0x0f, 0, "RCODE must be NOERROR");
let ancount = u16::from_be_bytes([resp[6], resp[7]]);
assert!(ancount >= 1, "at least one answer record");
assert!(
resp.windows(4).any(|bytes| bytes == [192, 168, 1, 1]),
"answer must carry the seeded A address",
);
let _ = stop_tx.send(());
let result = timeout(Duration::from_secs(5), handle)
.await
.expect("App shuts down within 5s")
.expect("join");
assert!(result.is_ok());
}
}