use std::{
net::Ipv4Addr,
net::Ipv6Addr,
sync::{
Arc,
atomic::{AtomicI64, Ordering},
},
};
use arc_swap::ArcSwap;
use crate::{
codec::synth::BlockMode,
resolver::{
self,
cache::DnsCache,
forward_zone::ForwardZoneSet,
local::{LocalMatcher, LocalRecords, RecordData},
matchset::{AttributedSet, MatchSet},
},
storage::{
Db,
lists::{AllowlistRepository, BlacklistRepository},
local_records::{LocalRecordRepository, RecordType},
settings::{BlockingMode, SelectionStrategy, Settings, SettingsRepository},
},
time::Clock,
};
#[derive(Debug, Clone, PartialEq)]
pub struct RuntimeSettings {
pub cache_min_ttl: u32,
pub cache_max_ttl: u32,
pub cache_capacity: u64,
pub negative_ttl_cap: u32,
pub block_mode: BlockMode,
pub blocklist_refresh_interval: u32,
pub query_log_enabled: bool,
pub query_log_retention_days: u32,
pub upstream_selection_strategy: SelectionStrategy,
pub upstream_parallel_fanout: u32,
}
impl From<&Settings> for RuntimeSettings {
fn from(s: &Settings) -> Self {
let block_mode = match s.blocking_mode {
BlockingMode::NxDomain => BlockMode::NxDomain,
BlockingMode::NullIp => BlockMode::null_ip(),
BlockingMode::Custom => BlockMode::Address {
v4: s.custom_block_ipv4.unwrap_or(Ipv4Addr::UNSPECIFIED),
v6: s.custom_block_ipv6.unwrap_or(Ipv6Addr::UNSPECIFIED),
},
};
Self {
cache_min_ttl: s.cache_min_ttl,
cache_max_ttl: s.cache_max_ttl,
cache_capacity: s.cache_capacity,
negative_ttl_cap: s.cache_negative_ttl_cap,
block_mode,
blocklist_refresh_interval: s.blocklist_refresh_interval,
query_log_enabled: s.query_log_enabled,
query_log_retention_days: s.query_log_retention_days,
upstream_selection_strategy: s.upstream_selection_strategy,
upstream_parallel_fanout: s.upstream_parallel_fanout,
}
}
}
pub struct ResolverState {
blacklist: MatchSet,
allowlist: MatchSet,
blocklist: AttributedSet,
local: LocalMatcher,
cache: DnsCache,
settings: ArcSwap<RuntimeSettings>,
forward_zones: ArcSwap<ForwardZoneSet>,
paused_until: AtomicI64,
}
impl ResolverState {
#[must_use]
pub fn blacklist(&self) -> &MatchSet {
&self.blacklist
}
#[must_use]
pub fn allowlist(&self) -> &MatchSet {
&self.allowlist
}
#[must_use]
pub fn blocklist(&self) -> &AttributedSet {
&self.blocklist
}
#[must_use]
pub fn local(&self) -> &LocalMatcher {
&self.local
}
#[must_use]
pub fn cache(&self) -> &DnsCache {
&self.cache
}
#[must_use]
pub fn settings(&self) -> arc_swap::Guard<Arc<RuntimeSettings>> {
self.settings.load()
}
#[must_use]
pub fn settings_full(&self) -> Arc<RuntimeSettings> {
self.settings.load_full()
}
pub fn store_settings(&self, new_settings: RuntimeSettings) {
self.settings.store(Arc::new(new_settings));
}
#[must_use]
pub fn forward_zones(&self) -> Arc<ForwardZoneSet> {
self.forward_zones.load_full()
}
pub fn store_forward_zones(&self, set: ForwardZoneSet) {
self.forward_zones.store(Arc::new(set));
}
pub fn pause_for_secs(&self, secs: i64) {
let deadline = if secs > 0 {
Clock::now_secs() + secs
} else {
0
};
self.paused_until.store(deadline, Ordering::Relaxed);
}
pub fn resume(&self) {
self.paused_until.store(0, Ordering::Relaxed);
}
#[must_use]
pub fn blocking_paused(&self) -> bool {
self.paused_until().is_some()
}
#[must_use]
pub fn paused_until(&self) -> Option<i64> {
let deadline = self.paused_until.load(Ordering::Relaxed);
(deadline != 0 && Clock::now_secs() < deadline).then_some(deadline)
}
pub async fn hydrate(db: &Db) -> resolver::Result<Arc<Self>> {
let settings = db.settings().get().await?;
let cache = DnsCache::new(
settings.cache_capacity,
settings.cache_min_ttl,
settings.cache_max_ttl,
);
let runtime_settings = RuntimeSettings::from(&settings);
let blacklist = db.blacklist().load_all().await?.into_iter().collect();
let allowlist = db.allowlist().load_all().await?.into_iter().collect();
let blocklist = AttributedSet::empty();
let local_rows = db.local_records().load_all().await?;
let local = LocalMatcher::new(build_local_records(local_rows)?);
Ok(Arc::new(Self {
blacklist,
allowlist,
blocklist,
local,
cache,
settings: ArcSwap::from_pointee(runtime_settings),
forward_zones: ArcSwap::from_pointee(ForwardZoneSet::empty()),
paused_until: AtomicI64::new(0),
}))
}
}
pub fn build_local_records(
rows: Vec<crate::storage::local_records::LocalRecord>,
) -> resolver::Result<LocalRecords> {
let mut builder = LocalRecords::builder();
for row in rows {
let data = match row.record_type {
RecordType::A => {
let addr: Ipv4Addr = row.value.parse().map_err(|e| {
resolver::Error::InvalidLocalRecord(format!(
"record {:?} has invalid A value {:?}: {e}",
row.name, row.value
))
})?;
RecordData::A(addr)
}
RecordType::Aaaa => {
let addr: Ipv6Addr = row.value.parse().map_err(|e| {
resolver::Error::InvalidLocalRecord(format!(
"record {:?} has invalid AAAA value {:?}: {e}",
row.name, row.value
))
})?;
RecordData::Aaaa(addr)
}
};
let name = row.name.trim_end_matches('.');
builder.add(name, data, row.ttl).map_err(|e| {
resolver::Error::InvalidLocalRecord(format!(
"could not add local record {:?}: {e}",
row.name
))
})?;
}
Ok(builder.build())
}
impl std::fmt::Debug for ResolverState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResolverState")
.field("blacklist", &self.blacklist)
.field("allowlist", &self.allowlist)
.field("blocklist", &self.blocklist)
.field("local", &self.local)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use std::net::{Ipv4Addr, Ipv6Addr};
use super::*;
use crate::{
codec::{message::Qtype, name::Name},
resolver::local::LocalMatch,
storage::{
lists::{AllowlistRepository, BlacklistRepository},
local_records::{LocalRecordRepository, NewLocalRecord, RecordType},
settings::{BlockingMode, Settings},
},
};
fn name(s: &str) -> Name {
s.parse().expect("valid domain name")
}
fn base_settings() -> Settings {
Settings {
cache_min_ttl: 1,
cache_max_ttl: 86400,
cache_negative_ttl_cap: 3600,
cache_capacity: 100_000,
blocking_mode: BlockingMode::NullIp,
custom_block_ipv4: None,
custom_block_ipv6: None,
blocklist_refresh_interval: 86400,
ui_theme: "auto".to_owned(),
query_log_enabled: true,
query_log_retention_days: 30,
upstream_selection_strategy: SelectionStrategy::Random,
upstream_parallel_fanout: 2,
}
}
#[test]
fn runtime_settings_from_nxdomain() {
let mut s = base_settings();
s.blocking_mode = BlockingMode::NxDomain;
let rs = RuntimeSettings::from(&s);
assert_eq!(rs.block_mode, BlockMode::NxDomain);
assert_eq!(rs.negative_ttl_cap, 3600);
assert_eq!(rs.cache_max_ttl, 86400);
assert_eq!(rs.blocklist_refresh_interval, 86400);
}
#[test]
fn runtime_settings_from_null_ip() {
let s = base_settings();
let rs = RuntimeSettings::from(&s);
assert_eq!(rs.block_mode, BlockMode::null_ip());
}
#[test]
fn runtime_settings_carries_query_log_fields() {
let mut s = base_settings();
s.query_log_enabled = false;
s.query_log_retention_days = 7;
let rs = RuntimeSettings::from(&s);
assert!(!rs.query_log_enabled);
assert_eq!(rs.query_log_retention_days, 7);
}
#[test]
fn runtime_settings_from_custom_with_ips() {
let mut s = base_settings();
s.blocking_mode = BlockingMode::Custom;
s.custom_block_ipv4 = Some("203.0.113.1".parse().unwrap());
s.custom_block_ipv6 = Some("2001:db8::1".parse().unwrap());
let rs = RuntimeSettings::from(&s);
assert_eq!(
rs.block_mode,
BlockMode::Address {
v4: "203.0.113.1".parse().unwrap(),
v6: "2001:db8::1".parse().unwrap(),
}
);
}
#[test]
fn runtime_settings_from_custom_none_ips_falls_back_to_unspecified() {
let mut s = base_settings();
s.blocking_mode = BlockingMode::Custom;
s.custom_block_ipv4 = None;
s.custom_block_ipv6 = None;
let rs = RuntimeSettings::from(&s);
assert_eq!(
rs.block_mode,
BlockMode::Address {
v4: Ipv4Addr::UNSPECIFIED,
v6: Ipv6Addr::UNSPECIFIED,
}
);
}
#[test]
fn runtime_settings_from_custom_partial_ips() {
let mut s = base_settings();
s.blocking_mode = BlockingMode::Custom;
s.custom_block_ipv4 = Some("10.0.0.1".parse().unwrap());
s.custom_block_ipv6 = None;
let rs = RuntimeSettings::from(&s);
assert_eq!(
rs.block_mode,
BlockMode::Address {
v4: "10.0.0.1".parse().unwrap(),
v6: Ipv6Addr::UNSPECIFIED,
}
);
}
#[tokio::test]
async fn hydration_reflects_blacklist_and_allowlist() {
let (_dir, db) = crate::test_support::temp_db().await;
let bl = db.blacklist();
bl.add("ads.example.com").await.expect("add to blacklist");
bl.add("tracker.evil.net").await.expect("add to blacklist");
let al = db.allowlist();
al.add("safe.example.com").await.expect("add to allowlist");
let state = ResolverState::hydrate(&db).await.expect("hydrate");
assert!(state.blacklist().contains(&name("ads.example.com")));
assert!(state.blacklist().contains(&name("tracker.evil.net")));
assert!(!state.blacklist().contains(&name("safe.example.com")));
assert!(state.allowlist().contains(&name("safe.example.com")));
assert!(!state.allowlist().contains(&name("ads.example.com")));
}
#[tokio::test]
async fn hydration_blocklist_is_empty() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
assert!(
state.blocklist().is_empty(),
"blocklist must be empty right after hydration"
);
}
#[tokio::test]
async fn hydration_reflects_local_records() {
let (_dir, db) = crate::test_support::temp_db().await;
let repo = db.local_records();
repo.add(NewLocalRecord {
name: "router.home.lan".to_owned(),
record_type: RecordType::A,
value: "192.168.1.1".to_owned(),
ttl: 300,
})
.await
.expect("add A record");
repo.add(NewLocalRecord {
name: "router.home.lan".to_owned(),
record_type: RecordType::Aaaa,
value: "fd00::1".to_owned(),
ttl: 600,
})
.await
.expect("add AAAA record");
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let a_match = state.local().lookup(&name("router.home.lan"), Qtype::A);
assert!(
matches!(a_match, LocalMatch::Answer { data: crate::resolver::local::RecordData::A(addr), ttl: 300 } if addr == "192.168.1.1".parse::<Ipv4Addr>().unwrap()),
"expected A answer, got: {a_match:?}"
);
let aaaa_match = state.local().lookup(&name("router.home.lan"), Qtype::Aaaa);
assert!(
matches!(aaaa_match, LocalMatch::Answer { data: crate::resolver::local::RecordData::Aaaa(addr), ttl: 600 } if addr == "fd00::1".parse::<Ipv6Addr>().unwrap()),
"expected AAAA answer, got: {aaaa_match:?}"
);
}
#[tokio::test]
async fn hydration_settings_reflect_seed() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let s = state.settings();
assert_eq!(s.block_mode, BlockMode::null_ip(), "seed uses null-ip mode");
assert_eq!(s.negative_ttl_cap, 3600);
assert_eq!(s.cache_max_ttl, 86400);
assert_eq!(s.blocklist_refresh_interval, 86400);
assert_eq!(s.cache_min_ttl, 1);
assert_eq!(s.cache_capacity, 100_000);
}
#[tokio::test]
async fn settings_swap_is_visible_to_subsequent_readers() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
assert_eq!(state.settings().block_mode, BlockMode::null_ip());
let new_settings = RuntimeSettings {
block_mode: BlockMode::NxDomain,
..(*state.settings_full()).clone()
};
state.store_settings(new_settings);
assert_eq!(state.settings().block_mode, BlockMode::NxDomain);
}
#[tokio::test]
async fn settings_swap_concurrent_reader_observes_new_value() {
use std::sync::atomic::{AtomicBool, Ordering};
let (_dir, db) = crate::test_support::temp_db().await;
let state = Arc::new(ResolverState::hydrate(&db).await.expect("hydrate"));
let state_r = Arc::clone(&state);
let seen_nxdomain = Arc::new(AtomicBool::new(false));
let seen_r = Arc::clone(&seen_nxdomain);
let reader = tokio::spawn(async move {
loop {
if state_r.settings().block_mode == BlockMode::NxDomain {
seen_r.store(true, Ordering::Relaxed);
break;
}
tokio::task::yield_now().await;
}
});
let new_settings = RuntimeSettings {
block_mode: BlockMode::NxDomain,
..(*state.settings_full()).clone()
};
state.store_settings(new_settings);
reader.await.expect("reader task panicked");
assert!(
seen_nxdomain.load(Ordering::Relaxed),
"reader must have observed NxDomain after swap"
);
}
#[tokio::test]
async fn pause_sets_future_deadline_and_resume_clears_it() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
assert!(!state.blocking_paused());
assert_eq!(state.paused_until(), None);
state.pause_for_secs(300);
assert!(state.blocking_paused());
let deadline = state.paused_until().expect("paused deadline");
assert!(
deadline >= Clock::now_secs(),
"deadline must be in the future"
);
state.resume();
assert!(!state.blocking_paused());
assert_eq!(state.paused_until(), None);
}
#[tokio::test]
async fn expired_deadline_reads_as_active() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
state
.paused_until
.store(Clock::now_secs() - 1, Ordering::Relaxed);
assert!(
!state.blocking_paused(),
"past deadline must read as active"
);
assert_eq!(state.paused_until(), None);
}
#[tokio::test]
async fn pause_with_nonpositive_secs_is_immediate_resume() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
state.pause_for_secs(300);
assert!(state.blocking_paused());
state.pause_for_secs(0);
assert!(!state.blocking_paused());
assert_eq!(state.paused_until(), None);
}
#[tokio::test]
async fn hydration_empty_db_succeeds() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate empty db");
assert!(state.blacklist().is_empty());
assert!(state.allowlist().is_empty());
assert!(state.blocklist().is_empty());
assert_eq!(
state.local().lookup(&name("any.example.com"), Qtype::A),
LocalMatch::Miss
);
}
#[tokio::test]
async fn resolver_state_debug_does_not_panic() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let s = format!("{state:?}");
assert!(!s.is_empty());
}
#[tokio::test]
async fn settings_full_arc_outlives_swap() {
let (_dir, db) = crate::test_support::temp_db().await;
let state = ResolverState::hydrate(&db).await.expect("hydrate");
let old_settings = state.settings_full();
let old_block_mode = old_settings.block_mode.clone();
let new_settings = RuntimeSettings {
block_mode: BlockMode::NxDomain,
..(*state.settings_full()).clone()
};
state.store_settings(new_settings);
assert_eq!(old_block_mode, BlockMode::null_ip());
assert_eq!(state.settings().block_mode, BlockMode::NxDomain);
}
}