use std::sync::atomic::{AtomicU8, Ordering};
use dashmap::DashMap;
use xxhash_rust::xxh3::xxh3_64;
use super::{ChannelHash, ChannelName};
#[derive(Debug)]
pub struct BloomCache {
bloom: Vec<AtomicU8>,
bloom_mask: u64,
}
impl BloomCache {
pub fn new() -> Self {
let num_bytes = 1usize << (BLOOM_BITS - 3);
let bloom = (0..num_bytes).map(|_| AtomicU8::new(0)).collect();
Self {
bloom,
bloom_mask: (1u64 << BLOOM_BITS) - 1,
}
}
#[inline]
fn indices(&self, origin_hash: u64, channel_hash: ChannelHash) -> (usize, usize) {
let key = bloom_key(origin_hash, channel_hash);
let h1 = (key & self.bloom_mask) as usize;
let h2 = ((key >> BLOOM_BITS) & self.bloom_mask) as usize;
(h1, h2)
}
#[inline]
pub fn mark(&self, origin_hash: u64, channel_hash: ChannelHash) {
let (h1, h2) = self.indices(origin_hash, channel_hash);
self.set_bit(h1);
self.set_bit(h2);
}
#[inline]
pub fn probe(&self, origin_hash: u64, channel_hash: ChannelHash) -> bool {
let (h1, h2) = self.indices(origin_hash, channel_hash);
let bit1 = (self.bloom[h1 >> 3].load(Ordering::Relaxed) >> (h1 & 7)) & 1;
let bit2 = (self.bloom[h2 >> 3].load(Ordering::Relaxed) >> (h2 & 7)) & 1;
bit1 != 0 && bit2 != 0
}
pub fn clear(&mut self) {
for byte in &self.bloom {
byte.store(0, Ordering::Relaxed);
}
}
#[inline]
fn set_bit(&self, bit_index: usize) {
let byte_index = bit_index >> 3;
let bit_offset = bit_index & 7;
self.bloom[byte_index].fetch_or(1 << bit_offset, Ordering::Relaxed);
}
pub fn len(&self) -> usize {
self.bloom.len()
}
}
impl Default for BloomCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuthVerdict {
Allowed,
Denied,
NeedsFullCheck,
}
pub struct AuthGuard {
bloom: BloomCache,
verified: DashMap<(u64, ChannelHash), bool>,
exact: DashMap<(u64, ChannelName), ()>,
revocations_since_rebuild: std::sync::atomic::AtomicU64,
}
const BLOOM_BITS: u32 = 15;
impl AuthGuard {
pub fn new() -> Self {
Self {
bloom: BloomCache::new(),
verified: DashMap::new(),
exact: DashMap::new(),
revocations_since_rebuild: std::sync::atomic::AtomicU64::new(0),
}
}
#[inline]
pub fn check_fast(&self, origin_hash: u64, channel_hash: ChannelHash) -> AuthVerdict {
if !self.bloom.probe(origin_hash, channel_hash) {
return AuthVerdict::Denied;
}
if self.verified.contains_key(&(origin_hash, channel_hash)) {
AuthVerdict::Allowed
} else {
AuthVerdict::NeedsFullCheck
}
}
pub fn authorize(&self, origin_hash: u64, channel_hash: ChannelHash) {
self.bloom.mark(origin_hash, channel_hash);
self.verified.insert((origin_hash, channel_hash), true);
}
pub fn revoke(&self, origin_hash: u64, channel_hash: ChannelHash) {
self.verified.remove(&(origin_hash, channel_hash));
self.revocations_since_rebuild
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
#[inline]
pub fn revocations_since_rebuild(&self) -> u64 {
self.revocations_since_rebuild
.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn is_authorized(&self, origin_hash: u64, channel_hash: ChannelHash) -> bool {
self.verified.contains_key(&(origin_hash, channel_hash))
}
pub fn allow_channel(&self, origin_hash: u64, name: &ChannelName) {
self.exact.insert((origin_hash, name.clone()), ());
self.authorize(origin_hash, name.hash());
}
pub fn revoke_channel(&self, origin_hash: u64, name: &ChannelName) {
self.exact.remove(&(origin_hash, name.clone()));
self.revoke(origin_hash, name.hash());
}
pub fn is_authorized_full(&self, origin_hash: u64, name: &ChannelName) -> bool {
self.exact.contains_key(&(origin_hash, name.clone()))
}
pub fn authorized_count(&self) -> usize {
self.verified.len()
}
pub fn exact_authorized_count(&self) -> usize {
self.exact.len()
}
pub fn rebuild_bloom(&mut self) {
self.bloom.clear();
for entry in self.verified.iter() {
let (origin_hash, channel_hash) = *entry.key();
self.bloom.mark(origin_hash, channel_hash);
}
self.revocations_since_rebuild
.store(0, std::sync::atomic::Ordering::Relaxed);
}
}
impl Default for AuthGuard {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for AuthGuard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthGuard")
.field("bloom_size_bytes", &self.bloom.len())
.field("authorized_pairs", &self.verified.len())
.field("exact_authorized_pairs", &self.exact.len())
.finish()
}
}
#[inline]
fn bloom_key(origin_hash: u64, channel_hash: ChannelHash) -> u64 {
let mut buf = [0u8; 16];
buf[0..8].copy_from_slice(&origin_hash.to_le_bytes());
buf[8..16].copy_from_slice(&channel_hash.to_le_bytes());
xxh3_64(&buf)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_guard_denies() {
let guard = AuthGuard::new();
assert_eq!(guard.check_fast(0x1234, 0xABCD), AuthVerdict::Denied);
}
#[test]
fn test_authorize_then_allow() {
let guard = AuthGuard::new();
guard.authorize(0x1234, 0xABCD);
assert_eq!(guard.check_fast(0x1234, 0xABCD), AuthVerdict::Allowed);
}
#[test]
fn test_different_pair_denied() {
let guard = AuthGuard::new();
guard.authorize(0x1234, 0xABCD);
assert_ne!(guard.check_fast(0x5678, 0xABCD), AuthVerdict::Allowed);
assert_ne!(guard.check_fast(0x1234, 0x1111), AuthVerdict::Allowed);
}
#[test]
fn test_revoke() {
let guard = AuthGuard::new();
guard.authorize(0x1234, 0xABCD);
assert_eq!(guard.check_fast(0x1234, 0xABCD), AuthVerdict::Allowed);
guard.revoke(0x1234, 0xABCD);
assert_eq!(
guard.check_fast(0x1234, 0xABCD),
AuthVerdict::NeedsFullCheck
);
}
#[test]
fn test_rebuild_bloom_after_revoke() {
let mut guard = AuthGuard::new();
guard.authorize(0x1234, 0xABCD);
guard.authorize(0x5678, 0xBEEF);
guard.revoke(0x1234, 0xABCD);
guard.rebuild_bloom();
assert_eq!(guard.check_fast(0x1234, 0xABCD), AuthVerdict::Denied);
assert_eq!(guard.check_fast(0x5678, 0xBEEF), AuthVerdict::Allowed);
}
#[test]
fn test_multiple_authorizations() {
let guard = AuthGuard::new();
for i in 0..100u64 {
guard.authorize(i, (i * 7) as ChannelHash);
}
assert_eq!(guard.authorized_count(), 100);
for i in 0..100u64 {
assert_eq!(
guard.check_fast(i, (i * 7) as ChannelHash),
AuthVerdict::Allowed,
"pair ({}, {}) should be allowed",
i,
i * 7
);
}
}
#[test]
fn test_is_authorized() {
let guard = AuthGuard::new();
assert!(!guard.is_authorized(0x1234, 0xABCD));
guard.authorize(0x1234, 0xABCD);
assert!(guard.is_authorized(0x1234, 0xABCD));
guard.revoke(0x1234, 0xABCD);
assert!(!guard.is_authorized(0x1234, 0xABCD));
}
#[test]
fn test_bloom_false_positive_rate() {
let guard = AuthGuard::new();
for i in 0..1000u64 {
guard.authorize(i, i as ChannelHash);
}
let mut false_positives = 0;
for i in 10000..20000u64 {
let verdict = guard.check_fast(i, i as ChannelHash);
if verdict != AuthVerdict::Denied {
false_positives += 1;
}
}
let fp_rate = false_positives as f64 / 10000.0;
assert!(fp_rate < 0.01, "false positive rate {} exceeds 1%", fp_rate);
}
#[test]
fn test_regression_u64_origin_hash_defeats_32bit_collision() {
let guard = AuthGuard::new();
let name = ChannelName::new("regression-u64-origin").unwrap();
let legit: u64 = 0x0000_ABCD_1234_5678;
let forged: u64 = 0xFFFF_FFFF_1234_5678; assert_eq!(legit as u32, forged as u32);
assert_ne!(legit, forged);
guard.allow_channel(legit, &name);
assert_eq!(
guard.check_fast(legit, name.hash()),
AuthVerdict::Allowed,
"legit subscriber must be admitted"
);
assert!(guard.is_authorized_full(legit, &name));
assert_ne!(
guard.check_fast(forged, name.hash()),
AuthVerdict::Allowed,
"forged subscriber must not ride the legit grant"
);
assert!(!guard.is_authorized_full(forged, &name));
}
#[test]
fn test_regression_channel_hash_collision_distinguishable_by_exact_name() {
let guard = AuthGuard::new();
let base = "regression/coll-";
let mut name_a: Option<ChannelName> = None;
let mut name_b: Option<ChannelName> = None;
'outer: for i in 0..200_000u32 {
let cand = ChannelName::new(&format!("{base}{i}")).unwrap();
if name_a.is_none() {
name_a = Some(cand);
continue;
}
if cand.wire_hash() == name_a.as_ref().unwrap().wire_hash()
&& cand.as_str() != name_a.as_ref().unwrap().as_str()
{
name_b = Some(cand);
break 'outer;
}
}
let name_a = name_a.expect("seeded name");
let name_b = name_b.expect(
"two distinct ChannelNames with the same 16-bit wire hash — widen the search range",
);
assert_eq!(name_a.wire_hash(), name_b.wire_hash());
assert_ne!(name_a.as_str(), name_b.as_str());
let origin: u64 = 0xDEAD_BEEF_CAFE_F00D;
guard.allow_channel(origin, &name_a);
let fast_b = guard.check_fast(origin, name_b.hash());
if name_a.hash() == name_b.hash() {
assert_eq!(fast_b, AuthVerdict::Allowed);
} else {
assert_eq!(fast_b, AuthVerdict::Denied);
}
assert!(guard.is_authorized_full(origin, &name_a));
assert!(!guard.is_authorized_full(origin, &name_b));
}
#[test]
fn test_regression_concurrent_authorize_and_check() {
use std::sync::Arc;
use std::thread;
let guard = Arc::new(AuthGuard::new());
let mut handles = Vec::new();
for t in 0..4u64 {
let g = Arc::clone(&guard);
handles.push(thread::spawn(move || {
for i in 0..250u64 {
g.authorize(t * 1000 + i, (t * 1000 + i) as ChannelHash);
}
}));
}
for _ in 0..4 {
let g = Arc::clone(&guard);
handles.push(thread::spawn(move || {
for i in 0..1000u64 {
let _ = g.check_fast(i, i as ChannelHash);
}
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(guard.authorized_count(), 1000);
for t in 0..4u64 {
for i in 0..250u64 {
assert!(
guard.is_authorized(t * 1000 + i, (t * 1000 + i) as ChannelHash),
"pair ({}, {}) should be authorized after concurrent insertion",
t * 1000 + i,
t * 1000 + i
);
}
}
}
#[test]
fn concurrent_authorize_and_revoke_on_same_key_is_panic_free() {
use std::sync::{Arc, Barrier};
use std::thread;
let guard = Arc::new(AuthGuard::new());
let origin = 0x1234_5678_9ABC_DEF0u64;
let channel: ChannelHash = 0x4242_4242;
let iters = 1_000u32;
let start = Arc::new(Barrier::new(3));
let authorizer = {
let guard = guard.clone();
let start = start.clone();
thread::spawn(move || {
start.wait();
for _ in 0..iters {
guard.authorize(origin, channel);
}
})
};
let revoker = {
let guard = guard.clone();
let start = start.clone();
thread::spawn(move || {
start.wait();
for _ in 0..iters {
guard.revoke(origin, channel);
}
})
};
let observer = {
let guard = guard.clone();
let start = start.clone();
thread::spawn(move || {
start.wait();
for _ in 0..iters {
let _ = guard.is_authorized(origin, channel);
let _ = guard.check_fast(origin, channel);
}
})
};
authorizer.join().expect("authorizer panicked");
revoker.join().expect("revoker panicked");
observer.join().expect("observer panicked");
let final_state = guard.is_authorized(origin, channel);
assert_eq!(
final_state,
guard.is_authorized(origin, channel),
"two sequential is_authorized calls must agree — \
torn read would indicate DashMap corruption",
);
let count = guard.authorized_count();
assert!(
count == 0 || count == 1,
"authorized_count should be 0 or 1 after the race; got {count}",
);
}
#[test]
fn concurrent_allow_and_revoke_channel_on_same_key_is_panic_free() {
use std::sync::{Arc, Barrier};
use std::thread;
let guard = Arc::new(AuthGuard::new());
let origin = 0xDEAD_BEEF_FEED_CAFEu64;
let name = ChannelName::new("auth/contended").expect("channel name");
let iters = 1_000u32;
let start = Arc::new(Barrier::new(3));
let allower = {
let guard = guard.clone();
let name = name.clone();
let start = start.clone();
thread::spawn(move || {
start.wait();
for _ in 0..iters {
guard.allow_channel(origin, &name);
}
})
};
let revoker = {
let guard = guard.clone();
let name = name.clone();
let start = start.clone();
thread::spawn(move || {
start.wait();
for _ in 0..iters {
guard.revoke_channel(origin, &name);
}
})
};
let observer = {
let guard = guard.clone();
let name = name.clone();
let start = start.clone();
thread::spawn(move || {
start.wait();
for _ in 0..iters {
let _ = guard.is_authorized_full(origin, &name);
}
})
};
allower.join().expect("allower panicked");
revoker.join().expect("revoker panicked");
observer.join().expect("observer panicked");
let final_state = guard.is_authorized_full(origin, &name);
assert_eq!(
final_state,
guard.is_authorized_full(origin, &name),
"sequential is_authorized_full reads must agree",
);
}
}