use std::{
collections::HashMap,
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
net::IpAddr,
sync::{
Mutex,
atomic::{AtomicU64, Ordering},
},
};
use super::event::QueryEvent;
use crate::resolver::pipeline::Outcome;
const STATS_SHARDS: usize = 16;
const MAX_TRACKED_DOMAINS: usize = 4096;
const MAX_TRACKED_CLIENTS: usize = 2048;
#[derive(Debug, Clone)]
pub struct StatsSnapshot {
pub total: u64,
pub blocked: u64,
pub cached: u64,
pub forwarded: u64,
pub blocked_ratio: f64,
pub top_domains: Vec<(String, u64)>,
pub top_clients: Vec<(IpAddr, u64)>,
}
pub struct Stats {
total: AtomicU64,
blocked: AtomicU64,
cached: AtomicU64,
forwarded: AtomicU64,
domains: ShardedCounts<String>,
clients: ShardedCounts<IpAddr>,
}
impl Stats {
pub fn new() -> Self {
Self {
total: AtomicU64::new(0),
blocked: AtomicU64::new(0),
cached: AtomicU64::new(0),
forwarded: AtomicU64::new(0),
domains: ShardedCounts::new(MAX_TRACKED_DOMAINS),
clients: ShardedCounts::new(MAX_TRACKED_CLIENTS),
}
}
pub fn record(&self, event: &QueryEvent) {
self.total.fetch_add(1, Ordering::Relaxed);
if event.outcome.is_blocked() {
self.blocked.fetch_add(1, Ordering::Relaxed);
}
match event.outcome {
Outcome::Cached => {
self.cached.fetch_add(1, Ordering::Relaxed);
}
Outcome::Forwarded => {
self.forwarded.fetch_add(1, Ordering::Relaxed);
}
_ => {}
}
self.domains.record(event.qname.to_string());
self.clients.record(event.client.ip());
}
pub fn snapshot(&self, top_n: usize) -> StatsSnapshot {
let total = self.total.load(Ordering::Relaxed);
let blocked = self.blocked.load(Ordering::Relaxed);
let cached = self.cached.load(Ordering::Relaxed);
let forwarded = self.forwarded.load(Ordering::Relaxed);
let blocked_ratio = if total == 0 {
0.0_f64
} else {
blocked as f64 / total as f64
};
let mut top_domains = self.domains.snapshot();
top_domains.sort_unstable_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
top_domains.truncate(top_n);
let mut top_clients = self.clients.snapshot();
top_clients.sort_unstable_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
top_clients.truncate(top_n);
StatsSnapshot {
total,
blocked,
cached,
forwarded,
blocked_ratio,
top_domains,
top_clients,
}
}
}
struct ShardedCounts<K> {
shards: Vec<Mutex<HashMap<K, u64>>>,
max_per_shard: usize,
}
impl<K> ShardedCounts<K>
where
K: Clone + Eq + Hash,
{
fn new(max_entries: usize) -> Self {
let max_per_shard = (max_entries / STATS_SHARDS).max(1);
let shards = (0..STATS_SHARDS)
.map(|_| Mutex::new(HashMap::with_capacity(max_per_shard)))
.collect();
Self {
shards,
max_per_shard,
}
}
fn record(&self, key: K) {
let shard = self.shard_for(&key);
let mut map = self.shards[shard]
.lock()
.expect("stats shard mutex poisoned");
if let Some(count) = map.get_mut(&key) {
*count += 1;
return;
}
if map.len() < self.max_per_shard {
map.insert(key, 1);
}
}
fn snapshot(&self) -> Vec<(K, u64)> {
self.shards
.iter()
.flat_map(|shard| {
let map = shard.lock().expect("stats shard mutex poisoned");
map.iter()
.map(|(key, &count)| (key.clone(), count))
.collect::<Vec<_>>()
})
.collect()
}
fn shard_for(&self, key: &K) -> usize {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
(hasher.finish() as usize) % self.shards.len()
}
}
impl Default for Stats {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
codec::{message::Qtype, name::Name},
resolver::pipeline::Outcome,
telemetry::event::QueryEvent,
};
use std::net::SocketAddr;
fn make_event(domain: &str, client: &str, outcome: Outcome) -> QueryEvent {
let client: SocketAddr = client.parse().unwrap();
let qname: Name = domain.parse().unwrap();
QueryEvent::new(client, qname, Qtype::A, outcome)
}
#[test]
fn counters_mixed_outcomes() {
let stats = Stats::new();
stats.record(&make_event(
"a.test",
"203.0.113.1:1000",
Outcome::BlockedByAdmin,
));
stats.record(&make_event(
"b.test",
"203.0.113.2:1001",
Outcome::BlockedByBlocklist,
));
stats.record(&make_event("c.test", "203.0.113.3:1002", Outcome::Cached));
stats.record(&make_event(
"d.test",
"203.0.113.4:1003",
Outcome::Forwarded,
));
let snap = stats.snapshot(10);
assert_eq!(snap.total, 4);
assert_eq!(snap.blocked, 2);
assert_eq!(snap.cached, 1);
assert_eq!(snap.forwarded, 1);
let ratio_permille = (snap.blocked_ratio * 1000.0).round() as u64;
assert_eq!(ratio_permille, 500, "blocked_ratio should be 0.5");
}
#[test]
fn blocked_ratio_zero_when_no_queries() {
let stats = Stats::new();
let snap = stats.snapshot(10);
assert_eq!(snap.total, 0);
#[allow(clippy::float_cmp)]
{
assert_eq!(snap.blocked_ratio, 0.0_f64);
}
}
#[test]
fn top_domains_sorted_by_count() {
let stats = Stats::new();
for _ in 0..4 {
stats.record(&make_event(
"popular.test",
"10.0.0.1:1000",
Outcome::Forwarded,
));
}
for _ in 0..2 {
stats.record(&make_event(
"medium.test",
"10.0.0.1:1001",
Outcome::Forwarded,
));
}
stats.record(&make_event(
"rare.test",
"10.0.0.1:1002",
Outcome::Forwarded,
));
let snap = stats.snapshot(2);
assert_eq!(snap.top_domains.len(), 2);
assert_eq!(snap.top_domains[0].0, "popular.test.");
assert_eq!(snap.top_domains[0].1, 4);
assert_eq!(snap.top_domains[1].0, "medium.test.");
assert_eq!(snap.top_domains[1].1, 2);
}
#[test]
fn top_clients_sorted_by_count() {
let stats = Stats::new();
for _ in 0..3 {
stats.record(&make_event("x.test", "10.0.0.1:1000", Outcome::Forwarded));
}
for _ in 0..5 {
stats.record(&make_event("y.test", "10.0.0.2:1001", Outcome::Forwarded));
}
stats.record(&make_event("z.test", "10.0.0.3:1002", Outcome::Forwarded));
let snap = stats.snapshot(2);
assert_eq!(snap.top_clients.len(), 2);
assert_eq!(snap.top_clients[0].0, "10.0.0.2".parse::<IpAddr>().unwrap());
assert_eq!(snap.top_clients[0].1, 5);
assert_eq!(snap.top_clients[1].0, "10.0.0.1".parse::<IpAddr>().unwrap());
assert_eq!(snap.top_clients[1].1, 3);
}
#[test]
fn top_domains_tie_broken_by_name() {
let stats = Stats::new();
stats.record(&make_event(
"zebra.test",
"10.0.0.1:1000",
Outcome::Forwarded,
));
stats.record(&make_event(
"apple.test",
"10.0.0.1:1001",
Outcome::Forwarded,
));
let snap = stats.snapshot(2);
assert_eq!(snap.top_domains.len(), 2);
assert_eq!(snap.top_domains[0].0, "apple.test.");
assert_eq!(snap.top_domains[1].0, "zebra.test.");
}
#[derive(Clone, Debug, Eq, PartialEq)]
struct SameShard(&'static str);
impl Hash for SameShard {
fn hash<H: Hasher>(&self, state: &mut H) {
0u8.hash(state);
}
}
#[test]
fn sharded_counts_are_bounded_but_keep_existing_keys() {
let counts = ShardedCounts::new(1);
counts.record(SameShard("first"));
counts.record(SameShard("second"));
counts.record(SameShard("first"));
let snapshot = counts.snapshot();
assert_eq!(snapshot.len(), 1);
assert_eq!(snapshot[0], (SameShard("first"), 2));
}
}