Skip to main content

laserstream_core_proto/cuckoo/
set.rs

1//! Safe, tracked wrapper around [`CuckooFilter`] for client-side filter construction.
2//!
3//! [`CompressedAccountFilterSet`] keeps an exact [`HashSet`] alongside the probabilistic
4//! filter, so `remove` is safe (never evicts a fingerprint-colliding item) and `contains`
5//! is exact. The cuckoo filter is used only on the wire; the server accepts false positives
6//! and clients filter locally on receipt.
7//!
8//! [`HashSet`]: std::collections::HashSet
9
10use {
11    super::{
12        error::{CuckooBuildError, TableFullError},
13        filter::CuckooFilter,
14    },
15    crate::geyser::{
16        CuckooFilter as ProtoCuckooFilter, SubscribeRequest, SubscribeRequestFilterAccounts,
17    },
18    solana_pubkey::Pubkey,
19    std::collections::HashSet,
20};
21
22/// Safe builder for cuckoo filters sent in subscribe requests.
23///
24/// Holds two parallel collections: a [`HashSet`] (exact source of truth for `contains`/`len`
25/// and the guard for writes) and a [`CuckooFilter`] (compact wire form). Writes go to both;
26/// serialization reads the filter. Strictly safer than a raw [`CuckooFilter`], whose `remove`
27/// can silently evict the wrong fingerprint-colliding item.
28///
29/// [`HashSet`]: std::collections::HashSet
30pub struct CompressedAccountFilterSet {
31    items: HashSet<[u8; 32]>,
32    filter: CuckooFilter<[u8; 32]>,
33    dirty: bool,
34}
35
36impl CompressedAccountFilterSet {
37    /// Empty map pre-sized for `max_capacity` items (both the `HashSet` and filter are
38    /// allocated up front). Errors [`CuckooBuildError::CapacityOverflow`] if it can't be allocated.
39    pub fn with_capacity(max_capacity: usize) -> Result<Self, CuckooBuildError> {
40        let filter = CuckooFilter::with_capacity(max_capacity)?;
41
42        let mut items: HashSet<[u8; 32]> = HashSet::new();
43        items
44            .try_reserve(max_capacity)
45            .map_err(|_| CuckooBuildError::CapacityOverflow)?;
46
47        Ok(Self {
48            items,
49            filter,
50            dirty: false,
51        })
52    }
53
54    /// Inserts a key. `Ok(true)` if newly added, `Ok(false)` if already present (idempotent).
55    /// Errors [`TableFullError`] if the filter is saturated (map under-sized); state is
56    /// unchanged on error.
57    pub fn insert(&mut self, key: Pubkey) -> Result<bool, TableFullError> {
58        let bytes = key.to_bytes();
59
60        if self.items.contains(&bytes) {
61            return Ok(false);
62        }
63        self.filter.insert(&bytes)?;
64        self.items.insert(bytes);
65        self.dirty = true;
66        Ok(true)
67    }
68
69    /// Removes a key; returns whether it was present. Safe unlike [`CuckooFilter::remove`]:
70    /// checks the `HashSet` first and only touches the filter when the key genuinely exists.
71    pub fn remove(&mut self, key: Pubkey) -> bool {
72        let bytes = key.to_bytes();
73
74        if self.items.remove(&bytes) {
75            self.filter.remove(&bytes);
76            self.dirty = true;
77            true
78        } else {
79            false
80        }
81    }
82
83    /// Exact membership (from the `HashSet`, no false positives), unlike [`CuckooFilter::contains`].
84    pub fn contains(&self, key: Pubkey) -> bool {
85        self.items.contains(&key.to_bytes())
86    }
87
88    /// Returns the number of items in the map.
89    pub fn len(&self) -> usize {
90        self.items.len()
91    }
92
93    /// Items the map can hold without reallocating (≥ the `with_capacity` argument; the
94    /// `HashSet` rounds up). Useful for headroom checks before a batch of inserts.
95    pub fn capacity(&self) -> usize {
96        self.items.capacity()
97    }
98
99    /// Iterates items in arbitrary `HashSet` order (no ordering guarantee).
100    pub fn iter(&self) -> impl Iterator<Item = &[u8; 32]> {
101        self.items.iter()
102    }
103
104    /// Returns `true` if the map contains no items.
105    pub fn is_empty(&self) -> bool {
106        self.items.is_empty()
107    }
108
109    /// `true` if mutated since the last [`take_dirty`](Self::take_dirty) (or construction);
110    /// does not clear the flag.
111    pub const fn is_dirty(&self) -> bool {
112        self.dirty
113    }
114
115    /// Returns the dirty flag and clears it. Call when transmitting: `true` → rebuild and send.
116    pub fn take_dirty(&mut self) -> bool {
117        let dirty = self.dirty;
118        self.dirty = false;
119        dirty
120    }
121
122    /// Serializes the cuckoo filter to proto wire format (carries bucket geometry + hash seed).
123    pub fn to_proto(&self) -> ProtoCuckooFilter {
124        ProtoCuckooFilter::from(&self.filter)
125    }
126
127    /// A `SubscribeRequestFilterAccounts` carrying only this cuckoo filter (no account list,
128    /// owner, or predicates) — add it to a request under a name of your choosing.
129    pub fn to_account_filter(&self) -> SubscribeRequestFilterAccounts {
130        SubscribeRequestFilterAccounts {
131            account: vec![],
132            owner: vec![],
133            filters: vec![],
134            nonempty_txn_signature: None,
135            cuckoo_accounts_filter: Some(self.to_proto()),
136        }
137    }
138
139    /// Inserts this filter into `req.accounts` under `name` (replacing any existing entry,
140    /// preserving others) and marks the map clean.
141    pub fn insert_into_subscribe_request(&mut self, req: &mut SubscribeRequest, name: &str) {
142        req.accounts
143            .insert(name.to_string(), self.to_account_filter());
144        self.dirty = false;
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    // helper: a 32-byte key from a single seed byte
153    fn key(b: u8) -> Pubkey {
154        Pubkey::new_from_array([b; 32])
155    }
156
157    #[test]
158    fn basic_insert_contains() {
159        let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
160        assert!(filter.insert(key(1)).unwrap());
161        assert!(filter.contains(key(1)));
162        assert!(!filter.contains(key(2)));
163    }
164
165    #[test]
166    fn insert_duplicate_returns_false() {
167        let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
168        assert!(filter.insert(key(1)).unwrap());
169        assert!(!filter.insert(key(1)).unwrap());
170    }
171
172    #[test]
173    fn remove_existing() {
174        let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
175        filter.insert(key(1)).unwrap();
176        assert!(filter.remove(key(1)));
177        assert!(!filter.contains(key(1)));
178    }
179
180    #[test]
181    fn remove_nonexistent_is_safe() {
182        let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
183        filter.insert(key(1)).unwrap();
184
185        assert!(!filter.remove(key(2)));
186        assert!(filter.contains(key(1)));
187    }
188
189    #[test]
190    fn len_and_is_empty() {
191        let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
192        assert!(filter.is_empty());
193        assert_eq!(filter.len(), 0);
194
195        filter.insert(key(1)).unwrap();
196        filter.insert(key(2)).unwrap();
197        assert!(!filter.is_empty());
198        assert_eq!(filter.len(), 2);
199
200        filter.remove(key(1));
201        assert_eq!(filter.len(), 1);
202    }
203
204    #[test]
205    fn to_proto_round_trip() {
206        let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
207        filter.insert(key(1)).unwrap();
208        filter.insert(key(2)).unwrap();
209
210        let proto = filter.to_proto();
211        assert!(!proto.data.is_empty());
212        assert!(proto.bucket_count > 0);
213    }
214
215    #[test]
216    fn to_account_filter_carries_cuckoo_and_no_other_matchers() {
217        let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
218        filter.insert(key(1)).unwrap();
219
220        let f = filter.to_account_filter();
221
222        assert!(f.cuckoo_accounts_filter.is_some());
223        assert!(f.account.is_empty());
224        assert!(f.owner.is_empty());
225        assert!(f.filters.is_empty());
226        assert_eq!(f.nonempty_txn_signature, None);
227    }
228
229    #[test]
230    fn insert_into_subscribe_request_uses_given_name_and_preserves_other_filters() {
231        let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
232        filter.insert(key(1)).unwrap();
233
234        let mut req = SubscribeRequest::default();
235        req.accounts.insert(
236            "pre_existing".to_string(),
237            SubscribeRequestFilterAccounts::default(),
238        );
239
240        filter.insert_into_subscribe_request(&mut req, "tracked_accounts");
241
242        assert!(req.accounts.contains_key("tracked_accounts"));
243        assert!(req.accounts.contains_key("pre_existing"));
244        assert_eq!(req.accounts.len(), 2);
245        assert!(req
246            .accounts
247            .get("tracked_accounts")
248            .unwrap()
249            .cuckoo_accounts_filter
250            .is_some());
251    }
252
253    #[test]
254    fn insert_into_subscribe_request_clears_dirty_flag() {
255        let mut filter = CompressedAccountFilterSet::with_capacity(100).unwrap();
256        filter.insert(key(1)).unwrap();
257        assert!(filter.is_dirty());
258
259        let mut req = SubscribeRequest::default();
260        filter.insert_into_subscribe_request(&mut req, "accounts");
261
262        assert!(!filter.is_dirty());
263    }
264
265    #[test]
266    fn pubkey_like_usage() {
267        let mut filter = CompressedAccountFilterSet::with_capacity(1000).unwrap();
268
269        for i in 0..100u8 {
270            filter.insert(key(i)).unwrap();
271        }
272
273        assert_eq!(filter.len(), 100);
274
275        for i in 0..100u8 {
276            assert!(filter.contains(key(i)));
277        }
278
279        assert!(!filter.contains(key(255)));
280    }
281
282    #[test]
283    fn capacity_overflow() {
284        let result = CompressedAccountFilterSet::with_capacity(usize::MAX);
285        assert!(matches!(result, Err(CuckooBuildError::CapacityOverflow)));
286    }
287}