kaspa_mining/mempool/model/
frontier.rs

1use crate::{
2    feerate::{FeerateEstimator, FeerateEstimatorArgs},
3    model::candidate_tx::CandidateTransaction,
4    Policy, RebalancingWeightedTransactionSelector,
5};
6
7use feerate_key::FeerateTransactionKey;
8use kaspa_consensus_core::{block::TemplateTransactionSelector, tx::Transaction};
9use kaspa_core::trace;
10use rand::{distributions::Uniform, prelude::Distribution, Rng};
11use search_tree::SearchTree;
12use selectors::{SequenceSelector, SequenceSelectorInput, TakeAllSelector};
13use std::{collections::HashSet, iter::FusedIterator, sync::Arc};
14
15pub(crate) mod feerate_key;
16pub(crate) mod search_tree;
17pub(crate) mod selectors;
18
19/// If the frontier contains less than 4x the block mass limit, we consider
20/// inplace sampling to be less efficient (due to collisions) and thus use
21/// the rebalancing selector
22const COLLISION_FACTOR: u64 = 4;
23
24/// Multiplication factor for in-place sampling. We sample 20% more than the
25/// hard limit in order to allow the SequenceSelector to compensate for consensus rejections.
26const MASS_LIMIT_FACTOR: f64 = 1.2;
27
28/// A rough estimation for the average transaction mass. The usage is a non-important edge case
29/// hence we just throw this here (as oppose to performing an accurate estimation)
30const TYPICAL_TX_MASS: f64 = 2000.0;
31
32/// Management of the transaction pool frontier, that is, the set of transactions in
33/// the transaction pool which have no mempool ancestors and are essentially ready
34/// to enter the next block template.
35#[derive(Default)]
36pub struct Frontier {
37    /// Frontier transactions sorted by feerate order and searchable for weight sampling
38    search_tree: SearchTree,
39
40    /// Total masses: Σ_{tx in frontier} tx.mass
41    total_mass: u64,
42}
43
44impl Frontier {
45    pub fn total_weight(&self) -> f64 {
46        self.search_tree.total_weight()
47    }
48
49    pub fn total_mass(&self) -> u64 {
50        self.total_mass
51    }
52
53    pub fn len(&self) -> usize {
54        self.search_tree.len()
55    }
56
57    pub fn is_empty(&self) -> bool {
58        self.len() == 0
59    }
60
61    pub fn insert(&mut self, key: FeerateTransactionKey) -> bool {
62        let mass = key.mass;
63        if self.search_tree.insert(key) {
64            self.total_mass += mass;
65            true
66        } else {
67            false
68        }
69    }
70
71    pub fn remove(&mut self, key: &FeerateTransactionKey) -> bool {
72        let mass = key.mass;
73        if self.search_tree.remove(key) {
74            self.total_mass -= mass;
75            true
76        } else {
77            false
78        }
79    }
80
81    /// Samples the frontier in-place based on the provided policy and returns a SequenceSelector.
82    ///
83    /// This sampling algorithm should be used when frontier total mass is high enough compared to
84    /// policy mass limit so that the probability of sampling collisions remains low.
85    ///
86    /// Convergence analysis:
87    ///     1. Based on the above we can safely assume that `k << n`, where `n` is the total number of frontier items
88    ///        and `k` is the number of actual samples (since `desired_mass << total_mass` and mass per item is bounded)
89    ///     2. Indeed, if the weight distribution is not too spread (i.e., `max(weights) = O(min(weights))`), `k << n` means
90    ///        that the probability of collisions is low enough and the sampling process will converge in `O(k log(n))` w.h.p.
91    ///     3. It remains to deal with the case where the weight distribution is highly biased. The process implemented below
92    ///        keeps track of the top-weight element. If the distribution is highly biased, this element will be sampled with
93    ///        sufficient probability (in constant time). Following each sampling collision we search for a consecutive range of
94    ///        top elements which were already sampled and narrow the sampling space to exclude them all. We do this by computing
95    ///        the prefix weight up to the top most item which wasn't sampled yet (inclusive) and then continue the sampling process
96    ///        over the narrowed space. This process is repeated until acquiring the desired mass.  
97    ///     4. Numerical stability. Naively, one would simply subtract `total_weight -= top.weight` in order to narrow the sampling
98    ///        space. However, if `top.weight` is much larger than the remaining weight, the above f64 subtraction will yield a number
99    ///        close or equal to zero. We fix this by implementing a `log(n)` prefix weight operation.
100    ///     5. Q. Why not just use u64 weights?
101    ///        A. The current weight calculation is `feerate^alpha` with `alpha=3`. Using u64 would mean that the feerate space
102    ///           is limited to a range of size `(2^64)^(1/3) = ~2^21 = ~2M`. Already with current usages, the feerate can vary
103    ///           from `~1/50` (2000 sompi for a transaction with 100K storage mass), to `5M` (100 KAS fee for a transaction with
104    ///           2000 mass = 100·100_000_000/2000), resulting in a range of size 250M (`5M/(1/50)`).
105    ///           By using floating point arithmetics we gain the adjustment of the probability space to the accuracy level required for
106    ///           current samples. And if the space is highly biased, the repeated elimination of top items and the prefix weight computation
107    ///           will readjust it.
108    pub fn sample_inplace<R>(&self, rng: &mut R, policy: &Policy, _collisions: &mut u64) -> SequenceSelectorInput
109    where
110        R: Rng + ?Sized,
111    {
112        debug_assert!(!self.search_tree.is_empty(), "expected to be called only if not empty");
113
114        // Sample 20% more than the hard limit in order to allow the SequenceSelector to
115        // compensate for consensus rejections.
116        // Note: this is a soft limit which is why the loop below might pass it if the
117        //       next sampled transaction happens to cross the bound
118        let desired_mass = (policy.max_block_mass as f64 * MASS_LIMIT_FACTOR) as u64;
119
120        let mut distr = Uniform::new(0f64, self.total_weight());
121        let mut down_iter = self.search_tree.descending_iter();
122        let mut top = down_iter.next().unwrap();
123        let mut cache = HashSet::new();
124        let mut sequence = SequenceSelectorInput::default();
125        let mut total_selected_mass: u64 = 0;
126        let mut collisions = 0;
127
128        // The sampling process is converging so the cache will eventually hold all entries, which guarantees loop exit
129        'outer: while cache.len() < self.search_tree.len() && total_selected_mass <= desired_mass {
130            let query = distr.sample(rng);
131            let item = {
132                let mut item = self.search_tree.search(query);
133                while !cache.insert(item.tx.id()) {
134                    collisions += 1;
135                    // Try to narrow the sampling space in order to reduce further sampling collisions
136                    if cache.contains(&top.tx.id()) {
137                        loop {
138                            match down_iter.next() {
139                                Some(next) => top = next,
140                                None => break 'outer,
141                            }
142                            // Loop until finding a top item which was not sampled yet
143                            if !cache.contains(&top.tx.id()) {
144                                break;
145                            }
146                        }
147                        let remaining_weight = self.search_tree.prefix_weight(top);
148                        distr = Uniform::new(0f64, remaining_weight);
149                    }
150                    let query = distr.sample(rng);
151                    item = self.search_tree.search(query);
152                }
153                item
154            };
155            sequence.push(item.tx.clone(), item.mass);
156            total_selected_mass += item.mass; // Max standard mass + Mempool capacity bound imply this will not overflow
157        }
158        trace!("[mempool frontier sample inplace] collisions: {collisions}, cache: {}", cache.len());
159        *_collisions += collisions;
160        sequence
161    }
162
163    /// Dynamically builds a transaction selector based on the specific state of the ready transactions frontier.
164    ///
165    /// The logic is divided into three cases:
166    ///     1. The frontier is small and can fit entirely into a block: perform no sampling and return
167    ///        a TakeAllSelector
168    ///     2. The frontier has at least ~4x the capacity of a block: expected collision rate is low, perform
169    ///        in-place k*log(n) sampling and return a SequenceSelector
170    ///     3. The frontier has 1-4x capacity of a block. In this case we expect a high collision rate while
171    ///        the number of overall transactions is still low, so we take all of the transactions and use the
172    ///        rebalancing weighted selector (performing the actual sampling out of the mempool lock)
173    ///
174    /// The above thresholds were selected based on benchmarks. Overall, this dynamic selection provides
175    /// full transaction selection in less than 150 µs even if the frontier has 1M entries (!!). See mining/benches
176    /// for more details.  
177    pub fn build_selector(&self, policy: &Policy) -> Box<dyn TemplateTransactionSelector> {
178        if self.total_mass <= policy.max_block_mass {
179            Box::new(TakeAllSelector::new(self.search_tree.ascending_iter().map(|k| k.tx.clone()).collect()))
180        } else if self.total_mass > policy.max_block_mass * COLLISION_FACTOR {
181            let mut rng = rand::thread_rng();
182            Box::new(SequenceSelector::new(self.sample_inplace(&mut rng, policy, &mut 0), policy.clone()))
183        } else {
184            Box::new(RebalancingWeightedTransactionSelector::new(
185                policy.clone(),
186                self.search_tree.ascending_iter().cloned().map(CandidateTransaction::from_key).collect(),
187            ))
188        }
189    }
190
191    /// Exposed for benchmarking purposes
192    pub fn build_selector_sample_inplace(&self, _collisions: &mut u64) -> Box<dyn TemplateTransactionSelector> {
193        let mut rng = rand::thread_rng();
194        let policy = Policy::new(500_000);
195        Box::new(SequenceSelector::new(self.sample_inplace(&mut rng, &policy, _collisions), policy))
196    }
197
198    /// Exposed for benchmarking purposes
199    pub fn build_selector_take_all(&self) -> Box<dyn TemplateTransactionSelector> {
200        Box::new(TakeAllSelector::new(self.search_tree.ascending_iter().map(|k| k.tx.clone()).collect()))
201    }
202
203    /// Exposed for benchmarking purposes
204    pub fn build_rebalancing_selector(&self) -> Box<dyn TemplateTransactionSelector> {
205        Box::new(RebalancingWeightedTransactionSelector::new(
206            Policy::new(500_000),
207            self.search_tree.ascending_iter().cloned().map(CandidateTransaction::from_key).collect(),
208        ))
209    }
210
211    /// Builds a feerate estimator based on internal state of the ready transactions frontier
212    pub fn build_feerate_estimator(&self, args: FeerateEstimatorArgs) -> FeerateEstimator {
213        let average_transaction_mass = match self.len() {
214            0 => TYPICAL_TX_MASS,
215            n => self.total_mass() as f64 / n as f64,
216        };
217        let bps = args.network_blocks_per_second as f64;
218        let mut mass_per_block = args.maximum_mass_per_block as f64;
219        let mut inclusion_interval = average_transaction_mass / (mass_per_block * bps);
220        let mut estimator = FeerateEstimator::new(self.total_weight(), inclusion_interval);
221
222        // Search for better estimators by possibly removing extremely high outliers
223        let mut down_iter = self.search_tree.descending_iter().peekable();
224        while let Some(current) = down_iter.next() {
225            // Update values for the coming iteration. In order to remove the outlier from the
226            // total weight, we must compensate by capturing a block slot. Note we capture the
227            // slot with correspondence to the outlier actual mass. This is important in cases
228            // where the high-feerate txs have mass which deviates from the average.
229            mass_per_block -= current.mass as f64;
230            if mass_per_block <= average_transaction_mass {
231                // Out of block slots, break
232                break;
233            }
234
235            // Re-calc the inclusion interval based on the new block "capacity".
236            // Note that inclusion_interval < 1.0 as required by the estimator, since mass_per_block > average_transaction_mass (by condition above) and bps >= 1
237            inclusion_interval = average_transaction_mass / (mass_per_block * bps);
238
239            // Compute the weight up to, and excluding, current key (which translates to zero weight if peek() is none)
240            let prefix_weight = down_iter.peek().map(|key| self.search_tree.prefix_weight(key)).unwrap_or_default();
241            let pending_estimator = FeerateEstimator::new(prefix_weight, inclusion_interval);
242
243            // Test the pending estimator vs. the current one
244            if pending_estimator.feerate_to_time(1.0) < estimator.feerate_to_time(1.0) {
245                estimator = pending_estimator;
246            } else {
247                // The pending estimator is no better, break. Indicates that the reduction in
248                // network mass per second is more significant than the removed weight
249                break;
250            }
251        }
252        estimator
253    }
254
255    /// Returns an iterator to the transactions in the frontier in increasing feerate order
256    pub fn ascending_iter(&self) -> impl DoubleEndedIterator<Item = &Arc<Transaction>> + ExactSizeIterator + FusedIterator {
257        self.search_tree.ascending_iter().map(|key| &key.tx)
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use feerate_key::tests::build_feerate_key;
265    use itertools::Itertools;
266    use rand::thread_rng;
267    use std::collections::HashMap;
268
269    #[test]
270    pub fn test_highly_irregular_sampling() {
271        let mut rng = thread_rng();
272        let cap = 1000;
273        let mut map = HashMap::with_capacity(cap);
274        for i in 0..cap as u64 {
275            let mut fee: u64 = if i % (cap as u64 / 100) == 0 { 1000000 } else { rng.gen_range(1..10000) };
276            if i == 0 {
277                // Add an extremely large fee in order to create extremely high variance
278                fee = 100_000_000 * 1_000_000; // 1M KAS
279            }
280            let mass: u64 = 1650;
281            let key = build_feerate_key(fee, mass, i);
282            map.insert(key.tx.id(), key);
283        }
284
285        let mut frontier = Frontier::default();
286        for item in map.values().cloned() {
287            frontier.insert(item).then_some(()).unwrap();
288        }
289
290        let _sample = frontier.sample_inplace(&mut rng, &Policy::new(500_000), &mut 0);
291    }
292
293    #[test]
294    pub fn test_mempool_sampling_small() {
295        let mut rng = thread_rng();
296        let cap = 2000;
297        let mut map = HashMap::with_capacity(cap);
298        for i in 0..cap as u64 {
299            let fee: u64 = rng.gen_range(1..1000000);
300            let mass: u64 = 1650;
301            let key = build_feerate_key(fee, mass, i);
302            map.insert(key.tx.id(), key);
303        }
304
305        let mut frontier = Frontier::default();
306        for item in map.values().cloned() {
307            frontier.insert(item).then_some(()).unwrap();
308        }
309
310        let mut selector = frontier.build_selector(&Policy::new(500_000));
311        selector.select_transactions().iter().map(|k| k.gas).sum::<u64>();
312
313        let mut selector = frontier.build_rebalancing_selector();
314        selector.select_transactions().iter().map(|k| k.gas).sum::<u64>();
315
316        let mut selector = frontier.build_selector_sample_inplace(&mut 0);
317        selector.select_transactions().iter().map(|k| k.gas).sum::<u64>();
318
319        let mut selector = frontier.build_selector_take_all();
320        selector.select_transactions().iter().map(|k| k.gas).sum::<u64>();
321
322        let mut selector = frontier.build_selector(&Policy::new(500_000));
323        selector.select_transactions().iter().map(|k| k.gas).sum::<u64>();
324    }
325
326    #[test]
327    pub fn test_total_mass_tracking() {
328        let mut rng = thread_rng();
329        let cap = 10000;
330        let mut map = HashMap::with_capacity(cap);
331        for i in 0..cap as u64 {
332            let fee: u64 = if i % (cap as u64 / 100) == 0 { 1000000 } else { rng.gen_range(1..10000) };
333            let mass: u64 = rng.gen_range(1..100000); // Use distinct mass values to challenge the test
334            let key = build_feerate_key(fee, mass, i);
335            map.insert(key.tx.id(), key);
336        }
337
338        let len = cap / 2;
339        let mut frontier = Frontier::default();
340        for item in map.values().take(len).cloned() {
341            frontier.insert(item).then_some(()).unwrap();
342        }
343
344        let prev_total_mass = frontier.total_mass();
345        // Assert the total mass
346        assert_eq!(frontier.total_mass(), frontier.search_tree.ascending_iter().map(|k| k.mass).sum::<u64>());
347
348        // Add a bunch of duplicates and make sure the total mass remains the same
349        let mut dup_items = frontier.search_tree.ascending_iter().take(len / 2).cloned().collect_vec();
350        for dup in dup_items.iter().cloned() {
351            (!frontier.insert(dup)).then_some(()).unwrap();
352        }
353        assert_eq!(prev_total_mass, frontier.total_mass());
354        assert_eq!(frontier.total_mass(), frontier.search_tree.ascending_iter().map(|k| k.mass).sum::<u64>());
355
356        // Remove a few elements from the map in order to randomize the iterator
357        dup_items.iter().take(10).for_each(|k| {
358            map.remove(&k.tx.id());
359        });
360
361        // Add and remove random elements some of which will be duplicate insertions and some missing removals
362        for item in map.values().step_by(2) {
363            frontier.remove(item);
364            if let Some(item2) = dup_items.pop() {
365                frontier.insert(item2);
366            }
367        }
368        assert_eq!(frontier.total_mass(), frontier.search_tree.ascending_iter().map(|k| k.mass).sum::<u64>());
369    }
370
371    #[test]
372    fn test_feerate_estimator() {
373        let mut rng = thread_rng();
374        let cap = 2000;
375        let mut map = HashMap::with_capacity(cap);
376        for i in 0..cap as u64 {
377            let mut fee: u64 = rng.gen_range(1..1000000);
378            let mass: u64 = 1650;
379            // 304 (~500,000/1650) extreme outliers is an edge case where the build estimator logic should be tested at
380            if i <= 303 {
381                // Add an extremely large fee in order to create extremely high variance
382                fee = i * 10_000_000 * 1_000_000;
383            }
384            let key = build_feerate_key(fee, mass, i);
385            map.insert(key.tx.id(), key);
386        }
387
388        for len in [0, 1, 10, 100, 200, 300, 500, 750, cap / 2, (cap * 2) / 3, (cap * 4) / 5, (cap * 5) / 6, cap] {
389            let mut frontier = Frontier::default();
390            for item in map.values().take(len).cloned() {
391                frontier.insert(item).then_some(()).unwrap();
392            }
393
394            let args = FeerateEstimatorArgs { network_blocks_per_second: 1, maximum_mass_per_block: 500_000 };
395            // We are testing that the build function actually returns and is not looping indefinitely
396            let estimator = frontier.build_feerate_estimator(args);
397            let estimations = estimator.calc_estimations(1.0);
398
399            let buckets = estimations.ordered_buckets();
400            // Test for the absence of NaN, infinite or zero values in buckets
401            for b in buckets.iter() {
402                assert!(
403                    b.feerate.is_normal() && b.feerate >= 1.0,
404                    "bucket feerate must be a finite number greater or equal to the minimum standard feerate"
405                );
406                assert!(
407                    b.estimated_seconds.is_normal() && b.estimated_seconds > 0.0,
408                    "bucket estimated seconds must be a finite number greater than zero"
409                );
410            }
411            dbg!(len, estimator);
412            dbg!(estimations);
413        }
414    }
415
416    #[test]
417    fn test_constant_feerate_estimator() {
418        const MIN_FEERATE: f64 = 1.0;
419        let cap = 20_000;
420        let mut map = HashMap::with_capacity(cap);
421        for i in 0..cap as u64 {
422            let mass: u64 = 1650;
423            let fee = (mass as f64 * MIN_FEERATE) as u64;
424            let key = build_feerate_key(fee, mass, i);
425            map.insert(key.tx.id(), key);
426        }
427
428        for len in [0, 1, 10, 100, 200, 300, 500, 750, cap / 2, (cap * 2) / 3, (cap * 4) / 5, (cap * 5) / 6, cap] {
429            println!();
430            println!("Testing a frontier with {} txs...", len.min(cap));
431            let mut frontier = Frontier::default();
432            for item in map.values().take(len).cloned() {
433                frontier.insert(item).then_some(()).unwrap();
434            }
435
436            let args = FeerateEstimatorArgs { network_blocks_per_second: 1, maximum_mass_per_block: 500_000 };
437            // We are testing that the build function actually returns and is not looping indefinitely
438            let estimator = frontier.build_feerate_estimator(args);
439            let estimations = estimator.calc_estimations(MIN_FEERATE);
440            let buckets = estimations.ordered_buckets();
441            // Test for the absence of NaN, infinite or zero values in buckets
442            for b in buckets.iter() {
443                assert!(
444                    b.feerate.is_normal() && b.feerate >= MIN_FEERATE,
445                    "bucket feerate must be a finite number greater or equal to the minimum standard feerate"
446                );
447                assert!(
448                    b.estimated_seconds.is_normal() && b.estimated_seconds > 0.0,
449                    "bucket estimated seconds must be a finite number greater than zero"
450                );
451            }
452            dbg!(len, estimator);
453            dbg!(estimations);
454        }
455    }
456
457    #[test]
458    fn test_feerate_estimator_with_low_mass_outliers() {
459        const MIN_FEERATE: f64 = 1.0;
460        const STD_FEERATE: f64 = 10.0;
461        const HIGH_FEERATE: f64 = 1000.0;
462
463        let cap = 20_000;
464        let mut frontier = Frontier::default();
465        for i in 0..cap as u64 {
466            let (mass, fee) = if i < 200 {
467                let mass = 1650;
468                (mass, (HIGH_FEERATE * mass as f64) as u64)
469            } else {
470                let mass = 90_000;
471                (mass, (STD_FEERATE * mass as f64) as u64)
472            };
473            let key = build_feerate_key(fee, mass, i);
474            frontier.insert(key).then_some(()).unwrap();
475        }
476
477        let args = FeerateEstimatorArgs { network_blocks_per_second: 1, maximum_mass_per_block: 500_000 };
478        // We are testing that the build function actually returns and is not looping indefinitely
479        let estimator = frontier.build_feerate_estimator(args);
480        let estimations = estimator.calc_estimations(MIN_FEERATE);
481
482        // Test that estimations are not biased by the average high mass
483        let normal_feerate = estimations.normal_buckets.first().unwrap().feerate;
484        assert!(
485            normal_feerate < HIGH_FEERATE / 10.0,
486            "Normal bucket feerate is expected to be << high feerate due to small mass of high feerate txs ({}, {})",
487            normal_feerate,
488            HIGH_FEERATE
489        );
490
491        let buckets = estimations.ordered_buckets();
492        // Test for the absence of NaN, infinite or zero values in buckets
493        for b in buckets.iter() {
494            assert!(
495                b.feerate.is_normal() && b.feerate >= MIN_FEERATE,
496                "bucket feerate must be a finite number greater or equal to the minimum standard feerate"
497            );
498            assert!(
499                b.estimated_seconds.is_normal() && b.estimated_seconds > 0.0,
500                "bucket estimated seconds must be a finite number greater than zero"
501            );
502        }
503        dbg!(estimator);
504        dbg!(estimations);
505    }
506
507    #[test]
508    fn test_feerate_estimator_with_less_than_block_capacity() {
509        let mut map = HashMap::new();
510        for i in 0..304 {
511            let mass: u64 = 1650;
512            let fee = 10_000_000 * 1_000_000;
513            let key = build_feerate_key(fee, mass, i);
514            map.insert(key.tx.id(), key);
515        }
516
517        // All lens make for less than block capacity (given the mass used)
518        for len in [0, 1, 10, 100, 200, 250, 300] {
519            let mut frontier = Frontier::default();
520            for item in map.values().take(len).cloned() {
521                frontier.insert(item).then_some(()).unwrap();
522            }
523
524            let args = FeerateEstimatorArgs { network_blocks_per_second: 1, maximum_mass_per_block: 500_000 };
525            // We are testing that the build function actually returns and is not looping indefinitely
526            let estimator = frontier.build_feerate_estimator(args);
527            let estimations = estimator.calc_estimations(1.0);
528
529            let buckets = estimations.ordered_buckets();
530            // Test for the absence of NaN, infinite or zero values in buckets
531            for b in buckets.iter() {
532                // Expect min feerate bcs blocks are not full
533                assert!(b.feerate == 1.0, "bucket feerate is expected to be equal to the minimum standard feerate");
534                assert!(
535                    b.estimated_seconds.is_normal() && b.estimated_seconds > 0.0 && b.estimated_seconds <= 1.0,
536                    "bucket estimated seconds must be a finite number greater than zero & less than 1.0"
537                );
538            }
539            dbg!(len, estimator);
540            dbg!(estimations);
541        }
542    }
543}