Skip to main content

hotmint_mempool/
lib.rs

1use std::cmp::Ordering;
2use std::collections::HashMap;
3
4use tokio::sync::Mutex;
5use tracing::debug;
6
7/// Transaction hash for deduplication
8pub type TxHash = [u8; 32];
9
10/// A transaction entry in the priority mempool.
11#[derive(Clone)]
12struct TxEntry {
13    tx: Vec<u8>,
14    priority: u64,
15    gas_wanted: u64,
16    hash: TxHash,
17}
18
19impl Eq for TxEntry {}
20
21impl PartialEq for TxEntry {
22    fn eq(&self, other: &Self) -> bool {
23        self.cmp(other) == Ordering::Equal
24    }
25}
26
27/// Order by (priority ASC, hash ASC) so the *first* element in BTreeSet
28/// is the lowest-priority tx and the *last* is the highest.
29/// `Eq` is derived from `Ord` for BTreeSet consistency.
30impl Ord for TxEntry {
31    fn cmp(&self, other: &Self) -> Ordering {
32        self.priority
33            .cmp(&other.priority)
34            .then_with(|| self.hash.cmp(&other.hash))
35    }
36}
37
38impl PartialOrd for TxEntry {
39    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
40        Some(self.cmp(other))
41    }
42}
43
44/// Priority-based mempool with deduplication, eviction, and replace-by-fee (RBF).
45///
46/// Transactions are ordered by priority (highest first). When the pool is
47/// full, a new transaction with higher priority than the lowest-priority
48/// entry will evict it. This prevents spam DoS and enables fee-based
49/// ordering for DeFi applications.
50///
51/// RBF: submitting the same tx bytes with a higher priority replaces the
52/// existing pending entry. This allows wallets to bump fees on stuck txs.
53pub struct Mempool {
54    entries: Mutex<std::collections::BTreeSet<TxEntry>>,
55    /// Maps tx hash → current priority for RBF and for safe removal from the BTreeSet.
56    seen: Mutex<HashMap<TxHash, u64>>,
57    max_size: usize,
58    max_tx_bytes: usize,
59}
60
61impl Mempool {
62    pub fn new(max_size: usize, max_tx_bytes: usize) -> Self {
63        Self {
64            entries: Mutex::new(std::collections::BTreeSet::new()),
65            seen: Mutex::new(HashMap::new()),
66            max_size,
67            max_tx_bytes,
68        }
69    }
70
71    /// Add a transaction with a given priority.
72    ///
73    /// Returns `true` if accepted. When the pool is full, the new tx is
74    /// accepted only if its priority exceeds the lowest-priority entry,
75    /// which is then evicted.
76    ///
77    /// **Replace-by-fee (RBF):** if the same tx bytes are already pending
78    /// with a *lower* priority, the existing entry is replaced with the
79    /// new higher-priority one. This lets wallets bump stuck transactions.
80    pub async fn add_tx(&self, tx: Vec<u8>, priority: u64) -> bool {
81        self.add_tx_with_gas(tx, priority, 0).await
82    }
83
84    /// Add a transaction with priority and gas_wanted.
85    pub async fn add_tx_with_gas(&self, tx: Vec<u8>, priority: u64, gas_wanted: u64) -> bool {
86        if tx.len() > self.max_tx_bytes {
87            debug!(size = tx.len(), max = self.max_tx_bytes, "tx too large");
88            return false;
89        }
90
91        let hash = Self::hash_tx(&tx);
92
93        // Lock order: entries first, then seen
94        let mut entries = self.entries.lock().await;
95        let mut seen = self.seen.lock().await;
96
97        if let Some(&existing_priority) = seen.get(&hash) {
98            if priority <= existing_priority {
99                // Exact duplicate or lower-fee resubmission: reject.
100                return false;
101            }
102            // RBF: remove the old lower-priority entry and replace it.
103            let old = TxEntry {
104                tx: tx.clone(),
105                priority: existing_priority,
106                gas_wanted: 0,
107                hash,
108            };
109            entries.remove(&old);
110            seen.insert(hash, priority);
111            entries.insert(TxEntry {
112                tx,
113                priority,
114                gas_wanted,
115                hash,
116            });
117            debug!(
118                old = existing_priority,
119                new = priority,
120                "replaced tx via RBF"
121            );
122            return true;
123        }
124
125        if entries.len() >= self.max_size {
126            // Check if new tx beats the lowest-priority entry
127            if let Some(lowest) = entries.first() {
128                if priority <= lowest.priority {
129                    debug!(
130                        priority,
131                        lowest = lowest.priority,
132                        "mempool full, priority too low"
133                    );
134                    return false;
135                }
136                // Evict lowest
137                let evicted = entries.pop_first().expect("just checked non-empty");
138                seen.remove(&evicted.hash);
139                debug!(
140                    evicted_priority = evicted.priority,
141                    new_priority = priority,
142                    "evicted low-priority tx"
143                );
144            }
145        }
146
147        seen.insert(hash, priority);
148        entries.insert(TxEntry {
149            tx,
150            priority,
151            gas_wanted,
152            hash,
153        });
154        true
155    }
156
157    /// Collect transactions for a block proposal (up to max_bytes and max_gas total).
158    /// Collected transactions are removed from the pool and the seen set.
159    /// Transactions are collected in priority order (highest first).
160    /// The payload is length-prefixed: `[u32_le len][bytes]...`
161    ///
162    /// `max_gas` of 0 disables gas accounting (byte limit only).
163    pub async fn collect_payload(&self, max_bytes: usize) -> Vec<u8> {
164        self.collect_payload_with_gas(max_bytes, 0).await
165    }
166
167    /// Collect with both byte and gas limits.
168    ///
169    /// Skips transactions that exceed the remaining gas budget (instead of
170    /// stopping) to avoid head-of-line starvation by a single high-gas tx.
171    pub async fn collect_payload_with_gas(&self, max_bytes: usize, max_gas: u64) -> Vec<u8> {
172        let mut entries = self.entries.lock().await;
173        let mut seen = self.seen.lock().await;
174        let mut payload = Vec::new();
175        let mut total_gas = 0u64;
176        let mut skipped = Vec::new();
177
178        while let Some(entry) = entries.pop_last() {
179            // Byte limit: hard stop (all remaining txs are at most this size or smaller,
180            // but we can't know without iterating — stop here for simplicity).
181            if payload.len() + 4 + entry.tx.len() > max_bytes {
182                skipped.push(entry);
183                break;
184            }
185            // Gas limit: skip this tx but continue collecting smaller ones.
186            if max_gas > 0 && total_gas + entry.gas_wanted > max_gas {
187                skipped.push(entry);
188                continue;
189            }
190            seen.remove(&entry.hash);
191            total_gas += entry.gas_wanted;
192            let len = entry.tx.len() as u32;
193            payload.extend_from_slice(&len.to_le_bytes());
194            payload.extend_from_slice(&entry.tx);
195        }
196
197        // Re-insert skipped entries back into the pool.
198        for entry in skipped {
199            entries.insert(entry);
200        }
201
202        payload
203    }
204
205    /// Reap collected payload back into individual transactions
206    pub fn decode_payload(payload: &[u8]) -> Vec<Vec<u8>> {
207        let mut txs = Vec::new();
208        let mut offset = 0;
209        while offset + 4 <= payload.len() {
210            let len = u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
211            offset += 4;
212            if offset + len > payload.len() {
213                break;
214            }
215            txs.push(payload[offset..offset + len].to_vec());
216            offset += len;
217        }
218        txs
219    }
220
221    pub async fn size(&self) -> usize {
222        self.entries.lock().await.len()
223    }
224
225    /// Re-validate all pending transactions after a block commit.
226    ///
227    /// Calls `validator` on each pending tx. If it returns `None`, the tx
228    /// is evicted (no longer valid against updated state). If it returns
229    /// `Some((priority, gas_wanted))`, the tx is kept with possibly updated
230    /// priority.
231    pub async fn recheck(&self, validator: impl Fn(&[u8]) -> Option<(u64, u64)>) {
232        let mut entries = self.entries.lock().await;
233        let mut seen = self.seen.lock().await;
234
235        let old: Vec<TxEntry> = entries.iter().cloned().collect();
236        entries.clear();
237        seen.clear();
238
239        let mut removed = 0usize;
240        for entry in old {
241            match validator(&entry.tx) {
242                Some((new_priority, new_gas)) => {
243                    seen.insert(entry.hash, new_priority);
244                    entries.insert(TxEntry {
245                        tx: entry.tx,
246                        priority: new_priority,
247                        gas_wanted: new_gas,
248                        hash: entry.hash,
249                    });
250                }
251                None => {
252                    removed += 1;
253                }
254            }
255        }
256
257        if removed > 0 {
258            debug!(
259                removed,
260                remaining = entries.len(),
261                "mempool recheck complete"
262            );
263        }
264    }
265
266    fn hash_tx(tx: &[u8]) -> TxHash {
267        *blake3::hash(tx).as_bytes()
268    }
269}
270
271impl Default for Mempool {
272    fn default() -> Self {
273        Self::new(10_000, 1_048_576) // 10k txs, 1MB max per tx
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[tokio::test]
282    async fn test_add_and_collect() {
283        let pool = Mempool::new(100, 1024);
284        assert!(pool.add_tx(b"tx1".to_vec(), 10).await);
285        assert!(pool.add_tx(b"tx2".to_vec(), 20).await);
286        assert_eq!(pool.size().await, 2);
287
288        let payload = pool.collect_payload(1024).await;
289        let txs = Mempool::decode_payload(&payload);
290        assert_eq!(txs.len(), 2);
291        // Higher priority first
292        assert_eq!(txs[0], b"tx2");
293        assert_eq!(txs[1], b"tx1");
294    }
295
296    #[tokio::test]
297    async fn test_dedup() {
298        let pool = Mempool::new(100, 1024);
299        assert!(pool.add_tx(b"tx1".to_vec(), 10).await);
300        assert!(!pool.add_tx(b"tx1".to_vec(), 10).await); // same priority → rejected
301        assert!(!pool.add_tx(b"tx1".to_vec(), 5).await); // lower priority → rejected
302        assert_eq!(pool.size().await, 1);
303    }
304
305    #[tokio::test]
306    async fn test_rbf_replace_by_fee() {
307        let pool = Mempool::new(100, 1024);
308        // Submit tx with low priority
309        assert!(pool.add_tx(b"tx1".to_vec(), 5).await);
310        assert_eq!(pool.size().await, 1);
311        // Re-submit same bytes with higher priority → RBF accepted
312        assert!(pool.add_tx(b"tx1".to_vec(), 20).await);
313        // Pool should still have exactly 1 entry
314        assert_eq!(pool.size().await, 1);
315        // Collected tx should carry the new higher priority (collected first)
316        let payload = pool.collect_payload(1024).await;
317        let txs = Mempool::decode_payload(&payload);
318        assert_eq!(txs.len(), 1);
319        assert_eq!(txs[0], b"tx1");
320    }
321
322    #[tokio::test]
323    async fn test_eviction_by_priority() {
324        let pool = Mempool::new(2, 1024);
325        assert!(pool.add_tx(b"low".to_vec(), 1).await);
326        assert!(pool.add_tx(b"mid".to_vec(), 5).await);
327        // Pool full, but new tx has higher priority → evicts lowest
328        assert!(pool.add_tx(b"high".to_vec(), 10).await);
329        assert_eq!(pool.size().await, 2);
330
331        let payload = pool.collect_payload(1024).await;
332        let txs = Mempool::decode_payload(&payload);
333        assert_eq!(txs.len(), 2);
334        assert_eq!(txs[0], b"high");
335        assert_eq!(txs[1], b"mid");
336    }
337
338    #[tokio::test]
339    async fn test_reject_low_priority_when_full() {
340        let pool = Mempool::new(2, 1024);
341        assert!(pool.add_tx(b"a".to_vec(), 5).await);
342        assert!(pool.add_tx(b"b".to_vec(), 10).await);
343        // New tx has lower priority than lowest → rejected
344        assert!(!pool.add_tx(b"c".to_vec(), 3).await);
345        assert_eq!(pool.size().await, 2);
346    }
347
348    #[tokio::test]
349    async fn test_tx_too_large() {
350        let pool = Mempool::new(100, 4);
351        assert!(!pool.add_tx(b"toolarge".to_vec(), 10).await);
352        assert!(pool.add_tx(b"ok".to_vec(), 10).await);
353    }
354
355    #[tokio::test]
356    async fn test_collect_respects_max_bytes() {
357        let pool = Mempool::new(100, 1024);
358        pool.add_tx(b"aaaa".to_vec(), 1).await;
359        pool.add_tx(b"bbbb".to_vec(), 2).await;
360        pool.add_tx(b"cccc".to_vec(), 3).await;
361
362        // Each tx: 4 bytes len prefix + 4 bytes data = 8 bytes
363        // max_bytes = 17 should fit 2 txs (16 bytes) but not 3 (24 bytes)
364        let payload = pool.collect_payload(17).await;
365        let txs = Mempool::decode_payload(&payload);
366        assert_eq!(txs.len(), 2);
367        // Highest priority first
368        assert_eq!(txs[0], b"cccc");
369        assert_eq!(txs[1], b"bbbb");
370    }
371
372    #[test]
373    fn test_decode_empty_payload() {
374        let txs = Mempool::decode_payload(&[]);
375        assert!(txs.is_empty());
376    }
377
378    #[test]
379    fn test_decode_truncated_payload() {
380        // Only 2 bytes when expecting at least 4 for length prefix
381        let txs = Mempool::decode_payload(&[1, 2]);
382        assert!(txs.is_empty());
383    }
384
385    #[test]
386    fn test_decode_payload_with_truncated_data() {
387        // Length prefix says 100 bytes but only 3 available
388        let mut payload = vec![];
389        payload.extend_from_slice(&100u32.to_le_bytes());
390        payload.extend_from_slice(&[1, 2, 3]);
391        let txs = Mempool::decode_payload(&payload);
392        assert!(txs.is_empty());
393    }
394
395    #[tokio::test]
396    async fn test_empty_tx() {
397        let pool = Mempool::new(100, 1024);
398        assert!(pool.add_tx(vec![], 0).await);
399        let payload = pool.collect_payload(1024).await;
400        let txs = Mempool::decode_payload(&payload);
401        assert_eq!(txs.len(), 1);
402        assert!(txs[0].is_empty());
403    }
404}