Skip to main content

abtc_adapters/mempool/
mod.rs

1//! In-Memory Mempool Implementation
2//!
3//! Provides a real mempool that stores unconfirmed transactions, orders them by
4//! fee rate for mining, enforces size limits, and supports eviction of low-fee
5//! transactions when the pool is full.
6//!
7//! Features:
8//! - RBF (BIP125) replacement support
9//! - Ancestor/descendant tracking with configurable limits
10//! - CPFP-aware mining selection (ancestor fee rate ordering)
11//! - Fee estimation from observed transaction patterns
12//!
13//! All mutable state lives behind a single `RwLock<MempoolInner>` to
14//! eliminate the deadlock risk identified in the code review (finding #17).
15
16use abtc_domain::policy::limits::{MempoolLimits, PackageInfo};
17use abtc_domain::policy::rbf::{RbfPolicy, SignalsRbf};
18use abtc_domain::primitives::{Amount, Transaction, Txid};
19use abtc_ports::{MempoolEntry, MempoolInfo, MempoolPort};
20use async_trait::async_trait;
21use std::collections::{HashMap, HashSet};
22use std::sync::Arc;
23use std::time::{SystemTime, UNIX_EPOCH};
24use tokio::sync::RwLock;
25
26/// Maximum mempool size in bytes (default: 300 MB, matching Bitcoin Core)
27const DEFAULT_MAX_MEMPOOL_BYTES: u64 = 300_000_000;
28
29/// Minimum relay fee rate in satoshis per byte
30const MIN_RELAY_FEE_RATE: f64 = 1.0;
31
32/// All mutable mempool state, held behind a single `RwLock` to prevent
33/// deadlocks from multi-lock acquisition ordering.
34struct MempoolInner {
35    entries: HashMap<Txid, MempoolEntry>,
36    packages: HashMap<Txid, PackageInfo>,
37    children: HashMap<Txid, HashSet<Txid>>,
38    parents: HashMap<Txid, HashSet<Txid>>,
39    total_bytes: u64,
40    current_height: u32,
41    fee_rate_buckets: Vec<f64>,
42}
43
44impl MempoolInner {
45    fn new() -> Self {
46        MempoolInner {
47            entries: HashMap::new(),
48            packages: HashMap::new(),
49            children: HashMap::new(),
50            parents: HashMap::new(),
51            total_bytes: 0,
52            current_height: 0,
53            fee_rate_buckets: Vec::new(),
54        }
55    }
56
57    /// Collect all descendants of a txid into the given set (recursive).
58    fn collect_descendants(&self, txid: Txid, result: &mut HashSet<Txid>) {
59        if let Some(kids) = self.children.get(&txid) {
60            for child in kids {
61                if result.insert(*child) {
62                    self.collect_descendants(*child, result);
63                }
64            }
65        }
66    }
67
68    /// Collect all ancestors of a txid (recursive).
69    fn collect_ancestors(&self, txid: &Txid, result: &mut HashSet<Txid>) {
70        if let Some(pars) = self.parents.get(txid) {
71            for parent in pars {
72                if result.insert(*parent) {
73                    self.collect_ancestors(parent, result);
74                }
75            }
76        }
77    }
78
79    /// Remove a single transaction and clean up graph edges and package info.
80    fn remove_entry(&mut self, txid: &Txid) {
81        if let Some(entry) = self.entries.remove(txid) {
82            self.total_bytes = self.total_bytes.saturating_sub(entry.size as u64);
83            let fee = entry.fee;
84            let vsize = entry.size as u32;
85
86            // Update ancestors: decrement their descendant counts
87            let mut ancestors = HashSet::new();
88            self.collect_ancestors(txid, &mut ancestors);
89            for anc_txid in &ancestors {
90                if let Some(anc_pkg) = self.packages.get_mut(anc_txid) {
91                    anc_pkg.descendant_count = anc_pkg.descendant_count.saturating_sub(1);
92                    anc_pkg.descendant_size = anc_pkg.descendant_size.saturating_sub(vsize);
93                    anc_pkg.descendant_fee = Amount::from_sat(
94                        anc_pkg.descendant_fee.as_sat().saturating_sub(fee.as_sat()),
95                    );
96                }
97            }
98
99            // Clean up graph edges
100            if let Some(my_parents) = self.parents.remove(txid) {
101                for parent in &my_parents {
102                    if let Some(parent_children) = self.children.get_mut(parent) {
103                        parent_children.remove(txid);
104                    }
105                }
106            }
107            if let Some(my_children) = self.children.remove(txid) {
108                for child in &my_children {
109                    if let Some(child_parents) = self.parents.get_mut(child) {
110                        child_parents.remove(txid);
111                    }
112                }
113            }
114
115            self.packages.remove(txid);
116        }
117    }
118
119    /// Update ancestor/descendant package info after adding a transaction.
120    fn update_package_info(&mut self, txid: Txid, fee: Amount, vsize: u32) {
121        let mut ancestors = HashSet::new();
122        self.collect_ancestors(&txid, &mut ancestors);
123
124        let ancestor_count = (ancestors.len() + 1) as u32;
125        let mut ancestor_size = vsize;
126        let mut ancestor_fee = fee;
127
128        for anc_txid in &ancestors {
129            if let Some(pkg) = self.packages.get(anc_txid) {
130                ancestor_size += pkg.vsize;
131                ancestor_fee = Amount::from_sat(ancestor_fee.as_sat() + pkg.fee.as_sat());
132            }
133        }
134
135        let pkg = PackageInfo {
136            txid,
137            vsize,
138            fee,
139            ancestor_count,
140            ancestor_size,
141            ancestor_fee,
142            descendant_count: 1,
143            descendant_size: vsize,
144            descendant_fee: fee,
145        };
146        self.packages.insert(txid, pkg);
147
148        // Update descendant info on all ancestors
149        for anc_txid in &ancestors {
150            if let Some(anc_pkg) = self.packages.get_mut(anc_txid) {
151                anc_pkg.descendant_count += 1;
152                anc_pkg.descendant_size += vsize;
153                anc_pkg.descendant_fee =
154                    Amount::from_sat(anc_pkg.descendant_fee.as_sat() + fee.as_sat());
155            }
156        }
157    }
158
159    /// Evict lowest descendant-fee-rate transactions to make room.
160    fn evict_if_needed(&mut self, max_bytes: u64) {
161        if self.total_bytes <= max_bytes {
162            return;
163        }
164
165        let mut by_desc_rate: Vec<(Txid, f64, usize)> = self
166            .entries
167            .iter()
168            .map(|(txid, entry)| {
169                let rate = self
170                    .packages
171                    .get(txid)
172                    .map(|p| p.descendant_fee_rate())
173                    .unwrap_or(0.0);
174                (*txid, rate, entry.size)
175            })
176            .collect();
177
178        by_desc_rate.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
179
180        for (txid, _rate, size) in by_desc_rate {
181            if self.total_bytes <= max_bytes {
182                break;
183            }
184            self.entries.remove(&txid);
185            self.total_bytes = self.total_bytes.saturating_sub(size as u64);
186            tracing::debug!(
187                "Evicted transaction {} from mempool (low descendant fee rate)",
188                txid
189            );
190        }
191    }
192
193    /// Attempt RBF replacement: check if a new transaction can replace existing ones.
194    /// Returns the set of txids that would be evicted if replacement succeeds.
195    fn try_rbf_replacement(
196        &self,
197        tx: &Transaction,
198        new_fee: Amount,
199        new_size: usize,
200    ) -> Result<Vec<Txid>, String> {
201        let mut conflicting_txids: HashSet<Txid> = HashSet::new();
202
203        for input in &tx.inputs {
204            for (existing_txid, existing_entry) in self.entries.iter() {
205                for existing_input in &existing_entry.tx.inputs {
206                    if existing_input.previous_output == input.previous_output {
207                        conflicting_txids.insert(*existing_txid);
208                    }
209                }
210            }
211        }
212
213        if conflicting_txids.is_empty() {
214            return Err("no conflicting transactions".into());
215        }
216
217        let mut to_evict: HashSet<Txid> = HashSet::new();
218        for txid in &conflicting_txids {
219            self.collect_descendants(*txid, &mut to_evict);
220            to_evict.insert(*txid);
221        }
222
223        let originals: Vec<(Txid, Amount, usize, bool)> = conflicting_txids
224            .iter()
225            .filter_map(|txid| {
226                self.entries
227                    .get(txid)
228                    .map(|entry| (*txid, entry.fee, entry.size, entry.tx.signals_rbf()))
229            })
230            .collect();
231
232        RbfPolicy::check_replacement(new_fee, new_size, &originals, to_evict.len())
233            .map_err(|e| format!("RBF rejected: {}", e))?;
234
235        Ok(to_evict.into_iter().collect())
236    }
237}
238
239/// In-memory mempool implementation with fee-rate ordering and eviction.
240///
241/// All mutable state is held behind a single `RwLock<MempoolInner>` to
242/// eliminate the deadlock risk from acquiring multiple independent locks
243/// (code review finding #17).
244pub struct InMemoryMempool {
245    inner: Arc<RwLock<MempoolInner>>,
246    max_bytes: u64,
247    limits: MempoolLimits,
248}
249
250impl InMemoryMempool {
251    /// Create a new in-memory mempool with default settings
252    pub fn new() -> Self {
253        InMemoryMempool {
254            inner: Arc::new(RwLock::new(MempoolInner::new())),
255            max_bytes: DEFAULT_MAX_MEMPOOL_BYTES,
256            limits: MempoolLimits::default(),
257        }
258    }
259
260    /// Create a new in-memory mempool with custom max size
261    pub fn with_max_bytes(max_bytes: u64) -> Self {
262        InMemoryMempool {
263            max_bytes,
264            ..Self::new()
265        }
266    }
267
268    /// Set the current chain height
269    pub async fn set_height(&self, height: u32) {
270        let mut inner = self.inner.write().await;
271        inner.current_height = height;
272    }
273
274    /// Get the number of transactions in the mempool
275    pub async fn size(&self) -> usize {
276        let inner = self.inner.read().await;
277        inner.entries.len()
278    }
279
280    /// Get transactions ordered by ancestor fee rate (CPFP-aware) for mining.
281    pub async fn get_transactions_by_fee_rate(&self, max_weight: u32) -> Vec<MempoolEntry> {
282        let inner = self.inner.read().await;
283
284        let mut by_ancestor_rate: Vec<(Txid, f64)> = inner
285            .entries
286            .keys()
287            .map(|txid| {
288                let rate = inner
289                    .packages
290                    .get(txid)
291                    .map(|p| p.ancestor_fee_rate())
292                    .unwrap_or(0.0);
293                (*txid, rate)
294            })
295            .collect();
296
297        by_ancestor_rate.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
298
299        let mut total_weight: u32 = 0;
300        let mut selected = Vec::new();
301        let mut included: HashSet<Txid> = HashSet::new();
302
303        for (txid, _rate) in by_ancestor_rate {
304            if let Some(entry) = inner.entries.get(&txid) {
305                let tx_weight = (entry.size as u32) * 4;
306                if total_weight + tx_weight > max_weight {
307                    continue;
308                }
309                if included.contains(&txid) {
310                    continue;
311                }
312                total_weight += tx_weight;
313                included.insert(txid);
314                selected.push(entry.clone());
315            }
316        }
317
318        selected
319    }
320
321    /// Remove transactions that were confirmed in a block
322    pub async fn remove_for_block(&self, transactions: &[Transaction]) {
323        let mut inner = self.inner.write().await;
324        for tx in transactions {
325            let txid = tx.txid();
326            inner.remove_entry(&txid);
327            tracing::debug!("Removed confirmed transaction {} from mempool", txid);
328        }
329    }
330
331    /// Compute the serialized size of a transaction (simplified estimate)
332    fn estimate_tx_size(tx: &Transaction) -> usize {
333        let mut size = 10usize;
334
335        for input in &tx.inputs {
336            size += 41 + input.script_sig.len();
337            if !input.witness.is_empty() {
338                for item in input.witness.stack() {
339                    size += 1 + item.len();
340                }
341            }
342        }
343
344        for output in &tx.outputs {
345            size += 9 + output.script_pubkey.len();
346        }
347
348        size
349    }
350}
351
352impl Default for InMemoryMempool {
353    fn default() -> Self {
354        Self::new()
355    }
356}
357
358#[async_trait]
359impl MempoolPort for InMemoryMempool {
360    async fn add_transaction(
361        &self,
362        tx: &Transaction,
363    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
364        let txid = tx.txid();
365        let size = Self::estimate_tx_size(tx);
366        let max_bytes = self.max_bytes;
367
368        let mut inner = self.inner.write().await;
369
370        // Check if already in mempool
371        if inner.entries.contains_key(&txid) {
372            return Err(format!("Transaction {} already in mempool", txid).into());
373        }
374
375        // Compute fee from in-mempool parent outputs where available.
376        let fee = {
377            let mut input_total: i64 = 0;
378            let mut all_inputs_resolved = true;
379            for input in &tx.inputs {
380                let parent_txid = &input.previous_output.txid;
381                let vout = input.previous_output.vout as usize;
382                if let Some(parent) = inner.entries.get(parent_txid) {
383                    if let Some(output) = parent.tx.outputs.get(vout) {
384                        input_total += output.value.as_sat();
385                    } else {
386                        all_inputs_resolved = false;
387                    }
388                } else {
389                    all_inputs_resolved = false;
390                }
391            }
392            if all_inputs_resolved && !tx.inputs.is_empty() {
393                let output_total: i64 = tx.outputs.iter().map(|o| o.value.as_sat()).sum();
394                Amount::from_sat(std::cmp::max(0, input_total - output_total))
395            } else {
396                Amount::from_sat(0)
397            }
398        };
399
400        // Check for conflicts (same inputs) and attempt RBF if applicable
401        let evict_list = inner.try_rbf_replacement(tx, fee, size);
402        if let Ok(to_evict) = evict_list {
403            for evict_txid in to_evict {
404                inner.remove_entry(&evict_txid);
405                tracing::info!("RBF: evicted conflicting transaction {}", evict_txid);
406            }
407        }
408
409        // Identify in-mempool parents
410        let in_mempool_parents: HashSet<Txid> = tx
411            .inputs
412            .iter()
413            .filter(|input| inner.entries.contains_key(&input.previous_output.txid))
414            .map(|input| input.previous_output.txid)
415            .collect();
416
417        // Check ancestor limits
418        {
419            let mut all_ancestors = HashSet::new();
420            for parent_txid in &in_mempool_parents {
421                all_ancestors.insert(*parent_txid);
422                inner.collect_ancestors(parent_txid, &mut all_ancestors);
423            }
424
425            let ancestor_count = (all_ancestors.len() + 1) as u32;
426            let ancestor_size: u32 = all_ancestors
427                .iter()
428                .filter_map(|t| inner.packages.get(t))
429                .map(|p| p.vsize)
430                .sum::<u32>()
431                + size as u32;
432
433            self.limits
434                .check_ancestor_limits(ancestor_count, ancestor_size)
435                .map_err(|e| format!("Ancestor limit exceeded: {}", e))?;
436        }
437
438        // Check descendant limits (on all ancestors)
439        {
440            for parent_txid in &in_mempool_parents {
441                if let Some(parent_pkg) = inner.packages.get(parent_txid) {
442                    let new_desc_count = parent_pkg.descendant_count + 1;
443                    let new_desc_size = parent_pkg.descendant_size + size as u32;
444
445                    self.limits
446                        .check_descendant_limits(new_desc_count, new_desc_size)
447                        .map_err(|e| format!("Descendant limit exceeded: {}", e))?;
448                }
449
450                let mut ancestors_of_parent = HashSet::new();
451                inner.collect_ancestors(parent_txid, &mut ancestors_of_parent);
452                for anc in &ancestors_of_parent {
453                    if let Some(anc_pkg) = inner.packages.get(anc) {
454                        let new_desc_count = anc_pkg.descendant_count + 1;
455                        let new_desc_size = anc_pkg.descendant_size + size as u32;
456
457                        self.limits
458                            .check_descendant_limits(new_desc_count, new_desc_size)
459                            .map_err(|e| format!("Descendant limit exceeded on ancestor: {}", e))?;
460                    }
461                }
462            }
463        }
464
465        let now = SystemTime::now()
466            .duration_since(UNIX_EPOCH)
467            .unwrap_or_default()
468            .as_secs();
469
470        let height = inner.current_height;
471
472        let entry = MempoolEntry {
473            tx: tx.clone(),
474            fee,
475            size,
476            time: now,
477            height,
478            descendant_count: 0,
479            descendant_size: 0,
480            ancestor_count: in_mempool_parents.len() as u32,
481            ancestor_size: size as u32,
482        };
483
484        inner.entries.insert(txid, entry);
485
486        // Update graph
487        inner.parents.insert(txid, in_mempool_parents.clone());
488        for parent_txid in &in_mempool_parents {
489            inner.children.entry(*parent_txid).or_default().insert(txid);
490        }
491        inner.children.entry(txid).or_default();
492
493        // Update total bytes
494        inner.total_bytes += size as u64;
495
496        // Update package info
497        inner.update_package_info(txid, fee, size as u32);
498
499        // Track fee rate for estimation
500        let fee_rate = fee.as_sat() as f64 / size.max(1) as f64;
501        inner.fee_rate_buckets.push(fee_rate);
502        if inner.fee_rate_buckets.len() > 10000 {
503            let drain_end = inner.fee_rate_buckets.len() - 10000;
504            inner.fee_rate_buckets.drain(0..drain_end);
505        }
506
507        // Evict if over limit
508        inner.evict_if_needed(max_bytes);
509
510        tracing::debug!("Added transaction {} to mempool ({} bytes)", txid, size);
511        Ok(())
512    }
513
514    async fn remove_transaction(
515        &self,
516        txid: &Txid,
517        recursive: bool,
518    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
519        let mut inner = self.inner.write().await;
520
521        if recursive {
522            let mut desc = HashSet::new();
523            inner.collect_descendants(*txid, &mut desc);
524            for desc_txid in desc {
525                inner.remove_entry(&desc_txid);
526                tracing::debug!("Removed dependent transaction {} from mempool", desc_txid);
527            }
528        }
529
530        inner.remove_entry(txid);
531        tracing::debug!("Removed transaction {} from mempool", txid);
532        Ok(())
533    }
534
535    async fn get_transaction(
536        &self,
537        txid: &Txid,
538    ) -> Result<Option<MempoolEntry>, Box<dyn std::error::Error + Send + Sync>> {
539        let inner = self.inner.read().await;
540        Ok(inner.entries.get(txid).cloned())
541    }
542
543    async fn get_all_transactions(
544        &self,
545    ) -> Result<Vec<MempoolEntry>, Box<dyn std::error::Error + Send + Sync>> {
546        let inner = self.inner.read().await;
547        Ok(inner.entries.values().cloned().collect())
548    }
549
550    async fn get_transaction_count(&self) -> Result<u32, Box<dyn std::error::Error + Send + Sync>> {
551        let inner = self.inner.read().await;
552        Ok(inner.entries.len() as u32)
553    }
554
555    async fn estimate_fee(
556        &self,
557        target_blocks: u32,
558    ) -> Result<f64, Box<dyn std::error::Error + Send + Sync>> {
559        let inner = self.inner.read().await;
560
561        if inner.fee_rate_buckets.is_empty() {
562            return Ok(MIN_RELAY_FEE_RATE);
563        }
564
565        let mut sorted = inner.fee_rate_buckets.clone();
566        sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
567
568        let percentile = match target_blocks {
569            1 => 0.10,
570            2..=3 => 0.25,
571            4..=6 => 0.50,
572            7..=12 => 0.75,
573            _ => 0.90,
574        };
575
576        let index = ((sorted.len() as f64 * percentile) as usize).min(sorted.len() - 1);
577        let estimated = sorted[index].max(MIN_RELAY_FEE_RATE);
578
579        Ok(estimated)
580    }
581
582    async fn get_mempool_info(
583        &self,
584    ) -> Result<MempoolInfo, Box<dyn std::error::Error + Send + Sync>> {
585        let inner = self.inner.read().await;
586
587        Ok(MempoolInfo {
588            size: inner.entries.len() as u32,
589            bytes: inner.total_bytes,
590            usage: inner.total_bytes + (inner.entries.len() as u64 * 200),
591            max_mempool: self.max_bytes,
592            min_relay_fee: MIN_RELAY_FEE_RATE / 100_000_000.0,
593        })
594    }
595
596    async fn clear(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
597        let mut inner = self.inner.write().await;
598
599        inner.entries.clear();
600        inner.packages.clear();
601        inner.children.clear();
602        inner.parents.clear();
603        inner.total_bytes = 0;
604
605        tracing::info!("Mempool cleared");
606        Ok(())
607    }
608}
609
610#[cfg(test)]
611mod tests {
612    use super::*;
613    use abtc_domain::primitives::{OutPoint, TxIn, TxOut};
614    use abtc_domain::Script;
615
616    fn make_test_tx(value: i64) -> Transaction {
617        let input = TxIn::final_input(OutPoint::new(Txid::zero(), 0), Script::new());
618        let output = TxOut::new(Amount::from_sat(value), Script::new());
619        Transaction::v1(vec![input], vec![output], 0)
620    }
621
622    fn make_child_tx(parent_txid: Txid, value: i64) -> Transaction {
623        let input = TxIn::final_input(OutPoint::new(parent_txid, 0), Script::new());
624        let output = TxOut::new(Amount::from_sat(value), Script::new());
625        Transaction::v1(vec![input], vec![output], 0)
626    }
627
628    #[tokio::test]
629    async fn test_mempool_add_and_get() {
630        let mempool = InMemoryMempool::new();
631        let tx = make_test_tx(1000);
632        let txid = tx.txid();
633
634        mempool.add_transaction(&tx).await.unwrap();
635
636        let entry = mempool.get_transaction(&txid).await.unwrap();
637        assert!(entry.is_some());
638        assert_eq!(entry.unwrap().tx, tx);
639    }
640
641    #[tokio::test]
642    async fn test_mempool_duplicate_rejection() {
643        let mempool = InMemoryMempool::new();
644        let tx = make_test_tx(1000);
645
646        assert!(mempool.add_transaction(&tx).await.is_ok());
647        assert!(mempool.add_transaction(&tx).await.is_err());
648    }
649
650    #[tokio::test]
651    async fn test_mempool_remove() {
652        let mempool = InMemoryMempool::new();
653        let tx = make_test_tx(1000);
654        let txid = tx.txid();
655
656        mempool.add_transaction(&tx).await.unwrap();
657        assert_eq!(mempool.get_transaction_count().await.unwrap(), 1);
658
659        mempool.remove_transaction(&txid, false).await.unwrap();
660        assert_eq!(mempool.get_transaction_count().await.unwrap(), 0);
661    }
662
663    #[tokio::test]
664    async fn test_mempool_clear() {
665        let mempool = InMemoryMempool::new();
666
667        for i in 0..5 {
668            let tx = make_test_tx(1000 + i);
669            mempool.add_transaction(&tx).await.unwrap();
670        }
671
672        assert_eq!(mempool.get_transaction_count().await.unwrap(), 5);
673
674        mempool.clear().await.unwrap();
675        assert_eq!(mempool.get_transaction_count().await.unwrap(), 0);
676    }
677
678    #[tokio::test]
679    async fn test_mempool_info() {
680        let mempool = InMemoryMempool::new();
681        let tx = make_test_tx(1000);
682        mempool.add_transaction(&tx).await.unwrap();
683
684        let info = mempool.get_mempool_info().await.unwrap();
685        assert_eq!(info.size, 1);
686        assert!(info.bytes > 0);
687        assert_eq!(info.max_mempool, DEFAULT_MAX_MEMPOOL_BYTES);
688    }
689
690    #[tokio::test]
691    async fn test_fee_estimation() {
692        let mempool = InMemoryMempool::new();
693
694        let fee = mempool.estimate_fee(1).await.unwrap();
695        assert_eq!(fee, MIN_RELAY_FEE_RATE);
696    }
697
698    #[tokio::test]
699    async fn test_eviction() {
700        let mempool = InMemoryMempool::with_max_bytes(500);
701
702        for i in 0..10 {
703            let tx = make_test_tx(1000 * (i + 1));
704            let _ = mempool.add_transaction(&tx).await;
705        }
706
707        let info = mempool.get_mempool_info().await.unwrap();
708        assert!(info.bytes <= 500);
709    }
710
711    #[tokio::test]
712    async fn test_parent_child_tracking() {
713        let mempool = InMemoryMempool::new();
714
715        let parent = make_test_tx(50_000);
716        let parent_txid = parent.txid();
717        mempool.add_transaction(&parent).await.unwrap();
718
719        let child = make_child_tx(parent_txid, 40_000);
720        mempool.add_transaction(&child).await.unwrap();
721
722        assert_eq!(mempool.get_transaction_count().await.unwrap(), 2);
723
724        let inner = mempool.inner.read().await;
725        let parent_pkg = inner.packages.get(&parent_txid).unwrap();
726        assert_eq!(parent_pkg.descendant_count, 2);
727    }
728
729    #[tokio::test]
730    async fn test_recursive_remove() {
731        let mempool = InMemoryMempool::new();
732
733        let parent = make_test_tx(50_000);
734        let parent_txid = parent.txid();
735        mempool.add_transaction(&parent).await.unwrap();
736
737        let child = make_child_tx(parent_txid, 40_000);
738        mempool.add_transaction(&child).await.unwrap();
739
740        assert_eq!(mempool.get_transaction_count().await.unwrap(), 2);
741
742        mempool
743            .remove_transaction(&parent_txid, true)
744            .await
745            .unwrap();
746        assert_eq!(mempool.get_transaction_count().await.unwrap(), 0);
747    }
748
749    #[tokio::test]
750    async fn test_cpfp_mining_order() {
751        let mempool = InMemoryMempool::new();
752
753        let parent = make_test_tx(50_000);
754        let parent_txid = parent.txid();
755        mempool.add_transaction(&parent).await.unwrap();
756
757        let child = make_child_tx(parent_txid, 40_000);
758        mempool.add_transaction(&child).await.unwrap();
759
760        let selected = mempool.get_transactions_by_fee_rate(4_000_000).await;
761        assert_eq!(selected.len(), 2);
762    }
763
764    #[tokio::test]
765    async fn test_remove_for_block() {
766        let mempool = InMemoryMempool::new();
767
768        let tx1 = make_test_tx(1000);
769        let tx2 = make_test_tx(2000);
770        mempool.add_transaction(&tx1).await.unwrap();
771        mempool.add_transaction(&tx2).await.unwrap();
772
773        assert_eq!(mempool.get_transaction_count().await.unwrap(), 2);
774
775        mempool.remove_for_block(&[tx1]).await;
776        assert_eq!(mempool.get_transaction_count().await.unwrap(), 1);
777    }
778
779    #[tokio::test]
780    async fn test_ancestor_chain_limit() {
781        let mempool = InMemoryMempool::new();
782
783        let mut prev_txid = Txid::zero();
784        let mut txids = Vec::new();
785
786        for i in 0..25 {
787            let tx = if i == 0 {
788                make_test_tx(100_000)
789            } else {
790                make_child_tx(prev_txid, 100_000 - (i * 1000))
791            };
792            prev_txid = tx.txid();
793            txids.push(prev_txid);
794            mempool.add_transaction(&tx).await.unwrap();
795        }
796
797        assert_eq!(mempool.get_transaction_count().await.unwrap(), 25);
798
799        let tx26 = make_child_tx(prev_txid, 50_000);
800        let result = mempool.add_transaction(&tx26).await;
801        assert!(result.is_err());
802        assert!(result.unwrap_err().to_string().contains("Ancestor limit"));
803    }
804
805    #[tokio::test]
806    async fn test_get_all_transactions() {
807        let mempool = InMemoryMempool::new();
808
809        let tx1 = make_test_tx(1000);
810        let tx2 = make_test_tx(2000);
811        mempool.add_transaction(&tx1).await.unwrap();
812        mempool.add_transaction(&tx2).await.unwrap();
813
814        let all = mempool.get_all_transactions().await.unwrap();
815        assert_eq!(all.len(), 2);
816    }
817
818    #[tokio::test]
819    async fn test_get_nonexistent_transaction() {
820        let mempool = InMemoryMempool::new();
821        let result = mempool.get_transaction(&Txid::zero()).await.unwrap();
822        assert!(result.is_none());
823    }
824
825    #[tokio::test]
826    async fn test_remove_nonexistent_transaction() {
827        let mempool = InMemoryMempool::new();
828        mempool
829            .remove_transaction(&Txid::zero(), false)
830            .await
831            .unwrap();
832    }
833
834    #[tokio::test]
835    async fn test_mempool_set_height() {
836        let mempool = InMemoryMempool::new();
837        mempool.set_height(500).await;
838
839        let tx = make_test_tx(10_000);
840        mempool.add_transaction(&tx).await.unwrap();
841
842        let entry = mempool.get_transaction(&tx.txid()).await.unwrap().unwrap();
843        assert_eq!(entry.height, 500);
844    }
845
846    #[tokio::test]
847    async fn test_mining_selection_weight_limit() {
848        let mempool = InMemoryMempool::new();
849
850        for i in 0..10 {
851            let tx = make_test_tx(1000 * (i + 1));
852            mempool.add_transaction(&tx).await.unwrap();
853        }
854
855        let selected = mempool.get_transactions_by_fee_rate(200).await;
856        assert!(selected.len() < 10);
857
858        let all = mempool.get_transactions_by_fee_rate(4_000_000).await;
859        assert_eq!(all.len(), 10);
860    }
861
862    #[tokio::test]
863    async fn test_remove_for_block_nonexistent() {
864        let mempool = InMemoryMempool::new();
865
866        let fake_tx = make_test_tx(999);
867        mempool.remove_for_block(&[fake_tx]).await;
868        assert_eq!(mempool.get_transaction_count().await.unwrap(), 0);
869    }
870
871    #[tokio::test]
872    async fn test_child_removal_updates_parent_descendants() {
873        let mempool = InMemoryMempool::new();
874
875        let parent = make_test_tx(50_000);
876        let parent_txid = parent.txid();
877        mempool.add_transaction(&parent).await.unwrap();
878
879        let child = make_child_tx(parent_txid, 40_000);
880        let child_txid = child.txid();
881        mempool.add_transaction(&child).await.unwrap();
882
883        {
884            let inner = mempool.inner.read().await;
885            assert_eq!(
886                inner.packages.get(&parent_txid).unwrap().descendant_count,
887                2
888            );
889        }
890
891        mempool
892            .remove_transaction(&child_txid, false)
893            .await
894            .unwrap();
895
896        {
897            let inner = mempool.inner.read().await;
898            assert_eq!(
899                inner.packages.get(&parent_txid).unwrap().descendant_count,
900                1
901            );
902        }
903    }
904
905    #[tokio::test]
906    async fn test_default_impl() {
907        let mempool = InMemoryMempool::default();
908        assert_eq!(mempool.size().await, 0);
909    }
910
911    #[tokio::test]
912    async fn test_fee_estimation_with_data() {
913        let mempool = InMemoryMempool::new();
914
915        for i in 0..100 {
916            let tx = make_test_tx(1000 + i);
917            mempool.add_transaction(&tx).await.unwrap();
918        }
919
920        let fee_1 = mempool.estimate_fee(1).await.unwrap();
921        let fee_12 = mempool.estimate_fee(12).await.unwrap();
922
923        assert!(fee_1 >= MIN_RELAY_FEE_RATE);
924        assert!(fee_12 >= MIN_RELAY_FEE_RATE);
925    }
926}