Skip to main content

hotmint_mempool/
lib.rs

1use std::cmp::Ordering;
2use std::collections::HashMap;
3
4use tokio::sync::Mutex;
5use tracing::debug;
6
7/// Transaction hash for deduplication
8pub type TxHash = [u8; 32];
9
10/// A transaction entry in the priority mempool.
11#[derive(Clone)]
12struct TxEntry {
13    tx: Vec<u8>,
14    priority: u64,
15    gas_wanted: u64,
16    hash: TxHash,
17}
18
19impl Eq for TxEntry {}
20
21impl PartialEq for TxEntry {
22    fn eq(&self, other: &Self) -> bool {
23        self.cmp(other) == Ordering::Equal
24    }
25}
26
27/// Order by (priority ASC, hash ASC) so the *first* element in BTreeSet
28/// is the lowest-priority tx and the *last* is the highest.
29/// `Eq` is derived from `Ord` for BTreeSet consistency.
30impl Ord for TxEntry {
31    fn cmp(&self, other: &Self) -> Ordering {
32        self.priority
33            .cmp(&other.priority)
34            .then_with(|| self.hash.cmp(&other.hash))
35    }
36}
37
38impl PartialOrd for TxEntry {
39    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
40        Some(self.cmp(other))
41    }
42}
43
44/// Priority-based mempool with deduplication, eviction, and replace-by-fee (RBF).
45///
46/// Transactions are ordered by priority (highest first). When the pool is
47/// full, a new transaction with higher priority than the lowest-priority
48/// entry will evict it. This prevents spam DoS and enables fee-based
49/// ordering for DeFi applications.
50///
51/// RBF: submitting the same tx bytes with a higher priority replaces the
52/// existing pending entry. This allows wallets to bump fees on stuck txs.
53pub struct Mempool {
54    entries: Mutex<std::collections::BTreeSet<TxEntry>>,
55    /// Maps tx hash → current priority for RBF and for safe removal from the BTreeSet.
56    seen: Mutex<HashMap<TxHash, u64>>,
57    max_size: usize,
58    max_tx_bytes: usize,
59}
60
61impl Mempool {
62    pub fn new(max_size: usize, max_tx_bytes: usize) -> Self {
63        Self {
64            entries: Mutex::new(std::collections::BTreeSet::new()),
65            seen: Mutex::new(HashMap::new()),
66            max_size,
67            max_tx_bytes,
68        }
69    }
70
71    /// Add a transaction with a given priority.
72    ///
73    /// Returns `true` if accepted. When the pool is full, the new tx is
74    /// accepted only if its priority exceeds the lowest-priority entry,
75    /// which is then evicted.
76    ///
77    /// **Replace-by-fee (RBF):** if the same tx bytes are already pending
78    /// with a *lower* priority, the existing entry is replaced with the
79    /// new higher-priority one. This lets wallets bump stuck transactions.
80    pub async fn add_tx(&self, tx: Vec<u8>, priority: u64) -> bool {
81        self.add_tx_with_gas(tx, priority, 0).await
82    }
83
84    /// Add a transaction with priority and gas_wanted.
85    pub async fn add_tx_with_gas(&self, tx: Vec<u8>, priority: u64, gas_wanted: u64) -> bool {
86        if tx.len() > self.max_tx_bytes {
87            debug!(size = tx.len(), max = self.max_tx_bytes, "tx too large");
88            return false;
89        }
90
91        let hash = Self::hash_tx(&tx);
92
93        // Lock order: entries first, then seen
94        let mut entries = self.entries.lock().await;
95        let mut seen = self.seen.lock().await;
96
97        if let Some(&existing_priority) = seen.get(&hash) {
98            if priority <= existing_priority {
99                // Exact duplicate or lower-fee resubmission: reject.
100                return false;
101            }
102            // RBF: remove the old lower-priority entry and replace it.
103            let old = TxEntry {
104                tx: tx.clone(),
105                priority: existing_priority,
106                gas_wanted: 0,
107                hash,
108            };
109            entries.remove(&old);
110            seen.insert(hash, priority);
111            entries.insert(TxEntry {
112                tx,
113                priority,
114                gas_wanted,
115                hash,
116            });
117            debug!(
118                old = existing_priority,
119                new = priority,
120                "replaced tx via RBF"
121            );
122            return true;
123        }
124
125        if entries.len() >= self.max_size {
126            // Check if new tx beats the lowest-priority entry
127            if let Some(lowest) = entries.first() {
128                if priority <= lowest.priority {
129                    debug!(
130                        priority,
131                        lowest = lowest.priority,
132                        "mempool full, priority too low"
133                    );
134                    return false;
135                }
136                // Evict lowest
137                let evicted = entries.pop_first().expect("just checked non-empty");
138                seen.remove(&evicted.hash);
139                debug!(
140                    evicted_priority = evicted.priority,
141                    new_priority = priority,
142                    "evicted low-priority tx"
143                );
144            }
145        }
146
147        seen.insert(hash, priority);
148        entries.insert(TxEntry {
149            tx,
150            priority,
151            gas_wanted,
152            hash,
153        });
154        true
155    }
156
157    /// Collect transactions for a block proposal (up to max_bytes and max_gas total).
158    /// Collected transactions are removed from the pool and the seen set.
159    /// Transactions are collected in priority order (highest first).
160    /// The payload is length-prefixed: `[u32_le len][bytes]...`
161    ///
162    /// `max_gas` of 0 disables gas accounting (byte limit only).
163    pub async fn collect_payload(&self, max_bytes: usize) -> Vec<u8> {
164        self.collect_payload_with_gas(max_bytes, 0).await
165    }
166
167    /// Collect with both byte and gas limits.
168    ///
169    /// Skips transactions that exceed the remaining gas budget (instead of
170    /// stopping) to avoid head-of-line starvation by a single high-gas tx.
171    pub async fn collect_payload_with_gas(&self, max_bytes: usize, max_gas: u64) -> Vec<u8> {
172        let mut entries = self.entries.lock().await;
173        let mut seen = self.seen.lock().await;
174        let mut payload = Vec::new();
175        let mut total_gas = 0u64;
176        let mut skipped = Vec::new();
177
178        while let Some(entry) = entries.pop_last() {
179            // Byte limit: hard stop (all remaining txs are at most this size or smaller,
180            // but we can't know without iterating — stop here for simplicity).
181            if payload.len() + 4 + entry.tx.len() > max_bytes {
182                skipped.push(entry);
183                break;
184            }
185            // Gas limit: skip this tx but continue collecting smaller ones.
186            if max_gas > 0 && total_gas + entry.gas_wanted > max_gas {
187                skipped.push(entry);
188                continue;
189            }
190            seen.remove(&entry.hash);
191            total_gas += entry.gas_wanted;
192            let len = entry.tx.len() as u32;
193            payload.extend_from_slice(&len.to_le_bytes());
194            payload.extend_from_slice(&entry.tx);
195        }
196
197        // Re-insert skipped entries back into the pool.
198        for entry in skipped {
199            entries.insert(entry);
200        }
201
202        payload
203    }
204
205    /// Reap collected payload back into individual transactions
206    pub fn decode_payload(payload: &[u8]) -> Vec<Vec<u8>> {
207        let mut txs = Vec::new();
208        let mut offset = 0;
209        while offset + 4 <= payload.len() {
210            let len = u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
211            offset += 4;
212            if offset + len > payload.len() {
213                break;
214            }
215            txs.push(payload[offset..offset + len].to_vec());
216            offset += len;
217        }
218        txs
219    }
220
221    pub async fn size(&self) -> usize {
222        self.entries.lock().await.len()
223    }
224
225    fn hash_tx(tx: &[u8]) -> TxHash {
226        *blake3::hash(tx).as_bytes()
227    }
228}
229
230impl Default for Mempool {
231    fn default() -> Self {
232        Self::new(10_000, 1_048_576) // 10k txs, 1MB max per tx
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[tokio::test]
241    async fn test_add_and_collect() {
242        let pool = Mempool::new(100, 1024);
243        assert!(pool.add_tx(b"tx1".to_vec(), 10).await);
244        assert!(pool.add_tx(b"tx2".to_vec(), 20).await);
245        assert_eq!(pool.size().await, 2);
246
247        let payload = pool.collect_payload(1024).await;
248        let txs = Mempool::decode_payload(&payload);
249        assert_eq!(txs.len(), 2);
250        // Higher priority first
251        assert_eq!(txs[0], b"tx2");
252        assert_eq!(txs[1], b"tx1");
253    }
254
255    #[tokio::test]
256    async fn test_dedup() {
257        let pool = Mempool::new(100, 1024);
258        assert!(pool.add_tx(b"tx1".to_vec(), 10).await);
259        assert!(!pool.add_tx(b"tx1".to_vec(), 10).await); // same priority → rejected
260        assert!(!pool.add_tx(b"tx1".to_vec(), 5).await); // lower priority → rejected
261        assert_eq!(pool.size().await, 1);
262    }
263
264    #[tokio::test]
265    async fn test_rbf_replace_by_fee() {
266        let pool = Mempool::new(100, 1024);
267        // Submit tx with low priority
268        assert!(pool.add_tx(b"tx1".to_vec(), 5).await);
269        assert_eq!(pool.size().await, 1);
270        // Re-submit same bytes with higher priority → RBF accepted
271        assert!(pool.add_tx(b"tx1".to_vec(), 20).await);
272        // Pool should still have exactly 1 entry
273        assert_eq!(pool.size().await, 1);
274        // Collected tx should carry the new higher priority (collected first)
275        let payload = pool.collect_payload(1024).await;
276        let txs = Mempool::decode_payload(&payload);
277        assert_eq!(txs.len(), 1);
278        assert_eq!(txs[0], b"tx1");
279    }
280
281    #[tokio::test]
282    async fn test_eviction_by_priority() {
283        let pool = Mempool::new(2, 1024);
284        assert!(pool.add_tx(b"low".to_vec(), 1).await);
285        assert!(pool.add_tx(b"mid".to_vec(), 5).await);
286        // Pool full, but new tx has higher priority → evicts lowest
287        assert!(pool.add_tx(b"high".to_vec(), 10).await);
288        assert_eq!(pool.size().await, 2);
289
290        let payload = pool.collect_payload(1024).await;
291        let txs = Mempool::decode_payload(&payload);
292        assert_eq!(txs.len(), 2);
293        assert_eq!(txs[0], b"high");
294        assert_eq!(txs[1], b"mid");
295    }
296
297    #[tokio::test]
298    async fn test_reject_low_priority_when_full() {
299        let pool = Mempool::new(2, 1024);
300        assert!(pool.add_tx(b"a".to_vec(), 5).await);
301        assert!(pool.add_tx(b"b".to_vec(), 10).await);
302        // New tx has lower priority than lowest → rejected
303        assert!(!pool.add_tx(b"c".to_vec(), 3).await);
304        assert_eq!(pool.size().await, 2);
305    }
306
307    #[tokio::test]
308    async fn test_tx_too_large() {
309        let pool = Mempool::new(100, 4);
310        assert!(!pool.add_tx(b"toolarge".to_vec(), 10).await);
311        assert!(pool.add_tx(b"ok".to_vec(), 10).await);
312    }
313
314    #[tokio::test]
315    async fn test_collect_respects_max_bytes() {
316        let pool = Mempool::new(100, 1024);
317        pool.add_tx(b"aaaa".to_vec(), 1).await;
318        pool.add_tx(b"bbbb".to_vec(), 2).await;
319        pool.add_tx(b"cccc".to_vec(), 3).await;
320
321        // Each tx: 4 bytes len prefix + 4 bytes data = 8 bytes
322        // max_bytes = 17 should fit 2 txs (16 bytes) but not 3 (24 bytes)
323        let payload = pool.collect_payload(17).await;
324        let txs = Mempool::decode_payload(&payload);
325        assert_eq!(txs.len(), 2);
326        // Highest priority first
327        assert_eq!(txs[0], b"cccc");
328        assert_eq!(txs[1], b"bbbb");
329    }
330
331    #[test]
332    fn test_decode_empty_payload() {
333        let txs = Mempool::decode_payload(&[]);
334        assert!(txs.is_empty());
335    }
336
337    #[test]
338    fn test_decode_truncated_payload() {
339        // Only 2 bytes when expecting at least 4 for length prefix
340        let txs = Mempool::decode_payload(&[1, 2]);
341        assert!(txs.is_empty());
342    }
343
344    #[test]
345    fn test_decode_payload_with_truncated_data() {
346        // Length prefix says 100 bytes but only 3 available
347        let mut payload = vec![];
348        payload.extend_from_slice(&100u32.to_le_bytes());
349        payload.extend_from_slice(&[1, 2, 3]);
350        let txs = Mempool::decode_payload(&payload);
351        assert!(txs.is_empty());
352    }
353
354    #[tokio::test]
355    async fn test_empty_tx() {
356        let pool = Mempool::new(100, 1024);
357        assert!(pool.add_tx(vec![], 0).await);
358        let payload = pool.collect_payload(1024).await;
359        let txs = Mempool::decode_payload(&payload);
360        assert_eq!(txs.len(), 1);
361        assert!(txs[0].is_empty());
362    }
363}