use std::collections::HashMap;
use std::io;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use cellos_core::{
CloudEventV1, DnsRebindingPolicy, DnsRefreshPolicy, DnsResolver, DnsResolverDnssecPolicy,
};
use super::{
RebindingState, ResolvedAnswer, ResolverRefresh, ResolverState, TrustAnchors,
ValidatedResolvedAnswer,
};
pub type SharedResolverFn = Arc<dyn Fn(&str) -> io::Result<ResolvedAnswer> + Send + Sync>;
pub type SharedValidatedResolverFn =
Arc<dyn Fn(&str) -> io::Result<ValidatedResolvedAnswer> + Send + Sync>;
pub trait DriftEmitter: Send + Sync + 'static {
fn emit(&self, event: CloudEventV1);
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub struct TickerStats {
pub tick_count: u64,
pub events_emitted: u64,
pub resolver_errors: u64,
}
pub struct TickerHandle {
pub shutdown: Arc<AtomicBool>,
pub task: tokio::task::JoinHandle<TickerStats>,
}
pub struct TickerConfig {
pub interval: Duration,
pub policy: Option<DnsRefreshPolicy>,
pub rebinding_policy: Option<DnsRebindingPolicy>,
pub resolvers: Vec<DnsResolver>,
pub hostnames: Vec<String>,
pub keyset_id: Option<String>,
pub issuer_kid: Option<String>,
pub policy_digest: Option<String>,
pub correlation_id: Option<String>,
pub source: String,
pub cell_id: String,
pub run_id: String,
pub dnssec_policy: Option<DnsResolverDnssecPolicy>,
pub trust_anchors: Option<TrustAnchors>,
pub validated_resolver: Option<SharedValidatedResolverFn>,
}
pub fn spawn_continuous_ticker(
cfg: TickerConfig,
emitter: Arc<dyn DriftEmitter>,
resolver: SharedResolverFn,
) -> TickerHandle {
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_for_task = shutdown.clone();
let task =
tokio::spawn(
async move { run_ticker_loop(cfg, emitter, resolver, shutdown_for_task).await },
);
TickerHandle { shutdown, task }
}
async fn run_ticker_loop(
cfg: TickerConfig,
emitter: Arc<dyn DriftEmitter>,
resolver: SharedResolverFn,
shutdown: Arc<AtomicBool>,
) -> TickerStats {
let mut stats = TickerStats::default();
let mut state = ResolverState::new();
let mut rebinding_state = RebindingState::new();
if cfg.hostnames.is_empty() {
return stats;
}
loop {
if shutdown.load(Ordering::SeqCst) {
break;
}
let dnssec_active = cfg.dnssec_policy.is_some() && cfg.validated_resolver.is_some();
let validated_resolved: HashMap<String, io::Result<ValidatedResolvedAnswer>> =
if dnssec_active {
let hostnames_for_resolve = cfg.hostnames.clone();
let validated = cfg.validated_resolver.as_ref().unwrap().clone();
match tokio::task::spawn_blocking(move || {
let mut out: HashMap<String, io::Result<ValidatedResolvedAnswer>> =
HashMap::new();
for hostname in &hostnames_for_resolve {
out.insert(hostname.clone(), validated(hostname));
}
out
})
.await
{
Ok(map) => map,
Err(_) => {
if !sleep_or_shutdown(cfg.interval, &shutdown).await {
break;
}
continue;
}
}
} else {
HashMap::new()
};
let resolved: HashMap<String, io::Result<ResolvedAnswer>> = if dnssec_active {
HashMap::new()
} else {
let hostnames_for_resolve = cfg.hostnames.clone();
let resolver_for_blocking = resolver.clone();
let join_result = tokio::task::spawn_blocking(move || {
let mut out: HashMap<String, io::Result<ResolvedAnswer>> = HashMap::new();
for hostname in &hostnames_for_resolve {
out.insert(hostname.clone(), resolver_for_blocking(hostname));
}
out
})
.await;
match join_result {
Ok(map) => map,
Err(_) => {
if !sleep_or_shutdown(cfg.interval, &shutdown).await {
break;
}
continue;
}
}
};
if dnssec_active {
for v in validated_resolved.values() {
if v.is_err() {
stats.resolver_errors = stats.resolver_errors.saturating_add(1);
}
}
} else {
for v in resolved.values() {
if v.is_err() {
stats.resolver_errors = stats.resolver_errors.saturating_add(1);
}
}
}
let resolver_for_tick = |hostname: &str| -> io::Result<ResolvedAnswer> {
match resolved.get(hostname) {
Some(Ok(answer)) => Ok(answer.clone()),
Some(Err(e)) => Err(io::Error::new(e.kind(), e.to_string())),
None => Err(io::Error::other(
"ticker: hostname missing from pre-resolved map",
)),
}
};
let refresher = ResolverRefresh {
policy: cfg.policy.as_ref(),
rebinding_policy: cfg.rebinding_policy.as_ref(),
resolvers: cfg.resolvers.as_slice(),
hostnames: cfg.hostnames.as_slice(),
keyset_id: cfg.keyset_id.as_deref(),
issuer_kid: cfg.issuer_kid.as_deref(),
policy_digest: cfg.policy_digest.as_deref(),
correlation_id: cfg.correlation_id.as_deref(),
source: Some(cfg.source.as_str()),
dnssec_policy: cfg.dnssec_policy.as_ref(),
trust_anchors: cfg.trust_anchors.as_ref(),
};
let events = if dnssec_active {
refresher.tick_with_dnssec(
&mut state,
&mut rebinding_state,
&validated_resolved,
SystemTime::now(),
&cfg.cell_id,
&cfg.run_id,
)
} else {
refresher.tick_with_rebinding(
&mut state,
&mut rebinding_state,
&resolver_for_tick,
SystemTime::now(),
&cfg.cell_id,
&cfg.run_id,
)
};
for ev in events {
stats.events_emitted = stats.events_emitted.saturating_add(1);
emitter.emit(ev);
}
stats.tick_count = stats.tick_count.saturating_add(1);
if !sleep_or_shutdown(cfg.interval, &shutdown).await {
break;
}
}
stats
}
pub fn clamp_tick_interval_secs(secs: u64) -> u64 {
secs.max(5)
}
async fn sleep_or_shutdown(interval: Duration, shutdown: &AtomicBool) -> bool {
let poll_step = Duration::from_millis(50);
let deadline = std::time::Instant::now() + interval;
loop {
if shutdown.load(Ordering::SeqCst) {
return false;
}
let now = std::time::Instant::now();
if now >= deadline {
return true;
}
let remaining = deadline - now;
tokio::time::sleep(remaining.min(poll_step)).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::atomic::AtomicU64;
use std::sync::Mutex;
fn answer(targets: Vec<String>) -> ResolvedAnswer {
ResolvedAnswer {
targets,
ttl_seconds: 0,
resolver_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 53),
}
}
#[derive(Default)]
struct CollectingEmitter {
events: Mutex<Vec<CloudEventV1>>,
}
impl DriftEmitter for CollectingEmitter {
fn emit(&self, event: CloudEventV1) {
self.events.lock().unwrap().push(event);
}
}
fn one_resolver() -> Vec<DnsResolver> {
vec![DnsResolver {
resolver_id: "resolver-doh-cloudflare".into(),
endpoint: "https://1.1.1.1/dns-query".into(),
protocol: cellos_core::DnsResolverProtocol::Doh,
trust_kid: None,
dnssec: None,
}]
}
fn base_cfg(hostnames: Vec<String>, interval: Duration) -> TickerConfig {
TickerConfig {
interval,
policy: Some(DnsRefreshPolicy {
min_ttl_seconds: Some(0),
max_stale_seconds: None,
strategy: None,
}),
rebinding_policy: None,
resolvers: one_resolver(),
hostnames,
keyset_id: Some("keyset-test".into()),
issuer_kid: Some("kid-test".into()),
policy_digest: None,
correlation_id: None,
source: "cellos-supervisor-test".into(),
cell_id: "cell-A".into(),
run_id: "run-A".into(),
dnssec_policy: None,
trust_anchors: None,
validated_resolver: None,
}
}
fn cycling_resolver(sequence: Vec<Vec<String>>) -> SharedResolverFn {
let counter = Arc::new(AtomicU64::new(0));
let seq = Arc::new(sequence);
Arc::new(move |_h: &str| {
let idx = counter.fetch_add(1, Ordering::SeqCst) as usize;
let pick = if idx >= seq.len() {
seq.last().cloned().unwrap_or_default()
} else {
seq[idx].clone()
};
Ok(answer(pick))
})
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ticker_emits_drift_when_targets_change_between_ticks() {
let cfg = base_cfg(vec!["api.example.com".into()], Duration::from_millis(100));
let emitter = Arc::new(CollectingEmitter::default());
let resolver = cycling_resolver(vec![
vec!["1.1.1.1".into()],
vec!["1.0.0.1".into()],
vec!["1.0.0.1".into()],
]);
let handle = spawn_continuous_ticker(cfg, emitter.clone(), resolver);
tokio::time::sleep(Duration::from_millis(250)).await;
handle.shutdown.store(true, Ordering::SeqCst);
let stats = tokio::time::timeout(Duration::from_secs(1), handle.task)
.await
.expect("ticker join timeout")
.expect("ticker task panicked");
let events = emitter.events.lock().unwrap();
assert!(
events.len() >= 2,
"expected baseline + change drift events, got {}",
events.len()
);
assert!(stats.tick_count >= 2);
assert!(stats.events_emitted >= 2);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ticker_silent_when_targets_stable() {
let cfg = base_cfg(vec!["api.example.com".into()], Duration::from_millis(50));
let emitter = Arc::new(CollectingEmitter::default());
let resolver: SharedResolverFn =
Arc::new(|_h: &str| Ok(answer(vec!["203.0.113.10".into()])));
let handle = spawn_continuous_ticker(cfg, emitter.clone(), resolver);
tokio::time::sleep(Duration::from_millis(250)).await;
handle.shutdown.store(true, Ordering::SeqCst);
let _ = tokio::time::timeout(Duration::from_secs(1), handle.task)
.await
.expect("ticker join timeout");
let events = emitter.events.lock().unwrap();
assert_eq!(
events.len(),
1,
"stable targets must emit exactly one baseline event, got {}",
events.len()
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ticker_respects_shutdown_promptly() {
let cfg = base_cfg(vec!["api.example.com".into()], Duration::from_secs(10));
let emitter = Arc::new(CollectingEmitter::default());
let resolver: SharedResolverFn = Arc::new(|_h: &str| Ok(answer(vec!["1.1.1.1".into()])));
let handle = spawn_continuous_ticker(cfg, emitter, resolver);
tokio::time::sleep(Duration::from_millis(80)).await;
handle.shutdown.store(true, Ordering::SeqCst);
let started = std::time::Instant::now();
let _ = tokio::time::timeout(Duration::from_secs(1), handle.task)
.await
.expect("ticker did not honour shutdown within 1s");
let elapsed = started.elapsed();
assert!(
elapsed < Duration::from_millis(500),
"shutdown took {elapsed:?}, expected <500ms"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ticker_respects_floor_interval_min() {
let floor = crate::resolver_refresh::ticker::clamp_tick_interval_secs(1);
assert!(floor >= 5, "tick interval floor must be >=5s; got {floor}");
let unbounded = crate::resolver_refresh::ticker::clamp_tick_interval_secs(120);
assert_eq!(unbounded, 120, "values >=floor must pass through untouched");
let zero = crate::resolver_refresh::ticker::clamp_tick_interval_secs(0);
assert!(zero >= 5, "zero must clamp up to floor");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ticker_handles_resolver_failure_gracefully() {
let cfg = base_cfg(vec!["api.example.com".into()], Duration::from_millis(50));
let emitter = Arc::new(CollectingEmitter::default());
let resolver: SharedResolverFn = Arc::new(|_h: &str| Err(io::Error::other("transient")));
let handle = spawn_continuous_ticker(cfg, emitter.clone(), resolver);
tokio::time::sleep(Duration::from_millis(250)).await;
handle.shutdown.store(true, Ordering::SeqCst);
let stats = tokio::time::timeout(Duration::from_secs(1), handle.task)
.await
.expect("ticker join timeout")
.expect("ticker task panicked");
let events = emitter.events.lock().unwrap();
assert!(
events.is_empty(),
"resolver failures must not emit drift, got {} events",
events.len()
);
assert!(
stats.resolver_errors >= 1,
"resolver_errors counter should reflect the failures, got {}",
stats.resolver_errors
);
assert!(
stats.tick_count >= 1,
"ticker must keep running across resolver errors"
);
}
fn count_events_of(events: &[CloudEventV1], suffix: &str) -> usize {
events.iter().filter(|e| e.ty.ends_with(suffix)).count()
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ticker_no_rebinding_events_when_policy_is_none() {
let mut cfg = base_cfg(vec!["api.example.com".into()], Duration::from_millis(80));
cfg.rebinding_policy = None;
let emitter = Arc::new(CollectingEmitter::default());
let resolver = cycling_resolver(vec![
vec!["1.0.0.1".into()],
vec!["1.0.0.2".into()],
vec!["1.0.0.3".into()],
vec!["1.0.0.4".into()],
vec!["1.0.0.5".into()],
]);
let handle = spawn_continuous_ticker(cfg, emitter.clone(), resolver);
tokio::time::sleep(Duration::from_millis(250)).await;
handle.shutdown.store(true, Ordering::SeqCst);
let _ = tokio::time::timeout(Duration::from_secs(1), handle.task).await;
let events = emitter.events.lock().unwrap();
assert_eq!(
count_events_of(&events, "dns_authority_rebind_threshold"),
0,
"no rebinding policy → no threshold events"
);
assert_eq!(
count_events_of(&events, "dns_authority_rebind_rejected"),
0,
"no rebinding policy → no rejected events"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ticker_threshold_only_audit_mode_emits_threshold_events() {
let mut cfg = base_cfg(vec!["api.example.com".into()], Duration::from_millis(60));
cfg.rebinding_policy = Some(DnsRebindingPolicy {
response_ip_allowlist: Vec::new(),
max_novel_ips_per_hostname: 2,
reject_on_rebind: false,
});
cfg.policy_digest =
Some("sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".into());
let emitter = Arc::new(CollectingEmitter::default());
let resolver = cycling_resolver(vec![
vec!["1.0.0.1".into()],
vec!["1.0.0.2".into()],
vec!["1.0.0.3".into()],
vec!["1.0.0.4".into()],
]);
let handle = spawn_continuous_ticker(cfg, emitter.clone(), resolver);
tokio::time::sleep(Duration::from_millis(280)).await;
handle.shutdown.store(true, Ordering::SeqCst);
let _ = tokio::time::timeout(Duration::from_secs(1), handle.task).await;
let events = emitter.events.lock().unwrap();
let threshold = count_events_of(&events, "dns_authority_rebind_threshold");
let rejected = count_events_of(&events, "dns_authority_rebind_rejected");
assert!(
threshold >= 2,
"audit-only mode must emit at least 2 threshold events when cap is breached over multiple ticks; got {threshold}"
);
assert_eq!(
rejected, 0,
"no allowlist set → no rejected events; got {rejected}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ticker_reject_on_rebind_filters_drift_targets() {
let mut cfg = base_cfg(vec!["api.example.com".into()], Duration::from_millis(60));
cfg.rebinding_policy = Some(DnsRebindingPolicy {
response_ip_allowlist: Vec::new(),
max_novel_ips_per_hostname: 2,
reject_on_rebind: true,
});
cfg.policy_digest =
Some("sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".into());
let emitter = Arc::new(CollectingEmitter::default());
let resolver = cycling_resolver(vec![
vec!["1.0.0.1".into(), "1.0.0.2".into()],
vec!["1.0.0.1".into(), "1.0.0.2".into(), "198.51.100.7".into()],
]);
let handle = spawn_continuous_ticker(cfg, emitter.clone(), resolver);
tokio::time::sleep(Duration::from_millis(220)).await;
handle.shutdown.store(true, Ordering::SeqCst);
let _ = tokio::time::timeout(Duration::from_secs(1), handle.task).await;
let events = emitter.events.lock().unwrap();
let drift_events: Vec<_> = events
.iter()
.filter(|e| e.ty.ends_with("dns_authority_drift"))
.collect();
assert!(
!drift_events.is_empty(),
"drift events must still fire (rejection only filters targets, not the drift signal)"
);
let last = drift_events.last().unwrap();
let data = last.data.as_ref().expect("data");
let current: Vec<&str> = data["currentTargets"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_str().unwrap())
.collect();
assert!(
current.contains(&"1.0.0.1"),
"first legitimate IP must survive rejection: {current:?}"
);
assert!(
current.contains(&"1.0.0.2"),
"second legitimate IP must survive rejection: {current:?}"
);
assert!(
!current.contains(&"198.51.100.7"),
"attacker IP beyond cap=2 must be filtered when reject_on_rebind=true: {current:?}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ticker_allowlist_only_emits_rejected_events() {
let mut cfg = base_cfg(vec!["api.example.com".into()], Duration::from_millis(60));
cfg.rebinding_policy = Some(DnsRebindingPolicy {
response_ip_allowlist: vec!["api.example.com:1.1.1.1".into()],
max_novel_ips_per_hostname: 100,
reject_on_rebind: false,
});
cfg.policy_digest =
Some("sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".into());
let emitter = Arc::new(CollectingEmitter::default());
let resolver: SharedResolverFn =
Arc::new(|_h: &str| Ok(answer(vec!["198.51.100.7".into()])));
let handle = spawn_continuous_ticker(cfg, emitter.clone(), resolver);
tokio::time::sleep(Duration::from_millis(220)).await;
handle.shutdown.store(true, Ordering::SeqCst);
let _ = tokio::time::timeout(Duration::from_secs(1), handle.task).await;
let events = emitter.events.lock().unwrap();
let threshold = count_events_of(&events, "dns_authority_rebind_threshold");
let rejected = count_events_of(&events, "dns_authority_rebind_rejected");
assert!(
rejected >= 1,
"allowlist violation must fire at least one rejected event"
);
assert_eq!(
threshold, 0,
"cap is far above the IP count → no threshold events; got {threshold}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ticker_combined_threshold_and_allowlist_both_fire() {
let mut cfg = base_cfg(vec!["api.example.com".into()], Duration::from_millis(60));
cfg.rebinding_policy = Some(DnsRebindingPolicy {
response_ip_allowlist: vec!["api.example.com:1.1.1.1".into()],
max_novel_ips_per_hostname: 1,
reject_on_rebind: false,
});
cfg.policy_digest =
Some("sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".into());
let emitter = Arc::new(CollectingEmitter::default());
let resolver = cycling_resolver(vec![vec!["1.1.1.1".into()], vec!["198.51.100.7".into()]]);
let handle = spawn_continuous_ticker(cfg, emitter.clone(), resolver);
tokio::time::sleep(Duration::from_millis(220)).await;
handle.shutdown.store(true, Ordering::SeqCst);
let _ = tokio::time::timeout(Duration::from_secs(1), handle.task).await;
let events = emitter.events.lock().unwrap();
let threshold = count_events_of(&events, "dns_authority_rebind_threshold");
let rejected = count_events_of(&events, "dns_authority_rebind_rejected");
assert!(
threshold >= 1,
"second tick exceeds cap=1 → threshold event expected"
);
assert!(
rejected >= 1,
"second tick IP is not in allowlist → rejected event expected"
);
}
fn cycling_validated_resolver(
sequence: Vec<crate::resolver_refresh::ValidatedResolvedAnswer>,
) -> SharedValidatedResolverFn {
let counter = Arc::new(AtomicU64::new(0));
let seq = Arc::new(sequence);
Arc::new(move |_h: &str| {
let idx = counter.fetch_add(1, Ordering::SeqCst) as usize;
Ok(if idx >= seq.len() {
seq.last().cloned().unwrap_or_else(|| {
crate::resolver_refresh::ValidatedResolvedAnswer {
answer: ResolvedAnswer {
targets: Vec::new(),
ttl_seconds: 0,
resolver_addr: SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
53,
),
},
validation: crate::resolver_refresh::DnssecValidationResult::Unsigned,
}
})
} else {
seq[idx].clone()
})
})
}
fn validated_with(
targets: Vec<&str>,
validation: crate::resolver_refresh::DnssecValidationResult,
) -> crate::resolver_refresh::ValidatedResolvedAnswer {
crate::resolver_refresh::ValidatedResolvedAnswer {
answer: ResolvedAnswer {
targets: targets.into_iter().map(String::from).collect(),
ttl_seconds: 60,
resolver_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 53),
},
validation,
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn dnssec_failed_event_emitted_when_validate_true_failclosed_false() {
let mut cfg = base_cfg(vec!["api.example.com".into()], Duration::from_millis(70));
cfg.dnssec_policy = Some(DnsResolverDnssecPolicy {
validate: true,
fail_closed: false,
trust_anchors_path: None,
});
cfg.trust_anchors = Some(crate::resolver_refresh::TrustAnchors::iana_default());
cfg.validated_resolver = Some(cycling_validated_resolver(vec![validated_with(
vec!["1.0.0.1"],
crate::resolver_refresh::DnssecValidationResult::Failed {
reason: "synthetic-bogus".to_string(),
},
)]));
let emitter = Arc::new(CollectingEmitter::default());
let handle = spawn_continuous_ticker(
cfg,
emitter.clone(),
Arc::new(|_h: &str| Ok(answer(vec!["1.0.0.1".into()]))),
);
tokio::time::sleep(Duration::from_millis(220)).await;
handle.shutdown.store(true, Ordering::SeqCst);
let _ = tokio::time::timeout(Duration::from_secs(1), handle.task).await;
let events = emitter.events.lock().unwrap();
let dnssec_failed = count_events_of(&events, "dns_authority_dnssec_failed");
assert!(
dnssec_failed >= 1,
"audit-only DNSSEC failure must fire dns_authority_dnssec_failed; got {dnssec_failed}"
);
let drift_events: Vec<_> = events
.iter()
.filter(|e| e.ty.ends_with("dns_authority_drift"))
.collect();
assert!(
!drift_events.is_empty(),
"audit-only mode keeps the answer; drift must still fire"
);
let last_drift = drift_events.last().unwrap();
let data = last_drift.data.as_ref().expect("data");
let current: Vec<&str> = data["currentTargets"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_str().unwrap())
.collect();
assert!(
current.contains(&"1.0.0.1"),
"audit-only mode preserves the unvalidated answer in the drift event: {current:?}"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn dnssec_drops_answer_when_failclosed_true() {
let mut cfg = base_cfg(vec!["api.example.com".into()], Duration::from_millis(70));
cfg.dnssec_policy = Some(DnsResolverDnssecPolicy {
validate: true,
fail_closed: true,
trust_anchors_path: None,
});
cfg.trust_anchors = Some(crate::resolver_refresh::TrustAnchors::iana_default());
cfg.validated_resolver = Some(cycling_validated_resolver(vec![validated_with(
vec!["198.51.100.7"],
crate::resolver_refresh::DnssecValidationResult::Failed {
reason: "synthetic-bogus".to_string(),
},
)]));
let emitter = Arc::new(CollectingEmitter::default());
let handle = spawn_continuous_ticker(
cfg,
emitter.clone(),
Arc::new(|_h: &str| Ok(answer(vec!["198.51.100.7".into()]))),
);
tokio::time::sleep(Duration::from_millis(220)).await;
handle.shutdown.store(true, Ordering::SeqCst);
let _ = tokio::time::timeout(Duration::from_secs(1), handle.task).await;
let events = emitter.events.lock().unwrap();
let dnssec_failed = count_events_of(&events, "dns_authority_dnssec_failed");
assert!(
dnssec_failed >= 1,
"enforce DNSSEC failure must fire dns_authority_dnssec_failed; got {dnssec_failed}"
);
let drift_events: Vec<_> = events
.iter()
.filter(|e| e.ty.ends_with("dns_authority_drift"))
.collect();
if let Some(last) = drift_events.last() {
let data = last.data.as_ref().expect("data");
let current: Vec<&str> = data["currentTargets"]
.as_array()
.unwrap()
.iter()
.map(|v| v.as_str().unwrap())
.collect();
assert!(
!current.contains(&"198.51.100.7"),
"failClosed=true MUST drop the attacker IP from drift currentTargets: {current:?}"
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn dnssec_status_field_set_in_drift_event() {
let mut cfg = base_cfg(vec!["api.example.com".into()], Duration::from_millis(70));
cfg.dnssec_policy = Some(DnsResolverDnssecPolicy {
validate: true,
fail_closed: false,
trust_anchors_path: None,
});
cfg.trust_anchors = Some(crate::resolver_refresh::TrustAnchors::iana_default());
cfg.validated_resolver = Some(cycling_validated_resolver(vec![validated_with(
vec!["1.1.1.1"],
crate::resolver_refresh::DnssecValidationResult::Validated {
algorithm: "RSASHA256".to_string(),
key_tag: 19036,
},
)]));
let emitter = Arc::new(CollectingEmitter::default());
let handle = spawn_continuous_ticker(
cfg,
emitter.clone(),
Arc::new(|_h: &str| Ok(answer(vec!["1.1.1.1".into()]))),
);
tokio::time::sleep(Duration::from_millis(180)).await;
handle.shutdown.store(true, Ordering::SeqCst);
let _ = tokio::time::timeout(Duration::from_secs(1), handle.task).await;
let events = emitter.events.lock().unwrap();
assert_eq!(
count_events_of(&events, "dns_authority_dnssec_failed"),
0,
"Validated path must not emit dns_authority_dnssec_failed"
);
let drift_events: Vec<_> = events
.iter()
.filter(|e| e.ty.ends_with("dns_authority_drift"))
.collect();
assert!(
!drift_events.is_empty(),
"drift must fire on first observation"
);
let data = drift_events[0].data.as_ref().expect("data");
assert_eq!(
data["dnssecStatus"], "validated",
"drift in DNSSEC mode must stamp dnssecStatus=validated"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn ticker_novel_ip_within_cap_does_not_emit_threshold() {
let mut cfg = base_cfg(vec!["api.example.com".into()], Duration::from_millis(60));
cfg.rebinding_policy = Some(DnsRebindingPolicy {
response_ip_allowlist: Vec::new(),
max_novel_ips_per_hostname: 4,
reject_on_rebind: false,
});
cfg.policy_digest =
Some("sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855".into());
let emitter = Arc::new(CollectingEmitter::default());
let resolver = cycling_resolver(vec![
vec!["1.0.0.1".into()],
vec!["1.0.0.2".into()],
vec!["1.0.0.3".into()],
]);
let handle = spawn_continuous_ticker(cfg, emitter.clone(), resolver);
tokio::time::sleep(Duration::from_millis(220)).await;
handle.shutdown.store(true, Ordering::SeqCst);
let _ = tokio::time::timeout(Duration::from_secs(1), handle.task).await;
let events = emitter.events.lock().unwrap();
let threshold = count_events_of(&events, "dns_authority_rebind_threshold");
assert_eq!(
threshold, 0,
"3 distinct IPs under cap=4 must not fire threshold; got {threshold}"
);
}
}