Skip to main content

hotmint_mempool/
lib.rs

1use std::collections::{HashSet, VecDeque};
2use tokio::sync::Mutex;
3use tracing::debug;
4
5/// Transaction hash for deduplication
6pub type TxHash = [u8; 32];
7
8/// Simple mempool: FIFO queue with deduplication
9pub struct Mempool {
10    txs: Mutex<VecDeque<Vec<u8>>>,
11    seen: Mutex<HashSet<TxHash>>,
12    max_size: usize,
13    max_tx_bytes: usize,
14}
15
16impl Mempool {
17    pub fn new(max_size: usize, max_tx_bytes: usize) -> Self {
18        Self {
19            txs: Mutex::new(VecDeque::new()),
20            seen: Mutex::new(HashSet::new()),
21            max_size,
22            max_tx_bytes,
23        }
24    }
25
26    /// Add a transaction to the mempool. Returns false if rejected.
27    pub async fn add_tx(&self, tx: Vec<u8>) -> bool {
28        if tx.len() > self.max_tx_bytes {
29            debug!(size = tx.len(), max = self.max_tx_bytes, "tx too large");
30            return false;
31        }
32
33        let hash = Self::hash_tx(&tx);
34
35        // Lock order: txs first, then seen (same as collect_payload)
36        let mut txs = self.txs.lock().await;
37        let mut seen = self.seen.lock().await;
38
39        if seen.contains(&hash) {
40            return false;
41        }
42        if txs.len() >= self.max_size {
43            debug!(size = txs.len(), max = self.max_size, "mempool full");
44            return false;
45        }
46
47        seen.insert(hash);
48        txs.push_back(tx);
49        true
50    }
51
52    /// Collect transactions for a block proposal (up to max_bytes total).
53    /// Collected transactions are removed from the pool and the seen set.
54    /// The payload is length-prefixed: `[u32_le len][bytes]...`
55    pub async fn collect_payload(&self, max_bytes: usize) -> Vec<u8> {
56        let mut txs = self.txs.lock().await;
57        let mut seen = self.seen.lock().await;
58        let mut payload = Vec::new();
59
60        while let Some(tx) = txs.front() {
61            // 4 bytes length prefix + tx bytes
62            if payload.len() + 4 + tx.len() > max_bytes {
63                break;
64            }
65            let tx = txs.pop_front().unwrap();
66            seen.remove(&Self::hash_tx(&tx));
67            let len = tx.len() as u32;
68            payload.extend_from_slice(&len.to_le_bytes());
69            payload.extend_from_slice(&tx);
70        }
71
72        payload
73    }
74
75    /// Reap collected payload back into individual transactions
76    pub fn decode_payload(payload: &[u8]) -> Vec<Vec<u8>> {
77        let mut txs = Vec::new();
78        let mut offset = 0;
79        while offset + 4 <= payload.len() {
80            let len = u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
81            offset += 4;
82            if offset + len > payload.len() {
83                break;
84            }
85            txs.push(payload[offset..offset + len].to_vec());
86            offset += len;
87        }
88        txs
89    }
90
91    pub async fn size(&self) -> usize {
92        self.txs.lock().await.len()
93    }
94
95    fn hash_tx(tx: &[u8]) -> TxHash {
96        blake3_hash(tx)
97    }
98}
99
100fn blake3_hash(data: &[u8]) -> TxHash {
101    *blake3::hash(data).as_bytes()
102}
103
104impl Default for Mempool {
105    fn default() -> Self {
106        Self::new(10_000, 1_048_576) // 10k txs, 1MB max per tx
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[tokio::test]
115    async fn test_add_and_collect() {
116        let pool = Mempool::new(100, 1024);
117        assert!(pool.add_tx(b"tx1".to_vec()).await);
118        assert!(pool.add_tx(b"tx2".to_vec()).await);
119        assert_eq!(pool.size().await, 2);
120
121        let payload = pool.collect_payload(1024).await;
122        let txs = Mempool::decode_payload(&payload);
123        assert_eq!(txs.len(), 2);
124        assert_eq!(txs[0], b"tx1");
125        assert_eq!(txs[1], b"tx2");
126    }
127
128    #[tokio::test]
129    async fn test_dedup() {
130        let pool = Mempool::new(100, 1024);
131        assert!(pool.add_tx(b"tx1".to_vec()).await);
132        assert!(!pool.add_tx(b"tx1".to_vec()).await); // duplicate
133        assert_eq!(pool.size().await, 1);
134    }
135
136    #[tokio::test]
137    async fn test_max_size() {
138        let pool = Mempool::new(2, 1024);
139        assert!(pool.add_tx(b"tx1".to_vec()).await);
140        assert!(pool.add_tx(b"tx2".to_vec()).await);
141        assert!(!pool.add_tx(b"tx3".to_vec()).await); // full
142    }
143
144    #[tokio::test]
145    async fn test_tx_too_large() {
146        let pool = Mempool::new(100, 4);
147        assert!(!pool.add_tx(b"toolarge".to_vec()).await);
148        assert!(pool.add_tx(b"ok".to_vec()).await);
149    }
150
151    #[tokio::test]
152    async fn test_collect_respects_max_bytes() {
153        let pool = Mempool::new(100, 1024);
154        pool.add_tx(b"aaaa".to_vec()).await;
155        pool.add_tx(b"bbbb".to_vec()).await;
156        pool.add_tx(b"cccc".to_vec()).await;
157
158        // Each tx: 4 bytes len prefix + 4 bytes data = 8 bytes
159        // max_bytes = 17 should fit 2 txs (16 bytes) but not 3 (24 bytes)
160        let payload = pool.collect_payload(17).await;
161        let txs = Mempool::decode_payload(&payload);
162        assert_eq!(txs.len(), 2);
163    }
164
165    #[test]
166    fn test_decode_empty_payload() {
167        let txs = Mempool::decode_payload(&[]);
168        assert!(txs.is_empty());
169    }
170
171    #[test]
172    fn test_decode_truncated_payload() {
173        // Only 2 bytes when expecting at least 4 for length prefix
174        let txs = Mempool::decode_payload(&[1, 2]);
175        assert!(txs.is_empty());
176    }
177
178    #[test]
179    fn test_decode_payload_with_truncated_data() {
180        // Length prefix says 100 bytes but only 3 available
181        let mut payload = vec![];
182        payload.extend_from_slice(&100u32.to_le_bytes());
183        payload.extend_from_slice(&[1, 2, 3]);
184        let txs = Mempool::decode_payload(&payload);
185        assert!(txs.is_empty());
186    }
187
188    #[tokio::test]
189    async fn test_empty_tx() {
190        let pool = Mempool::new(100, 1024);
191        assert!(pool.add_tx(vec![]).await);
192        let payload = pool.collect_payload(1024).await;
193        let txs = Mempool::decode_payload(&payload);
194        assert_eq!(txs.len(), 1);
195        assert!(txs[0].is_empty());
196    }
197}