Skip to main content

hotmint_mempool/
lib.rs

1use std::cmp::Ordering;
2use std::collections::HashMap;
3
4use tokio::sync::Mutex;
5use tracing::debug;
6
7/// Transaction hash for deduplication
8pub type TxHash = [u8; 32];
9
10/// A transaction entry in the priority mempool.
11#[derive(Clone)]
12struct TxEntry {
13    tx: Vec<u8>,
14    priority: u64,
15    gas_wanted: u64,
16    hash: TxHash,
17}
18
19impl Eq for TxEntry {}
20
21impl PartialEq for TxEntry {
22    fn eq(&self, other: &Self) -> bool {
23        self.cmp(other) == Ordering::Equal
24    }
25}
26
27/// Order by (priority ASC, hash ASC) so the *first* element in BTreeSet
28/// is the lowest-priority tx and the *last* is the highest.
29/// `Eq` is derived from `Ord` for BTreeSet consistency.
30impl Ord for TxEntry {
31    fn cmp(&self, other: &Self) -> Ordering {
32        self.priority
33            .cmp(&other.priority)
34            .then_with(|| self.hash.cmp(&other.hash))
35    }
36}
37
38impl PartialOrd for TxEntry {
39    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
40        Some(self.cmp(other))
41    }
42}
43
44/// Priority-based mempool with deduplication, eviction, and replace-by-fee (RBF).
45///
46/// Transactions are ordered by priority (highest first). When the pool is
47/// full, a new transaction with higher priority than the lowest-priority
48/// entry will evict it. This prevents spam DoS and enables fee-based
49/// ordering for DeFi applications.
50///
51/// RBF: submitting the same tx bytes with a higher priority replaces the
52/// existing pending entry. This allows wallets to bump fees on stuck txs.
53pub struct Mempool {
54    entries: Mutex<std::collections::BTreeSet<TxEntry>>,
55    /// Maps tx hash → current priority for RBF and for safe removal from the BTreeSet.
56    seen: Mutex<HashMap<TxHash, u64>>,
57    max_size: usize,
58    max_tx_bytes: usize,
59}
60
61impl Mempool {
62    pub fn new(max_size: usize, max_tx_bytes: usize) -> Self {
63        Self {
64            entries: Mutex::new(std::collections::BTreeSet::new()),
65            seen: Mutex::new(HashMap::new()),
66            max_size,
67            max_tx_bytes,
68        }
69    }
70
71    /// Add a transaction with a given priority.
72    ///
73    /// Returns `true` if accepted. When the pool is full, the new tx is
74    /// accepted only if its priority exceeds the lowest-priority entry,
75    /// which is then evicted.
76    ///
77    /// **Replace-by-fee (RBF):** if the same tx bytes are already pending
78    /// with a *lower* priority, the existing entry is replaced with the
79    /// new higher-priority one. This lets wallets bump stuck transactions.
80    pub async fn add_tx(&self, tx: Vec<u8>, priority: u64) -> bool {
81        self.add_tx_with_gas(tx, priority, 0).await
82    }
83
84    /// Add a transaction with priority and gas_wanted.
85    pub async fn add_tx_with_gas(&self, tx: Vec<u8>, priority: u64, gas_wanted: u64) -> bool {
86        if tx.len() > self.max_tx_bytes {
87            debug!(size = tx.len(), max = self.max_tx_bytes, "tx too large");
88            return false;
89        }
90
91        let hash = Self::hash_tx(&tx);
92
93        // Lock order: entries first, then seen
94        let mut entries = self.entries.lock().await;
95        let mut seen = self.seen.lock().await;
96
97        if let Some(&existing_priority) = seen.get(&hash) {
98            if priority <= existing_priority {
99                // Exact duplicate or lower-fee resubmission: reject.
100                return false;
101            }
102            // RBF: remove the old lower-priority entry and replace it.
103            let old = TxEntry {
104                tx: tx.clone(),
105                priority: existing_priority,
106                gas_wanted: 0,
107                hash,
108            };
109            entries.remove(&old);
110            seen.insert(hash, priority);
111            entries.insert(TxEntry {
112                tx,
113                priority,
114                gas_wanted,
115                hash,
116            });
117            debug!(
118                old = existing_priority,
119                new = priority,
120                "replaced tx via RBF"
121            );
122            return true;
123        }
124
125        if entries.len() >= self.max_size {
126            // Check if new tx beats the lowest-priority entry
127            if let Some(lowest) = entries.first() {
128                if priority <= lowest.priority {
129                    debug!(
130                        priority,
131                        lowest = lowest.priority,
132                        "mempool full, priority too low"
133                    );
134                    return false;
135                }
136                // Evict lowest
137                let evicted = entries.pop_first().expect("just checked non-empty");
138                seen.remove(&evicted.hash);
139                debug!(
140                    evicted_priority = evicted.priority,
141                    new_priority = priority,
142                    "evicted low-priority tx"
143                );
144            }
145        }
146
147        seen.insert(hash, priority);
148        entries.insert(TxEntry {
149            tx,
150            priority,
151            gas_wanted,
152            hash,
153        });
154        true
155    }
156
157    /// Collect transactions for a block proposal (up to max_bytes and max_gas total).
158    /// Collected transactions are removed from the pool and the seen set.
159    /// Transactions are collected in priority order (highest first).
160    /// The payload is length-prefixed: `[u32_le len][bytes]...`
161    ///
162    /// `max_gas` of 0 disables gas accounting (byte limit only).
163    pub async fn collect_payload(&self, max_bytes: usize) -> Vec<u8> {
164        self.collect_payload_with_gas(max_bytes, 0).await
165    }
166
167    /// Collect with both byte and gas limits.
168    ///
169    /// Skips transactions that exceed the remaining gas budget (instead of
170    /// stopping) to avoid head-of-line starvation by a single high-gas tx.
171    pub async fn collect_payload_with_gas(&self, max_bytes: usize, max_gas: u64) -> Vec<u8> {
172        let mut entries = self.entries.lock().await;
173        let mut seen = self.seen.lock().await;
174        let mut payload = Vec::new();
175        let mut total_gas = 0u64;
176        let mut skipped = Vec::new();
177        // B-2/A-5: Cap skipped transactions to bound the loop.
178        const MAX_SKIPPED: usize = 200;
179
180        while let Some(entry) = entries.pop_last() {
181            // Byte limit: skip this tx if it doesn't fit, continue with smaller ones.
182            if payload.len() + 4 + entry.tx.len() > max_bytes {
183                skipped.push(entry);
184                if skipped.len() >= MAX_SKIPPED {
185                    break;
186                }
187                continue;
188            }
189            // Gas limit: skip this tx but continue collecting smaller ones.
190            if max_gas > 0 && total_gas + entry.gas_wanted > max_gas {
191                skipped.push(entry);
192                if skipped.len() >= MAX_SKIPPED {
193                    break;
194                }
195                continue;
196            }
197            seen.remove(&entry.hash);
198            total_gas += entry.gas_wanted;
199            let len = entry.tx.len() as u32;
200            payload.extend_from_slice(&len.to_le_bytes());
201            payload.extend_from_slice(&entry.tx);
202        }
203
204        // Re-insert skipped entries back into the pool.
205        for entry in skipped {
206            entries.insert(entry);
207        }
208
209        payload
210    }
211
212    /// Reap collected payload back into individual transactions
213    pub fn decode_payload(payload: &[u8]) -> Vec<Vec<u8>> {
214        let mut txs = Vec::new();
215        let mut offset = 0;
216        while offset + 4 <= payload.len() {
217            let len = u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
218            offset += 4;
219            if offset + len > payload.len() {
220                break;
221            }
222            txs.push(payload[offset..offset + len].to_vec());
223            offset += len;
224        }
225        txs
226    }
227
228    pub async fn size(&self) -> usize {
229        self.entries.lock().await.len()
230    }
231
232    /// Re-validate all pending transactions after a block commit.
233    ///
234    /// Calls `validator` on each pending tx. If it returns `None`, the tx
235    /// is evicted (no longer valid against updated state). If it returns
236    /// `Some((priority, gas_wanted))`, the tx is kept with possibly updated
237    /// priority.
238    pub async fn recheck(&self, validator: impl Fn(&[u8]) -> Option<(u64, u64)>) {
239        let mut entries = self.entries.lock().await;
240        let mut seen = self.seen.lock().await;
241
242        let old: Vec<TxEntry> = entries.iter().cloned().collect();
243        entries.clear();
244        seen.clear();
245
246        let mut removed = 0usize;
247        for entry in old {
248            match validator(&entry.tx) {
249                Some((new_priority, new_gas)) => {
250                    seen.insert(entry.hash, new_priority);
251                    entries.insert(TxEntry {
252                        tx: entry.tx,
253                        priority: new_priority,
254                        gas_wanted: new_gas,
255                        hash: entry.hash,
256                    });
257                }
258                None => {
259                    removed += 1;
260                }
261            }
262        }
263
264        if removed > 0 {
265            debug!(
266                removed,
267                remaining = entries.len(),
268                "mempool recheck complete"
269            );
270        }
271    }
272
273    fn hash_tx(tx: &[u8]) -> TxHash {
274        *blake3::hash(tx).as_bytes()
275    }
276}
277
278impl Default for Mempool {
279    fn default() -> Self {
280        Self::new(10_000, 1_048_576) // 10k txs, 1MB max per tx
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[tokio::test]
289    async fn test_add_and_collect() {
290        let pool = Mempool::new(100, 1024);
291        assert!(pool.add_tx(b"tx1".to_vec(), 10).await);
292        assert!(pool.add_tx(b"tx2".to_vec(), 20).await);
293        assert_eq!(pool.size().await, 2);
294
295        let payload = pool.collect_payload(1024).await;
296        let txs = Mempool::decode_payload(&payload);
297        assert_eq!(txs.len(), 2);
298        // Higher priority first
299        assert_eq!(txs[0], b"tx2");
300        assert_eq!(txs[1], b"tx1");
301    }
302
303    #[tokio::test]
304    async fn test_dedup() {
305        let pool = Mempool::new(100, 1024);
306        assert!(pool.add_tx(b"tx1".to_vec(), 10).await);
307        assert!(!pool.add_tx(b"tx1".to_vec(), 10).await); // same priority → rejected
308        assert!(!pool.add_tx(b"tx1".to_vec(), 5).await); // lower priority → rejected
309        assert_eq!(pool.size().await, 1);
310    }
311
312    #[tokio::test]
313    async fn test_rbf_replace_by_fee() {
314        let pool = Mempool::new(100, 1024);
315        // Submit tx with low priority
316        assert!(pool.add_tx(b"tx1".to_vec(), 5).await);
317        assert_eq!(pool.size().await, 1);
318        // Re-submit same bytes with higher priority → RBF accepted
319        assert!(pool.add_tx(b"tx1".to_vec(), 20).await);
320        // Pool should still have exactly 1 entry
321        assert_eq!(pool.size().await, 1);
322        // Collected tx should carry the new higher priority (collected first)
323        let payload = pool.collect_payload(1024).await;
324        let txs = Mempool::decode_payload(&payload);
325        assert_eq!(txs.len(), 1);
326        assert_eq!(txs[0], b"tx1");
327    }
328
329    #[tokio::test]
330    async fn test_eviction_by_priority() {
331        let pool = Mempool::new(2, 1024);
332        assert!(pool.add_tx(b"low".to_vec(), 1).await);
333        assert!(pool.add_tx(b"mid".to_vec(), 5).await);
334        // Pool full, but new tx has higher priority → evicts lowest
335        assert!(pool.add_tx(b"high".to_vec(), 10).await);
336        assert_eq!(pool.size().await, 2);
337
338        let payload = pool.collect_payload(1024).await;
339        let txs = Mempool::decode_payload(&payload);
340        assert_eq!(txs.len(), 2);
341        assert_eq!(txs[0], b"high");
342        assert_eq!(txs[1], b"mid");
343    }
344
345    #[tokio::test]
346    async fn test_reject_low_priority_when_full() {
347        let pool = Mempool::new(2, 1024);
348        assert!(pool.add_tx(b"a".to_vec(), 5).await);
349        assert!(pool.add_tx(b"b".to_vec(), 10).await);
350        // New tx has lower priority than lowest → rejected
351        assert!(!pool.add_tx(b"c".to_vec(), 3).await);
352        assert_eq!(pool.size().await, 2);
353    }
354
355    #[tokio::test]
356    async fn test_tx_too_large() {
357        let pool = Mempool::new(100, 4);
358        assert!(!pool.add_tx(b"toolarge".to_vec(), 10).await);
359        assert!(pool.add_tx(b"ok".to_vec(), 10).await);
360    }
361
362    #[tokio::test]
363    async fn test_collect_respects_max_bytes() {
364        let pool = Mempool::new(100, 1024);
365        pool.add_tx(b"aaaa".to_vec(), 1).await;
366        pool.add_tx(b"bbbb".to_vec(), 2).await;
367        pool.add_tx(b"cccc".to_vec(), 3).await;
368
369        // Each tx: 4 bytes len prefix + 4 bytes data = 8 bytes
370        // max_bytes = 17 should fit 2 txs (16 bytes) but not 3 (24 bytes)
371        let payload = pool.collect_payload(17).await;
372        let txs = Mempool::decode_payload(&payload);
373        assert_eq!(txs.len(), 2);
374        // Highest priority first
375        assert_eq!(txs[0], b"cccc");
376        assert_eq!(txs[1], b"bbbb");
377    }
378
379    #[test]
380    fn test_decode_empty_payload() {
381        let txs = Mempool::decode_payload(&[]);
382        assert!(txs.is_empty());
383    }
384
385    #[test]
386    fn test_decode_truncated_payload() {
387        // Only 2 bytes when expecting at least 4 for length prefix
388        let txs = Mempool::decode_payload(&[1, 2]);
389        assert!(txs.is_empty());
390    }
391
392    #[test]
393    fn test_decode_payload_with_truncated_data() {
394        // Length prefix says 100 bytes but only 3 available
395        let mut payload = vec![];
396        payload.extend_from_slice(&100u32.to_le_bytes());
397        payload.extend_from_slice(&[1, 2, 3]);
398        let txs = Mempool::decode_payload(&payload);
399        assert!(txs.is_empty());
400    }
401
402    #[tokio::test]
403    async fn test_empty_tx() {
404        let pool = Mempool::new(100, 1024);
405        assert!(pool.add_tx(vec![], 0).await);
406        let payload = pool.collect_payload(1024).await;
407        let txs = Mempool::decode_payload(&payload);
408        assert_eq!(txs.len(), 1);
409        assert!(txs[0].is_empty());
410    }
411}