use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use tokio::sync::watch;
use tracing::{debug, instrument};
use super::records::{DnsRecord, RecordType};
use super::resolver::DnsResolver;
use crate::error::{Result, SeerError};
const MAX_FOLLOW_ITERATIONS: usize = 10_000;
#[derive(Debug, Clone)]
pub struct FollowConfig {
pub iterations: usize,
pub interval_secs: u64,
pub changes_only: bool,
}
impl Default for FollowConfig {
fn default() -> Self {
Self {
iterations: 10,
interval_secs: 60,
changes_only: false,
}
}
}
impl FollowConfig {
pub fn new(iterations: usize, interval_minutes: f64) -> Result<Self> {
if iterations == 0 {
return Err(SeerError::InvalidInput(
"iterations must be at least 1".into(),
));
}
if iterations > MAX_FOLLOW_ITERATIONS {
return Err(SeerError::InvalidInput(format!(
"iterations must be at most {MAX_FOLLOW_ITERATIONS}"
)));
}
if !interval_minutes.is_finite() {
return Err(SeerError::InvalidInput(
"interval_minutes must be a finite number".into(),
));
}
if interval_minutes < 0.0 {
return Err(SeerError::InvalidInput(
"interval_minutes must be non-negative".into(),
));
}
if interval_minutes > 60.0 {
return Err(SeerError::InvalidInput(
"interval_minutes must be at most 60".into(),
));
}
let mut interval_secs = (interval_minutes * 60.0) as u64;
if iterations > 1 {
interval_secs = interval_secs.max(1);
}
Ok(Self {
iterations,
interval_secs,
changes_only: false,
})
}
pub fn with_changes_only(mut self, changes_only: bool) -> Self {
self.changes_only = changes_only;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FollowIteration {
pub iteration: usize,
pub total_iterations: usize,
pub timestamp: DateTime<Utc>,
pub records: Vec<DnsRecord>,
pub changed: bool,
pub added: Vec<String>,
pub removed: Vec<String>,
pub error: Option<String>,
}
impl FollowIteration {
pub fn success(&self) -> bool {
self.error.is_none()
}
pub fn record_count(&self) -> usize {
self.records.len()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FollowResult {
pub domain: String,
pub record_type: RecordType,
pub nameserver: Option<String>,
pub iterations_requested: usize,
pub interval_secs: u64,
pub iterations: Vec<FollowIteration>,
pub interrupted: bool,
pub total_changes: usize,
pub started_at: DateTime<Utc>,
pub ended_at: DateTime<Utc>,
}
impl FollowResult {
pub fn completed_iterations(&self) -> usize {
self.iterations.len()
}
pub fn successful_iterations(&self) -> usize {
self.iterations.iter().filter(|i| i.success()).count()
}
pub fn failed_iterations(&self) -> usize {
self.iterations.iter().filter(|i| !i.success()).count()
}
}
pub type FollowProgressCallback = Arc<dyn Fn(&FollowIteration) + Send + Sync>;
#[derive(Clone)]
pub struct DnsFollower {
resolver: DnsResolver,
}
impl Default for DnsFollower {
fn default() -> Self {
Self::new()
}
}
impl DnsFollower {
pub fn new() -> Self {
Self {
resolver: DnsResolver::new(),
}
}
pub fn with_resolver(resolver: DnsResolver) -> Self {
Self { resolver }
}
#[instrument(skip(self, config, callback, cancel_rx))]
pub async fn follow(
&self,
domain: &str,
record_type: RecordType,
nameserver: Option<&str>,
config: FollowConfig,
callback: Option<FollowProgressCallback>,
cancel_rx: Option<watch::Receiver<bool>>,
) -> Result<FollowResult> {
let domain = crate::validation::normalize_domain(domain)?;
let started_at = Utc::now();
let mut iterations: Vec<FollowIteration> = Vec::with_capacity(config.iterations);
let mut previous_values: HashSet<String> = HashSet::new();
let mut total_changes = 0;
let mut interrupted = false;
debug!(
domain = %domain,
record_type = %record_type,
iterations = config.iterations,
interval_secs = config.interval_secs,
"Starting DNS follow"
);
for i in 0..config.iterations {
if let Some(ref rx) = cancel_rx {
if *rx.borrow() {
debug!("Follow operation cancelled");
interrupted = true;
break;
}
}
let timestamp = Utc::now();
let iteration_num = i + 1;
let (records, error) = match self
.resolver
.resolve(&domain, record_type, nameserver)
.await
{
Ok(records) => (records, None),
Err(e) => {
debug!(domain = %domain, error = %e, "DNS follow query failed");
(Vec::new(), Some(e.sanitized_message()))
}
};
let current_values: HashSet<String> =
records.iter().map(|r| r.data.to_string()).collect();
let (changed, added, removed) = if i == 0 {
(false, Vec::new(), Vec::new())
} else {
let added: Vec<String> = current_values
.difference(&previous_values)
.cloned()
.collect();
let removed: Vec<String> = previous_values
.difference(¤t_values)
.cloned()
.collect();
let changed = !added.is_empty() || !removed.is_empty();
(changed, added, removed)
};
if changed {
total_changes += 1;
}
let iteration = FollowIteration {
iteration: iteration_num,
total_iterations: config.iterations,
timestamp,
records,
changed,
added,
removed,
error,
};
if let Some(ref cb) = callback {
if !config.changes_only || iteration_num == 1 || changed {
cb(&iteration);
}
}
iterations.push(iteration);
previous_values = current_values;
if i < config.iterations - 1 {
let sleep_duration = Duration::from_secs(config.interval_secs);
if let Some(ref rx) = cancel_rx {
let mut rx_clone = rx.clone();
tokio::select! {
_ = tokio::time::sleep(sleep_duration) => {}
_ = rx_clone.changed() => {
if *rx_clone.borrow() {
debug!("Follow operation cancelled during sleep");
interrupted = true;
break;
}
}
}
} else {
tokio::time::sleep(sleep_duration).await;
}
}
}
let ended_at = Utc::now();
Ok(FollowResult {
domain: domain.to_string(),
record_type,
nameserver: nameserver.map(|s| s.to_string()),
iterations_requested: config.iterations,
interval_secs: config.interval_secs,
iterations,
interrupted,
total_changes,
started_at,
ended_at,
})
}
#[instrument(skip(self, config), fields(domain = %domain, record_type = ?record_type))]
pub async fn follow_simple(
&self,
domain: &str,
record_type: RecordType,
nameserver: Option<&str>,
config: FollowConfig,
) -> Result<FollowResult> {
self.follow(domain, record_type, nameserver, config, None, None)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_follow_config_default() {
let config = FollowConfig::default();
assert_eq!(config.iterations, 10);
assert_eq!(config.interval_secs, 60);
assert!(!config.changes_only);
}
#[test]
fn follow_config_rejects_unbounded_iterations() {
assert!(FollowConfig::new(MAX_FOLLOW_ITERATIONS, 1.0).is_ok());
let err = FollowConfig::new(MAX_FOLLOW_ITERATIONS + 1, 1.0).unwrap_err();
assert!(matches!(err, SeerError::InvalidInput(_)));
assert!(FollowConfig::new(usize::MAX, 1.0).is_err());
}
#[tokio::test]
async fn test_follow_config_new() {
let config = FollowConfig::new(5, 0.5).unwrap();
assert_eq!(config.iterations, 5);
assert_eq!(config.interval_secs, 30);
}
#[tokio::test]
#[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
async fn test_follow_single_iteration() {
let follower = DnsFollower::new();
let config = FollowConfig::new(1, 0.0).unwrap();
let result = follower
.follow_simple("example.com", RecordType::A, None, config)
.await;
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.completed_iterations(), 1);
assert!(!result.interrupted);
}
#[test]
fn follow_config_rejects_zero_iterations() {
assert!(FollowConfig::new(0, 1.0).is_err());
}
#[test]
fn follow_config_rejects_infinite_interval() {
assert!(FollowConfig::new(10, f64::INFINITY).is_err());
assert!(FollowConfig::new(10, f64::NEG_INFINITY).is_err());
}
#[test]
fn follow_config_rejects_nan_interval() {
assert!(FollowConfig::new(10, f64::NAN).is_err());
}
#[test]
fn follow_config_rejects_negative_interval() {
assert!(FollowConfig::new(10, -1.0).is_err());
}
#[test]
fn follow_config_rejects_interval_above_cap() {
assert!(FollowConfig::new(10, 60.1).is_err());
}
#[test]
fn follow_config_accepts_valid() {
assert!(FollowConfig::new(10, 1.5).is_ok());
assert!(FollowConfig::new(1, 0.0).is_ok());
assert!(FollowConfig::new(1, 60.0).is_ok());
}
#[test]
fn follow_config_floors_subsecond_interval_for_multi_iteration() {
let config = FollowConfig::new(10_000, 0.001).unwrap();
assert!(
config.interval_secs >= 1,
"multi-iteration interval must be floored to >= 1s, got {}",
config.interval_secs
);
}
#[test]
fn follow_config_allows_zero_interval_for_single_iteration() {
let config = FollowConfig::new(1, 0.0).unwrap();
assert_eq!(config.interval_secs, 0);
}
#[tokio::test]
#[ignore = "live network; run with --ignored or SEER_LIVE_TESTS=1"]
async fn follow_honors_cancel() {
use tokio::sync::watch;
let (tx, rx) = watch::channel(false);
let config = FollowConfig::new(100, 0.5).unwrap();
let follower = DnsFollower::new();
let handle = tokio::spawn(async move {
follower
.follow("example.com", RecordType::A, None, config, None, Some(rx))
.await
});
tokio::time::sleep(Duration::from_millis(200)).await;
tx.send(true).unwrap();
let joined = tokio::time::timeout(Duration::from_secs(10), handle)
.await
.expect("follow should return promptly after cancel");
let result = joined.expect("join").expect("follow result");
assert!(result.interrupted, "follow should be interrupted");
assert!(
result.completed_iterations() < 100,
"should not complete all iterations"
);
}
}