use std::collections::{HashMap, HashSet};
use crate::{codec::name::Name, resolver::matchset::AttributedSet};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SourceContribution {
pub source: i64,
pub count: usize,
}
pub struct Aggregator {
merged: HashMap<Name, i64>,
contributions: Vec<SourceContribution>,
}
impl Aggregator {
#[must_use]
pub fn new() -> Self {
Self {
merged: HashMap::new(),
contributions: Vec::new(),
}
}
pub fn add(&mut self, source_id: i64, names: HashSet<Name>) {
let count = names.len();
self.contributions.push(SourceContribution {
source: source_id,
count,
});
for name in names {
self.merged.entry(name).or_insert(source_id);
}
}
#[must_use]
pub fn len(&self) -> usize {
self.merged.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.merged.is_empty()
}
#[must_use]
pub fn into_parts(self) -> (HashMap<Name, i64>, Vec<SourceContribution>) {
(self.merged, self.contributions)
}
#[must_use]
pub fn install(self, target: &AttributedSet) -> Vec<SourceContribution> {
let (merged, contributions) = self.into_parts();
target.store(merged);
contributions
}
}
impl Default for Aggregator {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for Aggregator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Aggregator")
.field("merged_len", &self.merged.len())
.field("contributions", &self.contributions)
.finish()
}
}
impl FromIterator<(i64, HashSet<Name>)> for Aggregator {
fn from_iter<I: IntoIterator<Item = (i64, HashSet<Name>)>>(iter: I) -> Self {
let mut agg = Self::new();
for (source_id, names) in iter {
agg.add(source_id, names);
}
agg
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use super::*;
fn name(s: &str) -> Name {
s.parse().expect("valid domain name in test helper")
}
fn set_of(names: &[&str]) -> HashSet<Name> {
names.iter().map(|s| name(s)).collect()
}
#[test]
fn new_is_empty() {
let agg = Aggregator::new();
assert!(agg.is_empty());
assert_eq!(agg.len(), 0);
}
#[test]
fn default_is_empty() {
let agg = Aggregator::default();
assert!(agg.is_empty());
}
#[test]
fn single_source_add_records_contribution_and_len() {
let mut agg = Aggregator::new();
agg.add(1, set_of(&["ads.example.com", "tracker.net"]));
assert_eq!(agg.len(), 2);
let (merged, contributions) = agg.into_parts();
assert_eq!(contributions.len(), 1);
assert_eq!(contributions[0].source, 1);
assert_eq!(contributions[0].count, 2);
assert_eq!(merged.get(&name("ads.example.com")), Some(&1));
assert_eq!(merged.get(&name("tracker.net")), Some(&1));
}
#[test]
fn dedup_across_overlapping_sources_attributes_to_first() {
let mut agg = Aggregator::new();
agg.add(1, set_of(&["ads.example.com", "tracker.net", "shared.bad"]));
agg.add(2, set_of(&["shared.bad", "malware.io", "evil.org"]));
assert_eq!(agg.len(), 5);
let (merged, contributions) = agg.into_parts();
assert_eq!(
merged.get(&name("shared.bad")),
Some(&1),
"overlap must be attributed to the first (lowest-id) source"
);
assert_eq!(merged.get(&name("malware.io")), Some(&2));
assert_eq!(merged.get(&name("evil.org")), Some(&2));
assert_eq!(contributions[0].source, 1);
assert_eq!(contributions[0].count, 3);
assert_eq!(contributions[1].source, 2);
assert_eq!(contributions[1].count, 3);
}
#[test]
fn first_writer_wins_regardless_of_id_value() {
let mut agg = Aggregator::new();
agg.add(10, set_of(&["shared.bad"]));
agg.add(2, set_of(&["shared.bad"]));
let (merged, _) = agg.into_parts();
assert_eq!(
merged.get(&name("shared.bad")),
Some(&10),
"the source added first keeps the attribution"
);
}
#[test]
fn install_swaps_into_set_and_removes_stale() {
let target: AttributedSet = [(name("stale.example.com"), 99)].into_iter().collect();
assert!(target.contains(&name("stale.example.com")));
let mut agg = Aggregator::new();
agg.add(1, set_of(&["ads.example.com", "tracker.net"]));
agg.add(2, set_of(&["tracker.net", "malware.io"]));
let contributions = agg.install(&target);
assert_eq!(target.primary_source(&name("ads.example.com")), Some(1));
assert_eq!(target.primary_source(&name("tracker.net")), Some(1));
assert_eq!(target.primary_source(&name("malware.io")), Some(2));
assert_eq!(target.len(), 3);
assert!(!target.contains(&name("stale.example.com")));
assert_eq!(contributions.len(), 2);
assert_eq!(contributions[0].source, 1);
assert_eq!(contributions[0].count, 2);
assert_eq!(contributions[1].source, 2);
assert_eq!(contributions[1].count, 2);
}
#[test]
fn install_new_set_immediately_visible() {
let target = AttributedSet::empty();
let mut agg = Aggregator::new();
agg.add(1, set_of(&["new.example.com"]));
let _ = agg.install(&target);
assert!(target.contains(&name("new.example.com")));
assert_eq!(target.len(), 1);
}
#[test]
fn zero_sources_installs_empty_set_no_panic() {
let target: AttributedSet = [(name("previously-blocked.example.com"), 1)]
.into_iter()
.collect();
let agg = Aggregator::new();
let contributions = agg.install(&target);
assert!(
target.is_empty(),
"install of empty aggregator must clear the set"
);
assert!(contributions.is_empty());
}
#[test]
fn per_source_count_is_intra_source_deduped_size() {
let mut raw: HashSet<Name> = HashSet::new();
raw.insert("ADS.EXAMPLE.COM".parse().unwrap());
raw.insert("ads.example.com".parse().unwrap()); raw.insert("tracker.net".parse().unwrap());
assert_eq!(raw.len(), 2, "pre-condition: set has 2 deduped entries");
let mut agg = Aggregator::new();
agg.add(1, raw);
let (_, contributions) = agg.into_parts();
assert_eq!(
contributions[0].count, 2,
"count must be the deduped set size"
);
}
#[test]
fn into_parts_returns_merged_map_and_contributions() {
let mut agg = Aggregator::new();
agg.add(10, set_of(&["a.example.com", "b.example.com"]));
agg.add(20, set_of(&["b.example.com", "c.example.com"]));
let (merged, contributions) = agg.into_parts();
assert_eq!(merged.len(), 3);
assert_eq!(merged.get(&name("a.example.com")), Some(&10));
assert_eq!(merged.get(&name("b.example.com")), Some(&10)); assert_eq!(merged.get(&name("c.example.com")), Some(&20));
assert_eq!(contributions.len(), 2);
assert_eq!(
contributions[0],
SourceContribution {
source: 10,
count: 2
}
);
assert_eq!(
contributions[1],
SourceContribution {
source: 20,
count: 2
}
);
}
#[test]
fn from_iterator_builds_same_result_as_manual_add() {
let sources = vec![
(1, set_of(&["x.com", "y.com"])),
(2, set_of(&["y.com", "z.com"])),
];
let agg: Aggregator = sources.into_iter().collect();
assert_eq!(agg.len(), 3);
let (merged, contributions) = agg.into_parts();
assert_eq!(merged.get(&name("x.com")), Some(&1));
assert_eq!(merged.get(&name("y.com")), Some(&1)); assert_eq!(merged.get(&name("z.com")), Some(&2));
assert_eq!(contributions[0].source, 1);
assert_eq!(contributions[0].count, 2);
assert_eq!(contributions[1].source, 2);
assert_eq!(contributions[1].count, 2);
}
}