use std::{
collections::HashMap,
net::IpAddr,
sync::{
Mutex,
atomic::{AtomicU64, Ordering},
},
};
use super::event::QueryEvent;
use crate::resolver::pipeline::Outcome;
#[derive(Debug, Clone)]
pub struct StatsSnapshot {
pub total: u64,
pub blocked: u64,
pub cached: u64,
pub forwarded: u64,
pub blocked_ratio: f64,
pub top_domains: Vec<(String, u64)>,
pub top_clients: Vec<(IpAddr, u64)>,
}
pub struct Stats {
total: AtomicU64,
blocked: AtomicU64,
cached: AtomicU64,
forwarded: AtomicU64,
domains: Mutex<HashMap<String, u64>>,
clients: Mutex<HashMap<IpAddr, u64>>,
}
impl Stats {
pub fn new() -> Self {
Self {
total: AtomicU64::new(0),
blocked: AtomicU64::new(0),
cached: AtomicU64::new(0),
forwarded: AtomicU64::new(0),
domains: Mutex::new(HashMap::new()),
clients: Mutex::new(HashMap::new()),
}
}
pub fn record(&self, event: &QueryEvent) {
self.total.fetch_add(1, Ordering::Relaxed);
if event.outcome.is_blocked() {
self.blocked.fetch_add(1, Ordering::Relaxed);
}
match event.outcome {
Outcome::Cached => {
self.cached.fetch_add(1, Ordering::Relaxed);
}
Outcome::Forwarded => {
self.forwarded.fetch_add(1, Ordering::Relaxed);
}
_ => {}
}
{
let mut map = self.domains.lock().expect("domains mutex poisoned");
*map.entry(event.qname.to_string()).or_insert(0) += 1;
}
{
let mut map = self.clients.lock().expect("clients mutex poisoned");
*map.entry(event.client.ip()).or_insert(0) += 1;
}
}
pub fn snapshot(&self, top_n: usize) -> StatsSnapshot {
let total = self.total.load(Ordering::Relaxed);
let blocked = self.blocked.load(Ordering::Relaxed);
let cached = self.cached.load(Ordering::Relaxed);
let forwarded = self.forwarded.load(Ordering::Relaxed);
let blocked_ratio = if total == 0 {
0.0_f64
} else {
blocked as f64 / total as f64
};
let top_domains = {
let map = self.domains.lock().expect("domains mutex poisoned");
let mut pairs: Vec<(String, u64)> = map.iter().map(|(k, &v)| (k.clone(), v)).collect();
pairs.sort_unstable_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
pairs.truncate(top_n);
pairs
};
let top_clients = {
let map = self.clients.lock().expect("clients mutex poisoned");
let mut pairs: Vec<(IpAddr, u64)> = map.iter().map(|(k, &v)| (*k, v)).collect();
pairs.sort_unstable_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
pairs.truncate(top_n);
pairs
};
StatsSnapshot {
total,
blocked,
cached,
forwarded,
blocked_ratio,
top_domains,
top_clients,
}
}
}
impl Default for Stats {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
codec::{message::Qtype, name::Name},
resolver::pipeline::Outcome,
telemetry::event::QueryEvent,
};
use std::net::SocketAddr;
fn make_event(domain: &str, client: &str, outcome: Outcome) -> QueryEvent {
let client: SocketAddr = client.parse().unwrap();
let qname: Name = domain.parse().unwrap();
QueryEvent::new(client, qname, Qtype::A, outcome)
}
#[test]
fn counters_mixed_outcomes() {
let stats = Stats::new();
stats.record(&make_event(
"a.test",
"203.0.113.1:1000",
Outcome::BlockedByAdmin,
));
stats.record(&make_event(
"b.test",
"203.0.113.2:1001",
Outcome::BlockedByBlocklist,
));
stats.record(&make_event("c.test", "203.0.113.3:1002", Outcome::Cached));
stats.record(&make_event(
"d.test",
"203.0.113.4:1003",
Outcome::Forwarded,
));
let snap = stats.snapshot(10);
assert_eq!(snap.total, 4);
assert_eq!(snap.blocked, 2);
assert_eq!(snap.cached, 1);
assert_eq!(snap.forwarded, 1);
let ratio_permille = (snap.blocked_ratio * 1000.0).round() as u64;
assert_eq!(ratio_permille, 500, "blocked_ratio should be 0.5");
}
#[test]
fn blocked_ratio_zero_when_no_queries() {
let stats = Stats::new();
let snap = stats.snapshot(10);
assert_eq!(snap.total, 0);
#[allow(clippy::float_cmp)]
{
assert_eq!(snap.blocked_ratio, 0.0_f64);
}
}
#[test]
fn top_domains_sorted_by_count() {
let stats = Stats::new();
for _ in 0..4 {
stats.record(&make_event(
"popular.test",
"10.0.0.1:1000",
Outcome::Forwarded,
));
}
for _ in 0..2 {
stats.record(&make_event(
"medium.test",
"10.0.0.1:1001",
Outcome::Forwarded,
));
}
stats.record(&make_event(
"rare.test",
"10.0.0.1:1002",
Outcome::Forwarded,
));
let snap = stats.snapshot(2);
assert_eq!(snap.top_domains.len(), 2);
assert_eq!(snap.top_domains[0].0, "popular.test.");
assert_eq!(snap.top_domains[0].1, 4);
assert_eq!(snap.top_domains[1].0, "medium.test.");
assert_eq!(snap.top_domains[1].1, 2);
}
#[test]
fn top_clients_sorted_by_count() {
let stats = Stats::new();
for _ in 0..3 {
stats.record(&make_event("x.test", "10.0.0.1:1000", Outcome::Forwarded));
}
for _ in 0..5 {
stats.record(&make_event("y.test", "10.0.0.2:1001", Outcome::Forwarded));
}
stats.record(&make_event("z.test", "10.0.0.3:1002", Outcome::Forwarded));
let snap = stats.snapshot(2);
assert_eq!(snap.top_clients.len(), 2);
assert_eq!(snap.top_clients[0].0, "10.0.0.2".parse::<IpAddr>().unwrap());
assert_eq!(snap.top_clients[0].1, 5);
assert_eq!(snap.top_clients[1].0, "10.0.0.1".parse::<IpAddr>().unwrap());
assert_eq!(snap.top_clients[1].1, 3);
}
#[test]
fn top_domains_tie_broken_by_name() {
let stats = Stats::new();
stats.record(&make_event(
"zebra.test",
"10.0.0.1:1000",
Outcome::Forwarded,
));
stats.record(&make_event(
"apple.test",
"10.0.0.1:1001",
Outcome::Forwarded,
));
let snap = stats.snapshot(2);
assert_eq!(snap.top_domains.len(), 2);
assert_eq!(snap.top_domains[0].0, "apple.test.");
assert_eq!(snap.top_domains[1].0, "zebra.test.");
}
}