kaspa_mining/mempool/model/
frontier.rs1use crate::{
2 feerate::{FeerateEstimator, FeerateEstimatorArgs},
3 model::candidate_tx::CandidateTransaction,
4 Policy, RebalancingWeightedTransactionSelector,
5};
6
7use feerate_key::FeerateTransactionKey;
8use kaspa_consensus_core::{block::TemplateTransactionSelector, tx::Transaction};
9use kaspa_core::trace;
10use rand::{distributions::Uniform, prelude::Distribution, Rng};
11use search_tree::SearchTree;
12use selectors::{SequenceSelector, SequenceSelectorInput, TakeAllSelector};
13use std::{collections::HashSet, iter::FusedIterator, sync::Arc};
14
15pub(crate) mod feerate_key;
16pub(crate) mod search_tree;
17pub(crate) mod selectors;
18
19const COLLISION_FACTOR: u64 = 4;
23
24const MASS_LIMIT_FACTOR: f64 = 1.2;
27
28const TYPICAL_TX_MASS: f64 = 2000.0;
31
32#[derive(Default)]
36pub struct Frontier {
37 search_tree: SearchTree,
39
40 total_mass: u64,
42}
43
44impl Frontier {
45 pub fn total_weight(&self) -> f64 {
46 self.search_tree.total_weight()
47 }
48
49 pub fn total_mass(&self) -> u64 {
50 self.total_mass
51 }
52
53 pub fn len(&self) -> usize {
54 self.search_tree.len()
55 }
56
57 pub fn is_empty(&self) -> bool {
58 self.len() == 0
59 }
60
61 pub fn insert(&mut self, key: FeerateTransactionKey) -> bool {
62 let mass = key.mass;
63 if self.search_tree.insert(key) {
64 self.total_mass += mass;
65 true
66 } else {
67 false
68 }
69 }
70
71 pub fn remove(&mut self, key: &FeerateTransactionKey) -> bool {
72 let mass = key.mass;
73 if self.search_tree.remove(key) {
74 self.total_mass -= mass;
75 true
76 } else {
77 false
78 }
79 }
80
81 pub fn sample_inplace<R>(&self, rng: &mut R, policy: &Policy, _collisions: &mut u64) -> SequenceSelectorInput
109 where
110 R: Rng + ?Sized,
111 {
112 debug_assert!(!self.search_tree.is_empty(), "expected to be called only if not empty");
113
114 let desired_mass = (policy.max_block_mass as f64 * MASS_LIMIT_FACTOR) as u64;
119
120 let mut distr = Uniform::new(0f64, self.total_weight());
121 let mut down_iter = self.search_tree.descending_iter();
122 let mut top = down_iter.next().unwrap();
123 let mut cache = HashSet::new();
124 let mut sequence = SequenceSelectorInput::default();
125 let mut total_selected_mass: u64 = 0;
126 let mut collisions = 0;
127
128 'outer: while cache.len() < self.search_tree.len() && total_selected_mass <= desired_mass {
130 let query = distr.sample(rng);
131 let item = {
132 let mut item = self.search_tree.search(query);
133 while !cache.insert(item.tx.id()) {
134 collisions += 1;
135 if cache.contains(&top.tx.id()) {
137 loop {
138 match down_iter.next() {
139 Some(next) => top = next,
140 None => break 'outer,
141 }
142 if !cache.contains(&top.tx.id()) {
144 break;
145 }
146 }
147 let remaining_weight = self.search_tree.prefix_weight(top);
148 distr = Uniform::new(0f64, remaining_weight);
149 }
150 let query = distr.sample(rng);
151 item = self.search_tree.search(query);
152 }
153 item
154 };
155 sequence.push(item.tx.clone(), item.mass);
156 total_selected_mass += item.mass; }
158 trace!("[mempool frontier sample inplace] collisions: {collisions}, cache: {}", cache.len());
159 *_collisions += collisions;
160 sequence
161 }
162
163 pub fn build_selector(&self, policy: &Policy) -> Box<dyn TemplateTransactionSelector> {
178 if self.total_mass <= policy.max_block_mass {
179 Box::new(TakeAllSelector::new(self.search_tree.ascending_iter().map(|k| k.tx.clone()).collect()))
180 } else if self.total_mass > policy.max_block_mass * COLLISION_FACTOR {
181 let mut rng = rand::thread_rng();
182 Box::new(SequenceSelector::new(self.sample_inplace(&mut rng, policy, &mut 0), policy.clone()))
183 } else {
184 Box::new(RebalancingWeightedTransactionSelector::new(
185 policy.clone(),
186 self.search_tree.ascending_iter().cloned().map(CandidateTransaction::from_key).collect(),
187 ))
188 }
189 }
190
191 pub fn build_selector_sample_inplace(&self, _collisions: &mut u64) -> Box<dyn TemplateTransactionSelector> {
193 let mut rng = rand::thread_rng();
194 let policy = Policy::new(500_000);
195 Box::new(SequenceSelector::new(self.sample_inplace(&mut rng, &policy, _collisions), policy))
196 }
197
198 pub fn build_selector_take_all(&self) -> Box<dyn TemplateTransactionSelector> {
200 Box::new(TakeAllSelector::new(self.search_tree.ascending_iter().map(|k| k.tx.clone()).collect()))
201 }
202
203 pub fn build_rebalancing_selector(&self) -> Box<dyn TemplateTransactionSelector> {
205 Box::new(RebalancingWeightedTransactionSelector::new(
206 Policy::new(500_000),
207 self.search_tree.ascending_iter().cloned().map(CandidateTransaction::from_key).collect(),
208 ))
209 }
210
211 pub fn build_feerate_estimator(&self, args: FeerateEstimatorArgs) -> FeerateEstimator {
213 let average_transaction_mass = match self.len() {
214 0 => TYPICAL_TX_MASS,
215 n => self.total_mass() as f64 / n as f64,
216 };
217 let bps = args.network_blocks_per_second as f64;
218 let mut mass_per_block = args.maximum_mass_per_block as f64;
219 let mut inclusion_interval = average_transaction_mass / (mass_per_block * bps);
220 let mut estimator = FeerateEstimator::new(self.total_weight(), inclusion_interval);
221
222 let mut down_iter = self.search_tree.descending_iter().peekable();
224 while let Some(current) = down_iter.next() {
225 mass_per_block -= current.mass as f64;
230 if mass_per_block <= average_transaction_mass {
231 break;
233 }
234
235 inclusion_interval = average_transaction_mass / (mass_per_block * bps);
238
239 let prefix_weight = down_iter.peek().map(|key| self.search_tree.prefix_weight(key)).unwrap_or_default();
241 let pending_estimator = FeerateEstimator::new(prefix_weight, inclusion_interval);
242
243 if pending_estimator.feerate_to_time(1.0) < estimator.feerate_to_time(1.0) {
245 estimator = pending_estimator;
246 } else {
247 break;
250 }
251 }
252 estimator
253 }
254
255 pub fn ascending_iter(&self) -> impl DoubleEndedIterator<Item = &Arc<Transaction>> + ExactSizeIterator + FusedIterator {
257 self.search_tree.ascending_iter().map(|key| &key.tx)
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use feerate_key::tests::build_feerate_key;
265 use itertools::Itertools;
266 use rand::thread_rng;
267 use std::collections::HashMap;
268
269 #[test]
270 pub fn test_highly_irregular_sampling() {
271 let mut rng = thread_rng();
272 let cap = 1000;
273 let mut map = HashMap::with_capacity(cap);
274 for i in 0..cap as u64 {
275 let mut fee: u64 = if i % (cap as u64 / 100) == 0 { 1000000 } else { rng.gen_range(1..10000) };
276 if i == 0 {
277 fee = 100_000_000 * 1_000_000; }
280 let mass: u64 = 1650;
281 let key = build_feerate_key(fee, mass, i);
282 map.insert(key.tx.id(), key);
283 }
284
285 let mut frontier = Frontier::default();
286 for item in map.values().cloned() {
287 frontier.insert(item).then_some(()).unwrap();
288 }
289
290 let _sample = frontier.sample_inplace(&mut rng, &Policy::new(500_000), &mut 0);
291 }
292
293 #[test]
294 pub fn test_mempool_sampling_small() {
295 let mut rng = thread_rng();
296 let cap = 2000;
297 let mut map = HashMap::with_capacity(cap);
298 for i in 0..cap as u64 {
299 let fee: u64 = rng.gen_range(1..1000000);
300 let mass: u64 = 1650;
301 let key = build_feerate_key(fee, mass, i);
302 map.insert(key.tx.id(), key);
303 }
304
305 let mut frontier = Frontier::default();
306 for item in map.values().cloned() {
307 frontier.insert(item).then_some(()).unwrap();
308 }
309
310 let mut selector = frontier.build_selector(&Policy::new(500_000));
311 selector.select_transactions().iter().map(|k| k.gas).sum::<u64>();
312
313 let mut selector = frontier.build_rebalancing_selector();
314 selector.select_transactions().iter().map(|k| k.gas).sum::<u64>();
315
316 let mut selector = frontier.build_selector_sample_inplace(&mut 0);
317 selector.select_transactions().iter().map(|k| k.gas).sum::<u64>();
318
319 let mut selector = frontier.build_selector_take_all();
320 selector.select_transactions().iter().map(|k| k.gas).sum::<u64>();
321
322 let mut selector = frontier.build_selector(&Policy::new(500_000));
323 selector.select_transactions().iter().map(|k| k.gas).sum::<u64>();
324 }
325
326 #[test]
327 pub fn test_total_mass_tracking() {
328 let mut rng = thread_rng();
329 let cap = 10000;
330 let mut map = HashMap::with_capacity(cap);
331 for i in 0..cap as u64 {
332 let fee: u64 = if i % (cap as u64 / 100) == 0 { 1000000 } else { rng.gen_range(1..10000) };
333 let mass: u64 = rng.gen_range(1..100000); let key = build_feerate_key(fee, mass, i);
335 map.insert(key.tx.id(), key);
336 }
337
338 let len = cap / 2;
339 let mut frontier = Frontier::default();
340 for item in map.values().take(len).cloned() {
341 frontier.insert(item).then_some(()).unwrap();
342 }
343
344 let prev_total_mass = frontier.total_mass();
345 assert_eq!(frontier.total_mass(), frontier.search_tree.ascending_iter().map(|k| k.mass).sum::<u64>());
347
348 let mut dup_items = frontier.search_tree.ascending_iter().take(len / 2).cloned().collect_vec();
350 for dup in dup_items.iter().cloned() {
351 (!frontier.insert(dup)).then_some(()).unwrap();
352 }
353 assert_eq!(prev_total_mass, frontier.total_mass());
354 assert_eq!(frontier.total_mass(), frontier.search_tree.ascending_iter().map(|k| k.mass).sum::<u64>());
355
356 dup_items.iter().take(10).for_each(|k| {
358 map.remove(&k.tx.id());
359 });
360
361 for item in map.values().step_by(2) {
363 frontier.remove(item);
364 if let Some(item2) = dup_items.pop() {
365 frontier.insert(item2);
366 }
367 }
368 assert_eq!(frontier.total_mass(), frontier.search_tree.ascending_iter().map(|k| k.mass).sum::<u64>());
369 }
370
371 #[test]
372 fn test_feerate_estimator() {
373 let mut rng = thread_rng();
374 let cap = 2000;
375 let mut map = HashMap::with_capacity(cap);
376 for i in 0..cap as u64 {
377 let mut fee: u64 = rng.gen_range(1..1000000);
378 let mass: u64 = 1650;
379 if i <= 303 {
381 fee = i * 10_000_000 * 1_000_000;
383 }
384 let key = build_feerate_key(fee, mass, i);
385 map.insert(key.tx.id(), key);
386 }
387
388 for len in [0, 1, 10, 100, 200, 300, 500, 750, cap / 2, (cap * 2) / 3, (cap * 4) / 5, (cap * 5) / 6, cap] {
389 let mut frontier = Frontier::default();
390 for item in map.values().take(len).cloned() {
391 frontier.insert(item).then_some(()).unwrap();
392 }
393
394 let args = FeerateEstimatorArgs { network_blocks_per_second: 1, maximum_mass_per_block: 500_000 };
395 let estimator = frontier.build_feerate_estimator(args);
397 let estimations = estimator.calc_estimations(1.0);
398
399 let buckets = estimations.ordered_buckets();
400 for b in buckets.iter() {
402 assert!(
403 b.feerate.is_normal() && b.feerate >= 1.0,
404 "bucket feerate must be a finite number greater or equal to the minimum standard feerate"
405 );
406 assert!(
407 b.estimated_seconds.is_normal() && b.estimated_seconds > 0.0,
408 "bucket estimated seconds must be a finite number greater than zero"
409 );
410 }
411 dbg!(len, estimator);
412 dbg!(estimations);
413 }
414 }
415
416 #[test]
417 fn test_constant_feerate_estimator() {
418 const MIN_FEERATE: f64 = 1.0;
419 let cap = 20_000;
420 let mut map = HashMap::with_capacity(cap);
421 for i in 0..cap as u64 {
422 let mass: u64 = 1650;
423 let fee = (mass as f64 * MIN_FEERATE) as u64;
424 let key = build_feerate_key(fee, mass, i);
425 map.insert(key.tx.id(), key);
426 }
427
428 for len in [0, 1, 10, 100, 200, 300, 500, 750, cap / 2, (cap * 2) / 3, (cap * 4) / 5, (cap * 5) / 6, cap] {
429 println!();
430 println!("Testing a frontier with {} txs...", len.min(cap));
431 let mut frontier = Frontier::default();
432 for item in map.values().take(len).cloned() {
433 frontier.insert(item).then_some(()).unwrap();
434 }
435
436 let args = FeerateEstimatorArgs { network_blocks_per_second: 1, maximum_mass_per_block: 500_000 };
437 let estimator = frontier.build_feerate_estimator(args);
439 let estimations = estimator.calc_estimations(MIN_FEERATE);
440 let buckets = estimations.ordered_buckets();
441 for b in buckets.iter() {
443 assert!(
444 b.feerate.is_normal() && b.feerate >= MIN_FEERATE,
445 "bucket feerate must be a finite number greater or equal to the minimum standard feerate"
446 );
447 assert!(
448 b.estimated_seconds.is_normal() && b.estimated_seconds > 0.0,
449 "bucket estimated seconds must be a finite number greater than zero"
450 );
451 }
452 dbg!(len, estimator);
453 dbg!(estimations);
454 }
455 }
456
457 #[test]
458 fn test_feerate_estimator_with_low_mass_outliers() {
459 const MIN_FEERATE: f64 = 1.0;
460 const STD_FEERATE: f64 = 10.0;
461 const HIGH_FEERATE: f64 = 1000.0;
462
463 let cap = 20_000;
464 let mut frontier = Frontier::default();
465 for i in 0..cap as u64 {
466 let (mass, fee) = if i < 200 {
467 let mass = 1650;
468 (mass, (HIGH_FEERATE * mass as f64) as u64)
469 } else {
470 let mass = 90_000;
471 (mass, (STD_FEERATE * mass as f64) as u64)
472 };
473 let key = build_feerate_key(fee, mass, i);
474 frontier.insert(key).then_some(()).unwrap();
475 }
476
477 let args = FeerateEstimatorArgs { network_blocks_per_second: 1, maximum_mass_per_block: 500_000 };
478 let estimator = frontier.build_feerate_estimator(args);
480 let estimations = estimator.calc_estimations(MIN_FEERATE);
481
482 let normal_feerate = estimations.normal_buckets.first().unwrap().feerate;
484 assert!(
485 normal_feerate < HIGH_FEERATE / 10.0,
486 "Normal bucket feerate is expected to be << high feerate due to small mass of high feerate txs ({}, {})",
487 normal_feerate,
488 HIGH_FEERATE
489 );
490
491 let buckets = estimations.ordered_buckets();
492 for b in buckets.iter() {
494 assert!(
495 b.feerate.is_normal() && b.feerate >= MIN_FEERATE,
496 "bucket feerate must be a finite number greater or equal to the minimum standard feerate"
497 );
498 assert!(
499 b.estimated_seconds.is_normal() && b.estimated_seconds > 0.0,
500 "bucket estimated seconds must be a finite number greater than zero"
501 );
502 }
503 dbg!(estimator);
504 dbg!(estimations);
505 }
506
507 #[test]
508 fn test_feerate_estimator_with_less_than_block_capacity() {
509 let mut map = HashMap::new();
510 for i in 0..304 {
511 let mass: u64 = 1650;
512 let fee = 10_000_000 * 1_000_000;
513 let key = build_feerate_key(fee, mass, i);
514 map.insert(key.tx.id(), key);
515 }
516
517 for len in [0, 1, 10, 100, 200, 250, 300] {
519 let mut frontier = Frontier::default();
520 for item in map.values().take(len).cloned() {
521 frontier.insert(item).then_some(()).unwrap();
522 }
523
524 let args = FeerateEstimatorArgs { network_blocks_per_second: 1, maximum_mass_per_block: 500_000 };
525 let estimator = frontier.build_feerate_estimator(args);
527 let estimations = estimator.calc_estimations(1.0);
528
529 let buckets = estimations.ordered_buckets();
530 for b in buckets.iter() {
532 assert!(b.feerate == 1.0, "bucket feerate is expected to be equal to the minimum standard feerate");
534 assert!(
535 b.estimated_seconds.is_normal() && b.estimated_seconds > 0.0 && b.estimated_seconds <= 1.0,
536 "bucket estimated seconds must be a finite number greater than zero & less than 1.0"
537 );
538 }
539 dbg!(len, estimator);
540 dbg!(estimations);
541 }
542 }
543}