kaspa_mining/mempool/model/frontier/
search_tree.rs

1use super::feerate_key::FeerateTransactionKey;
2use std::iter::FusedIterator;
3use sweep_bptree::tree::visit::{DescendVisit, DescendVisitResult};
4use sweep_bptree::tree::{Argument, SearchArgument};
5use sweep_bptree::{BPlusTree, NodeStoreVec};
6
7type FeerateKey = FeerateTransactionKey;
8
9/// A struct for implementing "weight space" search using the SearchArgument customization.
10/// The weight space is the range `[0, total_weight)` and each key has a "logical" interval allocation
11/// within this space according to its tree position and weight.
12///
13/// We implement the search efficiently by maintaining subtree weights which are updated with each
14/// element insertion/removal. Given a search query `p ∈ [0, total_weight)` we then find the corresponding
15/// element in log time by walking down from the root and adjusting the query according to subtree weights.
16/// For instance if the query point is `123.56` and the top 3 subtrees have weights `120, 10.5 ,100` then we
17/// recursively query the middle subtree with the point `123.56 - 120 = 3.56`.
18///
19/// See SearchArgument implementation below for more details.
20#[derive(Clone, Copy, Debug, Default)]
21struct FeerateWeight(f64);
22
23impl FeerateWeight {
24    /// Returns the weight value
25    pub fn weight(&self) -> f64 {
26        self.0
27    }
28}
29
30impl Argument<FeerateKey> for FeerateWeight {
31    fn from_leaf(keys: &[FeerateKey]) -> Self {
32        Self(keys.iter().map(|k| k.weight()).sum())
33    }
34
35    fn from_inner(_keys: &[FeerateKey], arguments: &[Self]) -> Self {
36        Self(arguments.iter().map(|a| a.0).sum())
37    }
38}
39
40impl SearchArgument<FeerateKey> for FeerateWeight {
41    type Query = f64;
42
43    fn locate_in_leaf(query: Self::Query, keys: &[FeerateKey]) -> Option<usize> {
44        let mut sum = 0.0;
45        for (i, k) in keys.iter().enumerate() {
46            let w = k.weight();
47            sum += w;
48            if query < sum {
49                return Some(i);
50            }
51        }
52        // In order to avoid sensitivity to floating number arithmetics,
53        // we logically "clamp" the search, returning the last leaf if the query
54        // value is out of bounds
55        match keys.len() {
56            0 => None,
57            n => Some(n - 1),
58        }
59    }
60
61    fn locate_in_inner(mut query: Self::Query, _keys: &[FeerateKey], arguments: &[Self]) -> Option<(usize, Self::Query)> {
62        // Search algorithm: Locate the next subtree to visit by iterating through `arguments`
63        // and subtracting the query until the correct range is found
64        for (i, a) in arguments.iter().enumerate() {
65            if query >= a.0 {
66                query -= a.0;
67            } else {
68                return Some((i, query));
69            }
70        }
71        // In order to avoid sensitivity to floating number arithmetics,
72        // we logically "clamp" the search, returning the last subtree if the query
73        // value is out of bounds. Eventually this will lead to the return of the
74        // last leaf (see locate_in_leaf as well)
75        match arguments.len() {
76            0 => None,
77            n => Some((n - 1, arguments[n - 1].0)),
78        }
79    }
80}
81
82/// Visitor struct which accumulates the prefix weight up to a provided key (inclusive) in log time.
83///
84/// The basic idea is to use the subtree weights stored in the tree for walking down from the root
85/// to the leaf (corresponding to the searched key), and accumulating all weights proceeding the walk-down path
86struct PrefixWeightVisitor<'a> {
87    /// The key to search up to
88    key: &'a FeerateKey,
89    /// This field accumulates the prefix weight during the visit process
90    accumulated_weight: f64,
91}
92
93impl<'a> PrefixWeightVisitor<'a> {
94    pub fn new(key: &'a FeerateKey) -> Self {
95        Self { key, accumulated_weight: Default::default() }
96    }
97
98    /// Returns the index of the first `key ∈ keys` such that `key > self.key`. If no such key
99    /// exists, the returned index will be the length of `keys`.
100    fn search_in_keys(&self, keys: &[FeerateKey]) -> usize {
101        match keys.binary_search(self.key) {
102            Err(idx) => {
103                // self.key is not in keys, idx is the index of the following key
104                idx
105            }
106            Ok(idx) => {
107                // Exact match, return the following index
108                idx + 1
109            }
110        }
111    }
112}
113
114impl<'a> DescendVisit<FeerateKey, (), FeerateWeight> for PrefixWeightVisitor<'a> {
115    type Result = f64;
116
117    fn visit_inner(&mut self, keys: &[FeerateKey], arguments: &[FeerateWeight]) -> DescendVisitResult<Self::Result> {
118        let idx = self.search_in_keys(keys);
119        // Invariants:
120        //      a. arguments.len() == keys.len() + 1 (n inner node keys are the separators between n+1 subtrees)
121        //      b. idx <= keys.len() (hence idx < arguments.len())
122
123        // Based on the invariants, we first accumulate all the subtree weights up to idx
124        for argument in arguments.iter().take(idx) {
125            self.accumulated_weight += argument.weight();
126        }
127
128        // ..and then go down to the idx'th subtree
129        DescendVisitResult::GoDown(idx)
130    }
131
132    fn visit_leaf(&mut self, keys: &[FeerateKey], _values: &[()]) -> Option<Self::Result> {
133        // idx is the index of the key following self.key
134        let idx = self.search_in_keys(keys);
135        // Accumulate all key weights up to idx (which is inclusive if self.key ∈ tree)
136        for key in keys.iter().take(idx) {
137            self.accumulated_weight += key.weight();
138        }
139        // ..and return the final result
140        Some(self.accumulated_weight)
141    }
142}
143
144type InnerTree = BPlusTree<NodeStoreVec<FeerateKey, (), FeerateWeight>>;
145
146/// A transaction search tree sorted by feerate order and searchable for probabilistic weighted sampling.
147///
148/// All `log(n)` expressions below are in base 64 (based on constants chosen within the sweep_bptree crate).
149///
150/// The tree has the following properties:
151///     1. Linear time ordered access (ascending / descending)
152///     2. Insertions/removals in log(n) time
153///     3. Search for a weight point `p ∈ [0, total_weight)` in log(n) time
154///     4. Compute the prefix weight of a key, i.e., the sum of weights up to that key (inclusive)
155///        according to key order, in log(n) time
156///     5. Access the total weight in O(1) time. The total weight has numerical stability since it
157///        is recomputed from subtree weights for each item insertion/removal
158///
159/// Computing the prefix weight is a crucial operation if the tree is used for random sampling and
160/// the tree is highly imbalanced in terms of weight variance.
161/// See [`Frontier::sample_inplace()`](crate::mempool::model::frontier::Frontier::sample_inplace)
162/// for more details.  
163pub struct SearchTree {
164    tree: InnerTree,
165}
166
167impl Default for SearchTree {
168    fn default() -> Self {
169        Self { tree: InnerTree::new(Default::default()) }
170    }
171}
172
173impl SearchTree {
174    pub fn new() -> Self {
175        Self { tree: InnerTree::new(Default::default()) }
176    }
177
178    pub fn len(&self) -> usize {
179        self.tree.len()
180    }
181
182    pub fn is_empty(&self) -> bool {
183        self.len() == 0
184    }
185
186    /// Inserts a key into the tree in log(n) time. Returns `false` if the key was already in the tree.
187    pub fn insert(&mut self, key: FeerateKey) -> bool {
188        self.tree.insert(key, ()).is_none()
189    }
190
191    /// Remove a key from the tree in log(n) time. Returns `false` if the key was not in the tree.
192    pub fn remove(&mut self, key: &FeerateKey) -> bool {
193        self.tree.remove(key).is_some()
194    }
195
196    /// Search for a weight point `query ∈ [0, total_weight)` in log(n) time
197    pub fn search(&self, query: f64) -> &FeerateKey {
198        self.tree.get_by_argument(query).expect("clamped").0
199    }
200
201    /// Access the total weight in O(1) time
202    pub fn total_weight(&self) -> f64 {
203        self.tree.root_argument().weight()
204    }
205
206    /// Computes the prefix weight of a key, i.e., the sum of weights up to that key (inclusive)
207    /// according to key order, in log(n) time
208    pub fn prefix_weight(&self, key: &FeerateKey) -> f64 {
209        self.tree.descend_visit(PrefixWeightVisitor::new(key)).unwrap()
210    }
211
212    /// Iterate the tree in descending key order (going down from the
213    /// highest key). Linear in the number of keys *actually* iterated.
214    pub fn descending_iter(&self) -> impl DoubleEndedIterator<Item = &FeerateKey> + ExactSizeIterator + FusedIterator {
215        self.tree.iter().rev().map(|(key, ())| key)
216    }
217
218    /// Iterate the tree in ascending key order (going up from the
219    /// lowest key). Linear in the number of keys *actually* iterated.
220    pub fn ascending_iter(&self) -> impl DoubleEndedIterator<Item = &FeerateKey> + ExactSizeIterator + FusedIterator {
221        self.tree.iter().map(|(key, ())| key)
222    }
223
224    /// The lowest key in the tree (by key order)
225    pub fn first(&self) -> Option<&FeerateKey> {
226        self.tree.first().map(|(k, ())| k)
227    }
228
229    /// The highest key in the tree (by key order)
230    pub fn last(&self) -> Option<&FeerateKey> {
231        self.tree.last().map(|(k, ())| k)
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::super::feerate_key::tests::build_feerate_key;
238    use super::*;
239    use itertools::Itertools;
240    use std::collections::HashSet;
241    use std::ops::Sub;
242
243    #[test]
244    fn test_feerate_weight_queries() {
245        let mut tree = SearchTree::new();
246        let mass = 2000;
247        // The btree stores N=64 keys at each node/leaf, so we make sure the tree has more than
248        // 64^2 keys in order to trigger at least a few intermediate tree nodes
249        let fees = vec![[123, 113, 10_000, 1000, 2050, 2048]; 64 * (64 + 1)].into_iter().flatten().collect_vec();
250
251        #[allow(clippy::mutable_key_type)]
252        let mut s = HashSet::with_capacity(fees.len());
253        for (i, fee) in fees.iter().copied().enumerate() {
254            let key = build_feerate_key(fee, mass, i as u64);
255            s.insert(key.clone());
256            tree.insert(key);
257        }
258
259        // Randomly remove 1/6 of the items
260        let remove = s.iter().take(fees.len() / 6).cloned().collect_vec();
261        for r in remove {
262            s.remove(&r);
263            tree.remove(&r);
264        }
265
266        // Collect to vec and sort for reference
267        let mut v = s.into_iter().collect_vec();
268        v.sort();
269
270        // Test reverse iteration
271        for (expected, item) in v.iter().rev().zip(tree.descending_iter()) {
272            assert_eq!(&expected, &item);
273            assert!(expected.cmp(item).is_eq()); // Assert Ord equality as well
274        }
275
276        // Sweep through the tree and verify that weight search queries are handled correctly
277        let eps: f64 = 0.001;
278        let mut sum = 0.0;
279        for expected in v.iter() {
280            let weight = expected.weight();
281            let eps = eps.min(weight / 3.0);
282            let samples = [sum + eps, sum + weight / 2.0, sum + weight - eps];
283            for sample in samples {
284                let key = tree.search(sample);
285                assert_eq!(expected, key);
286                assert!(expected.cmp(key).is_eq()); // Assert Ord equality as well
287            }
288            sum += weight;
289        }
290
291        println!("{}, {}", sum, tree.total_weight());
292
293        // Test clamped search bounds
294        assert_eq!(tree.first(), Some(tree.search(f64::NEG_INFINITY)));
295        assert_eq!(tree.first(), Some(tree.search(-1.0)));
296        assert_eq!(tree.first(), Some(tree.search(-eps)));
297        assert_eq!(tree.first(), Some(tree.search(0.0)));
298        assert_eq!(tree.last(), Some(tree.search(sum)));
299        assert_eq!(tree.last(), Some(tree.search(sum + eps)));
300        assert_eq!(tree.last(), Some(tree.search(sum + 1.0)));
301        assert_eq!(tree.last(), Some(tree.search(1.0 / 0.0)));
302        assert_eq!(tree.last(), Some(tree.search(f64::INFINITY)));
303        let _ = tree.search(f64::NAN);
304
305        // Assert prefix weights
306        let mut prefix = Vec::with_capacity(v.len());
307        prefix.push(v[0].weight());
308        for i in 1..v.len() {
309            prefix.push(prefix[i - 1] + v[i].weight());
310        }
311        let eps = v.iter().map(|k| k.weight()).min_by(f64::total_cmp).unwrap() * 1e-4;
312        for (expected_prefix, key) in prefix.into_iter().zip(v) {
313            let prefix = tree.prefix_weight(&key);
314            assert!(expected_prefix.sub(prefix).abs() < eps);
315        }
316    }
317
318    #[test]
319    fn test_tree_rev_iter() {
320        let mut tree = SearchTree::new();
321        let mass = 2000;
322        let fees = vec![[123, 113, 10_000, 1000, 2050, 2048]; 64 * (64 + 1)].into_iter().flatten().collect_vec();
323        let mut v = Vec::with_capacity(fees.len());
324        for (i, fee) in fees.iter().copied().enumerate() {
325            let key = build_feerate_key(fee, mass, i as u64);
326            v.push(key.clone());
327            tree.insert(key);
328        }
329        v.sort();
330
331        for (expected, item) in v.into_iter().rev().zip(tree.descending_iter()) {
332            assert_eq!(&expected, item);
333            assert!(expected.cmp(item).is_eq()); // Assert Ord equality as well
334        }
335    }
336}