use std::collections::HashSet;
use crate::{codec::name::Name, resolver::matchset::MatchSet};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SourceContribution<K> {
pub source: K,
pub count: usize,
}
pub struct Aggregator<K> {
merged: HashSet<Name>,
contributions: Vec<SourceContribution<K>>,
}
impl<K> Aggregator<K> {
#[must_use]
pub fn new() -> Self {
Self {
merged: HashSet::new(),
contributions: Vec::new(),
}
}
pub fn add(&mut self, source: K, names: HashSet<Name>) {
let count = names.len();
self.contributions
.push(SourceContribution { source, count });
self.merged.extend(names);
}
#[must_use]
pub fn len(&self) -> usize {
self.merged.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.merged.is_empty()
}
#[must_use]
pub fn into_parts(self) -> (HashSet<Name>, Vec<SourceContribution<K>>) {
(self.merged, self.contributions)
}
#[must_use]
pub fn install(self, target: &MatchSet) -> Vec<SourceContribution<K>> {
let (merged, contributions) = self.into_parts();
target.store(merged);
contributions
}
}
impl<K> Default for Aggregator<K> {
fn default() -> Self {
Self::new()
}
}
impl<K: std::fmt::Debug> std::fmt::Debug for Aggregator<K> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Aggregator")
.field("merged_len", &self.merged.len())
.field("contributions", &self.contributions)
.finish()
}
}
impl<K> FromIterator<(K, HashSet<Name>)> for Aggregator<K> {
fn from_iter<I: IntoIterator<Item = (K, HashSet<Name>)>>(iter: I) -> Self {
let mut agg = Self::new();
for (source, names) in iter {
agg.add(source, names);
}
agg
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use super::*;
fn name(s: &str) -> Name {
s.parse().expect("valid domain name in test helper")
}
fn set_of(names: &[&str]) -> HashSet<Name> {
names.iter().map(|s| name(s)).collect()
}
#[test]
fn new_is_empty() {
let agg: Aggregator<u32> = Aggregator::new();
assert!(agg.is_empty());
assert_eq!(agg.len(), 0);
}
#[test]
fn default_is_empty() {
let agg: Aggregator<&str> = Aggregator::default();
assert!(agg.is_empty());
}
#[test]
fn single_source_add_records_contribution_and_len() {
let mut agg: Aggregator<&str> = Aggregator::new();
agg.add("source-a", set_of(&["ads.example.com", "tracker.net"]));
assert_eq!(agg.len(), 2);
let (merged, contributions) = agg.into_parts();
assert_eq!(contributions.len(), 1);
assert_eq!(contributions[0].source, "source-a");
assert_eq!(contributions[0].count, 2);
assert!(merged.contains(&name("ads.example.com")));
assert!(merged.contains(&name("tracker.net")));
}
#[test]
fn dedup_across_overlapping_sources() {
let mut agg: Aggregator<&str> = Aggregator::new();
agg.add(
"source-a",
set_of(&["ads.example.com", "tracker.net", "shared.bad"]),
);
agg.add(
"source-b",
set_of(&["shared.bad", "malware.io", "evil.org"]),
);
assert_eq!(agg.len(), 5);
let (merged, contributions) = agg.into_parts();
assert!(merged.contains(&name("ads.example.com")));
assert!(merged.contains(&name("tracker.net")));
assert!(merged.contains(&name("shared.bad")));
assert!(merged.contains(&name("malware.io")));
assert!(merged.contains(&name("evil.org")));
assert_eq!(contributions[0].source, "source-a");
assert_eq!(contributions[0].count, 3);
assert_eq!(contributions[1].source, "source-b");
assert_eq!(contributions[1].count, 3);
}
#[test]
fn install_swaps_into_matchset_and_removes_stale() {
let target = MatchSet::new(set_of(&["stale.example.com"]));
assert!(target.contains(&name("stale.example.com")));
let mut agg: Aggregator<&str> = Aggregator::new();
agg.add("source-a", set_of(&["ads.example.com", "tracker.net"]));
agg.add("source-b", set_of(&["tracker.net", "malware.io"]));
let contributions = agg.install(&target);
assert!(target.contains(&name("ads.example.com")));
assert!(target.contains(&name("tracker.net")));
assert!(target.contains(&name("malware.io")));
assert_eq!(target.len(), 3);
assert!(!target.contains(&name("stale.example.com")));
assert_eq!(contributions.len(), 2);
assert_eq!(contributions[0].source, "source-a");
assert_eq!(contributions[0].count, 2);
assert_eq!(contributions[1].source, "source-b");
assert_eq!(contributions[1].count, 2);
}
#[test]
fn install_new_set_immediately_visible() {
let target = MatchSet::empty();
let mut agg: Aggregator<u32> = Aggregator::new();
agg.add(1, set_of(&["new.example.com"]));
let _ = agg.install(&target);
assert!(target.contains(&name("new.example.com")));
assert_eq!(target.len(), 1);
}
#[test]
fn zero_sources_installs_empty_set_no_panic() {
let target = MatchSet::new(set_of(&["previously-blocked.example.com"]));
let agg: Aggregator<u32> = Aggregator::new();
let contributions = agg.install(&target);
assert!(
target.is_empty(),
"install of empty aggregator must clear the set"
);
assert!(contributions.is_empty());
}
#[test]
fn per_source_count_is_intra_source_deduped_size() {
let mut raw: HashSet<Name> = HashSet::new();
raw.insert("ADS.EXAMPLE.COM".parse().unwrap());
raw.insert("ads.example.com".parse().unwrap()); raw.insert("tracker.net".parse().unwrap());
assert_eq!(raw.len(), 2, "pre-condition: set has 2 deduped entries");
let mut agg: Aggregator<&str> = Aggregator::new();
agg.add("source-a", raw);
let (_, contributions) = agg.into_parts();
assert_eq!(
contributions[0].count, 2,
"count must be the deduped set size"
);
}
#[test]
fn into_parts_returns_same_merged_set_and_contributions() {
let mut agg: Aggregator<i64> = Aggregator::new();
agg.add(10, set_of(&["a.example.com", "b.example.com"]));
agg.add(20, set_of(&["b.example.com", "c.example.com"]));
let (merged, contributions) = agg.into_parts();
assert_eq!(merged.len(), 3);
assert!(merged.contains(&name("a.example.com")));
assert!(merged.contains(&name("b.example.com")));
assert!(merged.contains(&name("c.example.com")));
assert_eq!(contributions.len(), 2);
assert_eq!(
contributions[0],
SourceContribution {
source: 10,
count: 2
}
);
assert_eq!(
contributions[1],
SourceContribution {
source: 20,
count: 2
}
);
}
#[test]
fn from_iterator_builds_same_result_as_manual_add() {
let sources = vec![
("a", set_of(&["x.com", "y.com"])),
("b", set_of(&["y.com", "z.com"])),
];
let agg: Aggregator<&str> = sources.into_iter().collect();
assert_eq!(agg.len(), 3);
let (merged, contributions) = agg.into_parts();
assert!(merged.contains(&name("x.com")));
assert!(merged.contains(&name("y.com")));
assert!(merged.contains(&name("z.com")));
assert_eq!(contributions[0].source, "a");
assert_eq!(contributions[0].count, 2);
assert_eq!(contributions[1].source, "b");
assert_eq!(contributions[1].count, 2);
}
}