use {
super::{
error::{CuckooBuildError, TableFullError},
filter::CuckooFilter,
},
crate::geyser::{
CuckooFilter as ProtoCuckooFilter, SubscribeRequest, SubscribeRequestFilterAccounts,
},
solana_pubkey::Pubkey,
std::collections::HashSet,
};
pub struct CompressedAccountFilterSet {
items: HashSet<[u8; 32]>,
filter: CuckooFilter<[u8; 32]>,
dirty: bool,
}
impl CompressedAccountFilterSet {
pub fn with_capacity(max_capacity: usize) -> Result<Self, CuckooBuildError> {
let filter = CuckooFilter::with_capacity(max_capacity)?;
let mut items: HashSet<[u8; 32]> = HashSet::new();
items
.try_reserve(max_capacity)
.map_err(|_| CuckooBuildError::CapacityOverflow)?;
Ok(Self {
items,
filter,
dirty: false,
})
}
pub fn insert(&mut self, key: Pubkey) -> Result<bool, TableFullError> {
let bytes = key.to_bytes();
if self.items.contains(&bytes) {
return Ok(false);
}
self.filter.insert(&bytes)?;
self.items.insert(bytes);
self.dirty = true;
Ok(true)
}
pub fn remove(&mut self, key: Pubkey) -> bool {
let bytes = key.to_bytes();
if self.items.remove(&bytes) {
self.filter.remove(&bytes);
self.dirty = true;
true
} else {
false
}
}
pub fn contains(&self, key: Pubkey) -> bool {
self.items.contains(&key.to_bytes())
}
pub fn len(&self) -> usize {
self.items.len()
}
pub fn capacity(&self) -> usize {
self.items.capacity()
}
pub fn iter(&self) -> impl Iterator<Item = &[u8; 32]> {
self.items.iter()
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
pub const fn is_dirty(&self) -> bool {
self.dirty
}
pub fn take_dirty(&mut self) -> bool {
let dirty = self.dirty;
self.dirty = false;
dirty
}
pub fn to_proto(&self) -> ProtoCuckooFilter {
ProtoCuckooFilter::from(&self.filter)
}
pub fn to_account_filter(&self) -> SubscribeRequestFilterAccounts {
SubscribeRequestFilterAccounts {
account: vec![],
owner: vec![],
filters: vec![],
nonempty_txn_signature: None,
cuckoo_accounts_filter: Some(self.to_proto()),
}
}
pub fn insert_into_subscribe_request(&mut self, req: &mut SubscribeRequest, name: &str) {
req.accounts
.insert(name.to_string(), self.to_account_filter());
self.dirty = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn key(b: u8) -> Pubkey {
Pubkey::new_from_array([b; 32])
}
#[test]
fn basic_insert_contains() {
let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
assert!(filter.insert(key(1)).unwrap());
assert!(filter.contains(key(1)));
assert!(!filter.contains(key(2)));
}
#[test]
fn insert_duplicate_returns_false() {
let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
assert!(filter.insert(key(1)).unwrap());
assert!(!filter.insert(key(1)).unwrap());
}
#[test]
fn remove_existing() {
let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
filter.insert(key(1)).unwrap();
assert!(filter.remove(key(1)));
assert!(!filter.contains(key(1)));
}
#[test]
fn remove_nonexistent_is_safe() {
let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
filter.insert(key(1)).unwrap();
assert!(!filter.remove(key(2)));
assert!(filter.contains(key(1)));
}
#[test]
fn len_and_is_empty() {
let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
assert!(filter.is_empty());
assert_eq!(filter.len(), 0);
filter.insert(key(1)).unwrap();
filter.insert(key(2)).unwrap();
assert!(!filter.is_empty());
assert_eq!(filter.len(), 2);
filter.remove(key(1));
assert_eq!(filter.len(), 1);
}
#[test]
fn to_proto_round_trip() {
let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
filter.insert(key(1)).unwrap();
filter.insert(key(2)).unwrap();
let proto = filter.to_proto();
assert!(!proto.data.is_empty());
assert!(proto.bucket_count > 0);
}
#[test]
fn to_account_filter_carries_cuckoo_and_no_other_matchers() {
let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
filter.insert(key(1)).unwrap();
let f = filter.to_account_filter();
assert!(f.cuckoo_accounts_filter.is_some());
assert!(f.account.is_empty());
assert!(f.owner.is_empty());
assert!(f.filters.is_empty());
assert_eq!(f.nonempty_txn_signature, None);
}
#[test]
fn insert_into_subscribe_request_uses_given_name_and_preserves_other_filters() {
let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
filter.insert(key(1)).unwrap();
let mut req = SubscribeRequest::default();
req.accounts.insert(
"pre_existing".to_string(),
SubscribeRequestFilterAccounts::default(),
);
filter.insert_into_subscribe_request(&mut req, "tracked_accounts");
assert!(req.accounts.contains_key("tracked_accounts"));
assert!(req.accounts.contains_key("pre_existing"));
assert_eq!(req.accounts.len(), 2);
assert!(req
.accounts
.get("tracked_accounts")
.unwrap()
.cuckoo_accounts_filter
.is_some());
}
#[test]
fn insert_into_subscribe_request_clears_dirty_flag() {
let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
filter.insert(key(1)).unwrap();
assert!(filter.is_dirty());
let mut req = SubscribeRequest::default();
filter.insert_into_subscribe_request(&mut req, "accounts");
assert!(!filter.is_dirty());
}
#[test]
fn pubkey_like_usage() {
let mut filter = CompressedAccountFilterSet::with_capacity(1000).unwrap();
for i in 0..100u8 {
filter.insert(key(i)).unwrap();
}
assert_eq!(filter.len(), 100);
for i in 0..100u8 {
assert!(filter.contains(key(i)));
}
assert!(!filter.contains(key(255)));
}
#[test]
fn capacity_overflow() {
let result = CompressedAccountFilterSet::with_capacity(usize::MAX);
assert!(matches!(result, Err(CuckooBuildError::CapacityOverflow)));
}
}