use std::{collections::HashSet, fmt, sync::Arc};
use arc_swap::{ArcSwap, Guard};
use crate::codec::name::Name;
pub struct MatchSet {
inner: ArcSwap<HashSet<Name>>,
}
impl MatchSet {
#[must_use]
pub fn new(set: HashSet<Name>) -> Self {
Self {
inner: ArcSwap::from_pointee(set),
}
}
#[must_use]
pub fn empty() -> Self {
Self::new(HashSet::new())
}
#[must_use]
pub fn contains(&self, name: &Name) -> bool {
self.inner.load().contains(name)
}
#[must_use]
pub fn snapshot(&self) -> Guard<Arc<HashSet<Name>>> {
self.inner.load()
}
#[must_use]
pub fn load_full(&self) -> Arc<HashSet<Name>> {
self.inner.load_full()
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.load().len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.load().is_empty()
}
pub fn store(&self, set: HashSet<Name>) {
self.store_arc(Arc::new(set));
}
pub fn store_arc(&self, arc: Arc<HashSet<Name>>) {
self.inner.store(arc);
}
}
impl Default for MatchSet {
fn default() -> Self {
Self::empty()
}
}
impl fmt::Debug for MatchSet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let snap = self.inner.load();
f.debug_struct("MatchSet")
.field("len", &snap.len())
.finish()
}
}
impl FromIterator<Name> for MatchSet {
fn from_iter<I: IntoIterator<Item = Name>>(iter: I) -> Self {
Self::new(iter.into_iter().collect())
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashSet, sync::Arc, thread};
use super::*;
fn name(s: &str) -> Name {
s.parse().expect("valid domain name")
}
fn set_of(names: &[&str]) -> HashSet<Name> {
names.iter().map(|s| name(s)).collect()
}
#[test]
fn empty_is_empty() {
let ms = MatchSet::empty();
assert!(ms.is_empty());
assert_eq!(ms.len(), 0);
}
#[test]
fn default_is_empty() {
let ms = MatchSet::default();
assert!(ms.is_empty());
}
#[test]
fn new_with_set() {
let ms = MatchSet::new(set_of(&["example.com", "ads.example.net"]));
assert_eq!(ms.len(), 2);
assert!(!ms.is_empty());
}
#[test]
fn from_iterator() {
let ms: MatchSet = ["a.com", "b.com", "c.com"]
.iter()
.map(|s| name(s))
.collect();
assert_eq!(ms.len(), 3);
}
#[test]
fn contains_hit() {
let ms = MatchSet::new(set_of(&["blocked.example.com", "ads.tracker.net"]));
assert!(ms.contains(&name("blocked.example.com")));
assert!(ms.contains(&name("ads.tracker.net")));
}
#[test]
fn contains_miss() {
let ms = MatchSet::new(set_of(&["blocked.example.com"]));
assert!(!ms.contains(&name("safe.example.com")));
assert!(!ms.contains(&name("example.com")));
assert!(!ms.contains(&name("other.net")));
}
#[test]
fn contains_case_insensitive() {
let ms = MatchSet::new(set_of(&["Blocked.Example.COM"]));
assert!(ms.contains(&name("blocked.example.com")));
assert!(ms.contains(&name("BLOCKED.EXAMPLE.COM")));
}
#[test]
fn empty_contains_nothing() {
let ms = MatchSet::empty();
assert!(!ms.contains(&name("example.com")));
}
#[test]
fn snapshot_reflects_current_set() {
let ms = MatchSet::new(set_of(&["a.com", "b.com"]));
let snap = ms.snapshot();
assert_eq!(snap.len(), 2);
assert!(snap.contains(&name("a.com")));
}
#[test]
fn load_full_reflects_current_set() {
let ms = MatchSet::new(set_of(&["x.org"]));
let arc = ms.load_full();
assert!(arc.contains(&name("x.org")));
}
#[test]
fn store_replaces_set() {
let ms = MatchSet::new(set_of(&["old.com"]));
assert!(ms.contains(&name("old.com")));
ms.store(set_of(&["new.com", "other.net"]));
assert!(ms.contains(&name("new.com")));
assert!(ms.contains(&name("other.net")));
assert!(!ms.contains(&name("old.com")));
assert_eq!(ms.len(), 2);
}
#[test]
fn store_arc_replaces_set() {
let ms = MatchSet::new(set_of(&["before.com"]));
ms.store_arc(Arc::new(set_of(&["after.com"])));
assert!(ms.contains(&name("after.com")));
assert!(!ms.contains(&name("before.com")));
}
#[test]
fn store_empty_clears_set() {
let ms = MatchSet::new(set_of(&["a.com", "b.com"]));
ms.store(HashSet::new());
assert!(ms.is_empty());
assert!(!ms.contains(&name("a.com")));
}
#[test]
fn concurrent_store_and_contains_sees_consistent_snapshot() {
use std::sync::atomic::{AtomicBool, Ordering};
let v1 = Arc::new(set_of(&["alpha.com", "beta.com"]));
let v2 = Arc::new(set_of(&["gamma.com", "delta.com"]));
let ms = Arc::new(MatchSet::new((*v1).clone()));
let stop = Arc::new(AtomicBool::new(false));
const READER_THREADS: usize = 4;
let mut handles = Vec::with_capacity(READER_THREADS + 1);
for _ in 0..READER_THREADS {
let ms_r = Arc::clone(&ms);
let stop_r = Arc::clone(&stop);
let alpha = name("alpha.com");
let beta = name("beta.com");
let gamma = name("gamma.com");
let delta = name("delta.com");
handles.push(thread::spawn(move || {
while !stop_r.load(Ordering::Relaxed) {
let snap = ms_r.snapshot();
let in_v1 = snap.contains(&alpha);
let also_v1 = snap.contains(&beta);
let in_v2 = snap.contains(&gamma);
let also_v2 = snap.contains(&delta);
assert_eq!(in_v1, also_v1, "torn V1 read: alpha={in_v1} beta={also_v1}");
assert_eq!(
in_v2, also_v2,
"torn V2 read: gamma={in_v2} delta={also_v2}"
);
assert_ne!(
in_v1, in_v2,
"snapshot mixed V1 and V2: alpha={in_v1} gamma={in_v2}"
);
}
}));
}
{
let ms_w = Arc::clone(&ms);
let stop_w = Arc::clone(&stop);
let v1_w = Arc::clone(&v1);
let v2_w = Arc::clone(&v2);
handles.push(thread::spawn(move || {
for i in 0..2_000 {
if i % 2 == 0 {
ms_w.store_arc(Arc::clone(&v2_w));
} else {
ms_w.store_arc(Arc::clone(&v1_w));
}
}
stop_w.store(true, Ordering::Relaxed);
}));
}
for h in handles {
h.join().expect("thread panicked");
}
}
#[test]
fn debug_shows_len() {
let ms = MatchSet::new(set_of(&["a.com", "b.com", "c.com"]));
let s = format!("{ms:?}");
assert!(s.contains("len: 3"), "debug output was: {s}");
}
}