kaspa_mining/mempool/model/frontier/
search_tree.rs1use super::feerate_key::FeerateTransactionKey;
2use std::iter::FusedIterator;
3use sweep_bptree::tree::visit::{DescendVisit, DescendVisitResult};
4use sweep_bptree::tree::{Argument, SearchArgument};
5use sweep_bptree::{BPlusTree, NodeStoreVec};
6
7type FeerateKey = FeerateTransactionKey;
8
9#[derive(Clone, Copy, Debug, Default)]
21struct FeerateWeight(f64);
22
23impl FeerateWeight {
24 pub fn weight(&self) -> f64 {
26 self.0
27 }
28}
29
30impl Argument<FeerateKey> for FeerateWeight {
31 fn from_leaf(keys: &[FeerateKey]) -> Self {
32 Self(keys.iter().map(|k| k.weight()).sum())
33 }
34
35 fn from_inner(_keys: &[FeerateKey], arguments: &[Self]) -> Self {
36 Self(arguments.iter().map(|a| a.0).sum())
37 }
38}
39
40impl SearchArgument<FeerateKey> for FeerateWeight {
41 type Query = f64;
42
43 fn locate_in_leaf(query: Self::Query, keys: &[FeerateKey]) -> Option<usize> {
44 let mut sum = 0.0;
45 for (i, k) in keys.iter().enumerate() {
46 let w = k.weight();
47 sum += w;
48 if query < sum {
49 return Some(i);
50 }
51 }
52 match keys.len() {
56 0 => None,
57 n => Some(n - 1),
58 }
59 }
60
61 fn locate_in_inner(mut query: Self::Query, _keys: &[FeerateKey], arguments: &[Self]) -> Option<(usize, Self::Query)> {
62 for (i, a) in arguments.iter().enumerate() {
65 if query >= a.0 {
66 query -= a.0;
67 } else {
68 return Some((i, query));
69 }
70 }
71 match arguments.len() {
76 0 => None,
77 n => Some((n - 1, arguments[n - 1].0)),
78 }
79 }
80}
81
82struct PrefixWeightVisitor<'a> {
87 key: &'a FeerateKey,
89 accumulated_weight: f64,
91}
92
93impl<'a> PrefixWeightVisitor<'a> {
94 pub fn new(key: &'a FeerateKey) -> Self {
95 Self { key, accumulated_weight: Default::default() }
96 }
97
98 fn search_in_keys(&self, keys: &[FeerateKey]) -> usize {
101 match keys.binary_search(self.key) {
102 Err(idx) => {
103 idx
105 }
106 Ok(idx) => {
107 idx + 1
109 }
110 }
111 }
112}
113
114impl<'a> DescendVisit<FeerateKey, (), FeerateWeight> for PrefixWeightVisitor<'a> {
115 type Result = f64;
116
117 fn visit_inner(&mut self, keys: &[FeerateKey], arguments: &[FeerateWeight]) -> DescendVisitResult<Self::Result> {
118 let idx = self.search_in_keys(keys);
119 for argument in arguments.iter().take(idx) {
125 self.accumulated_weight += argument.weight();
126 }
127
128 DescendVisitResult::GoDown(idx)
130 }
131
132 fn visit_leaf(&mut self, keys: &[FeerateKey], _values: &[()]) -> Option<Self::Result> {
133 let idx = self.search_in_keys(keys);
135 for key in keys.iter().take(idx) {
137 self.accumulated_weight += key.weight();
138 }
139 Some(self.accumulated_weight)
141 }
142}
143
144type InnerTree = BPlusTree<NodeStoreVec<FeerateKey, (), FeerateWeight>>;
145
146pub struct SearchTree {
164 tree: InnerTree,
165}
166
167impl Default for SearchTree {
168 fn default() -> Self {
169 Self { tree: InnerTree::new(Default::default()) }
170 }
171}
172
173impl SearchTree {
174 pub fn new() -> Self {
175 Self { tree: InnerTree::new(Default::default()) }
176 }
177
178 pub fn len(&self) -> usize {
179 self.tree.len()
180 }
181
182 pub fn is_empty(&self) -> bool {
183 self.len() == 0
184 }
185
186 pub fn insert(&mut self, key: FeerateKey) -> bool {
188 self.tree.insert(key, ()).is_none()
189 }
190
191 pub fn remove(&mut self, key: &FeerateKey) -> bool {
193 self.tree.remove(key).is_some()
194 }
195
196 pub fn search(&self, query: f64) -> &FeerateKey {
198 self.tree.get_by_argument(query).expect("clamped").0
199 }
200
201 pub fn total_weight(&self) -> f64 {
203 self.tree.root_argument().weight()
204 }
205
206 pub fn prefix_weight(&self, key: &FeerateKey) -> f64 {
209 self.tree.descend_visit(PrefixWeightVisitor::new(key)).unwrap()
210 }
211
212 pub fn descending_iter(&self) -> impl DoubleEndedIterator<Item = &FeerateKey> + ExactSizeIterator + FusedIterator {
215 self.tree.iter().rev().map(|(key, ())| key)
216 }
217
218 pub fn ascending_iter(&self) -> impl DoubleEndedIterator<Item = &FeerateKey> + ExactSizeIterator + FusedIterator {
221 self.tree.iter().map(|(key, ())| key)
222 }
223
224 pub fn first(&self) -> Option<&FeerateKey> {
226 self.tree.first().map(|(k, ())| k)
227 }
228
229 pub fn last(&self) -> Option<&FeerateKey> {
231 self.tree.last().map(|(k, ())| k)
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::super::feerate_key::tests::build_feerate_key;
238 use super::*;
239 use itertools::Itertools;
240 use std::collections::HashSet;
241 use std::ops::Sub;
242
243 #[test]
244 fn test_feerate_weight_queries() {
245 let mut tree = SearchTree::new();
246 let mass = 2000;
247 let fees = vec![[123, 113, 10_000, 1000, 2050, 2048]; 64 * (64 + 1)].into_iter().flatten().collect_vec();
250
251 #[allow(clippy::mutable_key_type)]
252 let mut s = HashSet::with_capacity(fees.len());
253 for (i, fee) in fees.iter().copied().enumerate() {
254 let key = build_feerate_key(fee, mass, i as u64);
255 s.insert(key.clone());
256 tree.insert(key);
257 }
258
259 let remove = s.iter().take(fees.len() / 6).cloned().collect_vec();
261 for r in remove {
262 s.remove(&r);
263 tree.remove(&r);
264 }
265
266 let mut v = s.into_iter().collect_vec();
268 v.sort();
269
270 for (expected, item) in v.iter().rev().zip(tree.descending_iter()) {
272 assert_eq!(&expected, &item);
273 assert!(expected.cmp(item).is_eq()); }
275
276 let eps: f64 = 0.001;
278 let mut sum = 0.0;
279 for expected in v.iter() {
280 let weight = expected.weight();
281 let eps = eps.min(weight / 3.0);
282 let samples = [sum + eps, sum + weight / 2.0, sum + weight - eps];
283 for sample in samples {
284 let key = tree.search(sample);
285 assert_eq!(expected, key);
286 assert!(expected.cmp(key).is_eq()); }
288 sum += weight;
289 }
290
291 println!("{}, {}", sum, tree.total_weight());
292
293 assert_eq!(tree.first(), Some(tree.search(f64::NEG_INFINITY)));
295 assert_eq!(tree.first(), Some(tree.search(-1.0)));
296 assert_eq!(tree.first(), Some(tree.search(-eps)));
297 assert_eq!(tree.first(), Some(tree.search(0.0)));
298 assert_eq!(tree.last(), Some(tree.search(sum)));
299 assert_eq!(tree.last(), Some(tree.search(sum + eps)));
300 assert_eq!(tree.last(), Some(tree.search(sum + 1.0)));
301 assert_eq!(tree.last(), Some(tree.search(1.0 / 0.0)));
302 assert_eq!(tree.last(), Some(tree.search(f64::INFINITY)));
303 let _ = tree.search(f64::NAN);
304
305 let mut prefix = Vec::with_capacity(v.len());
307 prefix.push(v[0].weight());
308 for i in 1..v.len() {
309 prefix.push(prefix[i - 1] + v[i].weight());
310 }
311 let eps = v.iter().map(|k| k.weight()).min_by(f64::total_cmp).unwrap() * 1e-4;
312 for (expected_prefix, key) in prefix.into_iter().zip(v) {
313 let prefix = tree.prefix_weight(&key);
314 assert!(expected_prefix.sub(prefix).abs() < eps);
315 }
316 }
317
318 #[test]
319 fn test_tree_rev_iter() {
320 let mut tree = SearchTree::new();
321 let mass = 2000;
322 let fees = vec![[123, 113, 10_000, 1000, 2050, 2048]; 64 * (64 + 1)].into_iter().flatten().collect_vec();
323 let mut v = Vec::with_capacity(fees.len());
324 for (i, fee) in fees.iter().copied().enumerate() {
325 let key = build_feerate_key(fee, mass, i as u64);
326 v.push(key.clone());
327 tree.insert(key);
328 }
329 v.sort();
330
331 for (expected, item) in v.into_iter().rev().zip(tree.descending_iter()) {
332 assert_eq!(&expected, item);
333 assert!(expected.cmp(item).is_eq()); }
335 }
336}