use crate::config::MAX_STATEMENT_NOTIFICATION_SIZE;
use codec::{Decode, Encode};
use fastbloom::{BloomFilter, DefaultHasher as BloomDefaultHasher};
use sp_statement_store::Statement;
use std::hash::{BuildHasher, Hasher};
const MAX_BLOOM_BITS: usize = MAX_STATEMENT_NOTIFICATION_SIZE as usize * 8;
const MAX_NUM_HASHES: u32 = 64;
#[derive(Clone, Debug)]
struct PortableBuildHasher(BloomDefaultHasher);
impl PortableBuildHasher {
fn seeded(seed: u128) -> Self {
Self(BloomDefaultHasher::seeded(&seed.to_le_bytes()))
}
}
impl BuildHasher for PortableBuildHasher {
type Hasher = PortableHasher;
fn build_hasher(&self) -> Self::Hasher {
PortableHasher(self.0.build_hasher())
}
}
#[derive(Clone)]
struct PortableHasher(<BloomDefaultHasher as BuildHasher>::Hasher);
impl Hasher for PortableHasher {
#[inline]
fn finish(&self) -> u64 {
self.0.finish()
}
#[inline]
fn write(&mut self, bytes: &[u8]) {
self.0.write(bytes);
}
#[inline]
fn write_usize(&mut self, i: usize) {
self.0.write(&(i as u64).to_le_bytes());
}
#[inline]
fn write_isize(&mut self, i: isize) {
self.0.write(&(i as i64).to_le_bytes());
}
}
#[derive(Encode, Decode)]
struct EncodedBloomFilter {
seed: u128,
num_hashes: u32,
bits: Vec<u64>,
}
impl TryFrom<EncodedBloomFilter> for AffinityFilter {
type Error = &'static str;
fn try_from(encoded: EncodedBloomFilter) -> Result<Self, Self::Error> {
if encoded.bits.is_empty() {
return Err("bloom filter bits must not be empty");
}
if encoded.bits.len() * u64::BITS as usize > MAX_BLOOM_BITS {
return Err("bloom filter bits exceed maximum allowed size");
}
if encoded.num_hashes == 0 || encoded.num_hashes > MAX_NUM_HASHES {
return Err("num_hashes out of allowed range");
}
let bloom = BloomFilter::from_vec(encoded.bits)
.hasher(PortableBuildHasher::seeded(encoded.seed))
.hashes(encoded.num_hashes);
Ok(AffinityFilter { bloom, seed: encoded.seed })
}
}
#[derive(Debug)]
pub struct AffinityFilter {
bloom: BloomFilter<PortableBuildHasher>,
seed: u128,
}
impl AffinityFilter {
#[cfg(test)]
pub(crate) fn new(seed: u128, false_pos: f64, expected_items: usize) -> Self {
let bloom = BloomFilter::with_false_pos(false_pos)
.hasher(PortableBuildHasher::seeded(seed))
.expected_items(expected_items);
AffinityFilter { bloom, seed }
}
#[cfg(test)]
pub(crate) fn insert(&mut self, topic: &[u8; 32]) {
self.bloom.insert(topic);
}
pub(crate) fn contains(&self, topic: &[u8; 32]) -> bool {
self.bloom.contains(topic)
}
pub(crate) fn matches_statement(&self, statement: &Statement) -> bool {
let topics = statement.topics();
if topics.is_empty() {
return true;
}
topics.iter().any(|topic| self.contains(topic))
}
}
impl Encode for AffinityFilter {
fn encode_to<T: codec::Output + ?Sized>(&self, dest: &mut T) {
let encoded = EncodedBloomFilter {
seed: self.seed,
num_hashes: self.bloom.num_hashes(),
bits: self.bloom.as_slice().to_vec(),
};
encoded.encode_to(dest);
}
}
impl Decode for AffinityFilter {
fn decode<I: codec::Input>(input: &mut I) -> Result<Self, codec::Error> {
let encoded = EncodedBloomFilter::decode(input)?;
AffinityFilter::try_from(encoded).map_err(|e| codec::Error::from(e))
}
}
#[cfg(test)]
mod tests {
use super::*;
const BLOOM_SEED: u128 = 0x5EED_5EED_5EED_5EED;
const MAX_BLOOM_WORDS: usize = MAX_BLOOM_BITS / u64::BITS as usize;
#[test]
fn affinity_filter_encode_decode_roundtrip() {
const TOTAL: usize = 100_000;
const SET_COUNT: usize = TOTAL / 10;
let items: Vec<[u8; 32]> = (0..TOTAL)
.map(|i| {
let mut key = [0u8; 32];
key[..8].copy_from_slice(&(i as u64).to_le_bytes());
key
})
.collect();
let mut filter = AffinityFilter::new(BLOOM_SEED, 0.01, SET_COUNT);
for item in &items[..SET_COUNT] {
filter.insert(item);
}
let expected: Vec<bool> = items.iter().map(|item| filter.contains(item)).collect();
for i in 0..SET_COUNT {
assert!(expected[i], "inserted item {i} must be present");
}
let encoded = filter.encode();
let decoded =
AffinityFilter::decode(&mut encoded.as_slice()).expect("decoding should succeed");
for (i, item) in items.iter().enumerate() {
assert_eq!(decoded.contains(item), expected[i], "mismatch for item {i}");
}
assert_eq!(encoded, decoded.encode(), "re-encoding should produce identical bytes");
}
#[test]
fn affinity_filter_encoding_snapshot() {
const ITEM_COUNT: usize = 10_000;
let items: Vec<[u8; 32]> = (0..ITEM_COUNT)
.map(|i| {
let mut key = [0u8; 32];
key[..8].copy_from_slice(&(i as u64).to_le_bytes());
key
})
.collect();
let mut filter = AffinityFilter::new(BLOOM_SEED, 0.01, ITEM_COUNT);
for item in &items {
filter.insert(item);
}
let encoded = filter.encode();
assert_eq!(
sp_core::blake2_256(&encoded),
[
180, 34, 58, 78, 198, 24, 137, 83, 154, 127, 9, 152, 171, 50, 197, 27, 242, 158,
30, 79, 143, 192, 53, 151, 174, 106, 132, 105, 20, 145, 133, 0
],
"blake2_256 digest of encoded bytes must match snapshot"
);
let decoded =
AffinityFilter::decode(&mut encoded.as_slice()).expect("snapshot must decode");
for (i, item) in items.iter().enumerate() {
assert!(decoded.contains(item), "item {i} must be present after decoding");
}
let absent: [u8; 32] = [0xFF; 32];
assert!(!decoded.contains(&absent), "absent item must not match");
}
#[test]
fn matches_statement_no_topics_always_matches() {
let filter = AffinityFilter::new(BLOOM_SEED, 0.01, 10);
let mut stmt = Statement::new();
stmt.set_plain_data(b"broadcast".to_vec());
assert!(filter.matches_statement(&stmt));
}
#[test]
fn matches_statement_single_matching_topic() {
let topic: [u8; 32] = [0xAA; 32];
let mut filter = AffinityFilter::new(BLOOM_SEED, 0.01, 10);
filter.insert(&topic);
let mut stmt = Statement::new();
stmt.set_plain_data(b"matching".to_vec());
stmt.set_topic(0, topic.into());
assert!(filter.matches_statement(&stmt));
}
#[test]
fn matches_statement_single_non_matching_topic() {
let topic_in_filter: [u8; 32] = [0xAA; 32];
let topic_on_stmt: [u8; 32] = [0xBB; 32];
let mut filter = AffinityFilter::new(BLOOM_SEED, 0.01, 10);
filter.insert(&topic_in_filter);
let mut stmt = Statement::new();
stmt.set_plain_data(b"not matching".to_vec());
stmt.set_topic(0, topic_on_stmt.into());
assert!(!filter.matches_statement(&stmt));
}
#[test]
fn matches_statement_multiple_topics_any_semantics() {
let topic_aa: [u8; 32] = [0xAA; 32];
let topic_bb: [u8; 32] = [0xBB; 32];
let topic_cc: [u8; 32] = [0xCC; 32];
let mut filter = AffinityFilter::new(BLOOM_SEED, 0.01, 10);
filter.insert(&topic_bb);
let mut stmt = Statement::new();
stmt.set_plain_data(b"multi topic".to_vec());
stmt.set_topic(0, topic_aa.into());
stmt.set_topic(1, topic_bb.into());
assert!(filter.matches_statement(&stmt), "should match when ANY topic is in the filter");
let mut stmt2 = Statement::new();
stmt2.set_plain_data(b"no match multi".to_vec());
stmt2.set_topic(0, topic_aa.into());
stmt2.set_topic(1, topic_cc.into());
assert!(
!filter.matches_statement(&stmt2),
"should not match when NO topic is in the filter"
);
}
#[test]
fn decode_rejects_empty_bits() {
let encoded = EncodedBloomFilter { seed: BLOOM_SEED, num_hashes: 7, bits: vec![] };
let bytes = encoded.encode();
assert!(AffinityFilter::decode(&mut bytes.as_slice()).is_err());
}
#[test]
fn decode_rejects_oversized_bits() {
let encoded = EncodedBloomFilter {
seed: BLOOM_SEED,
num_hashes: 7,
bits: vec![0u64; MAX_BLOOM_WORDS + 1],
};
let bytes = encoded.encode();
assert!(AffinityFilter::decode(&mut bytes.as_slice()).is_err());
}
#[test]
fn decode_rejects_zero_num_hashes() {
let encoded = EncodedBloomFilter { seed: BLOOM_SEED, num_hashes: 0, bits: vec![0u64; 16] };
let bytes = encoded.encode();
assert!(AffinityFilter::decode(&mut bytes.as_slice()).is_err());
}
#[test]
fn decode_rejects_excessive_num_hashes() {
let encoded =
EncodedBloomFilter { seed: BLOOM_SEED, num_hashes: u32::MAX, bits: vec![0u64; 16] };
let bytes = encoded.encode();
assert!(AffinityFilter::decode(&mut bytes.as_slice()).is_err());
}
#[test]
fn decode_accepts_valid_bounds() {
let encoded = EncodedBloomFilter {
seed: BLOOM_SEED,
num_hashes: MAX_NUM_HASHES,
bits: vec![0u64; MAX_BLOOM_WORDS],
};
let bytes = encoded.encode();
assert!(AffinityFilter::decode(&mut bytes.as_slice()).is_ok());
}
}