Skip to main content

hotmint_mempool/
lib.rs

1use std::collections::VecDeque;
2use tokio::sync::Mutex;
3use tracing::debug;
4
5/// Transaction hash for deduplication
6pub type TxHash = [u8; 32];
7
8/// Simple mempool: FIFO queue with deduplication
9pub struct Mempool {
10    txs: Mutex<VecDeque<Vec<u8>>>,
11    seen: Mutex<std::collections::HashSet<TxHash>>,
12    max_size: usize,
13    max_tx_bytes: usize,
14}
15
16impl Mempool {
17    pub fn new(max_size: usize, max_tx_bytes: usize) -> Self {
18        Self {
19            txs: Mutex::new(VecDeque::new()),
20            seen: Mutex::new(std::collections::HashSet::new()),
21            max_size,
22            max_tx_bytes,
23        }
24    }
25
26    /// Add a transaction to the mempool. Returns false if rejected.
27    pub async fn add_tx(&self, tx: Vec<u8>) -> bool {
28        if tx.len() > self.max_tx_bytes {
29            debug!(size = tx.len(), max = self.max_tx_bytes, "tx too large");
30            return false;
31        }
32
33        let hash = Self::hash_tx(&tx);
34
35        let mut seen = self.seen.lock().await;
36        if seen.contains(&hash) {
37            return false;
38        }
39
40        let mut txs = self.txs.lock().await;
41        if txs.len() >= self.max_size {
42            debug!(size = txs.len(), max = self.max_size, "mempool full");
43            return false;
44        }
45
46        seen.insert(hash);
47        txs.push_back(tx);
48        true
49    }
50
51    /// Collect transactions for a block proposal (up to max_bytes total).
52    /// Collected transactions are removed from the pool.
53    /// The payload is length-prefixed: `[u32_le len][bytes]...`
54    pub async fn collect_payload(&self, max_bytes: usize) -> Vec<u8> {
55        let mut txs = self.txs.lock().await;
56        let mut payload = Vec::new();
57
58        while let Some(tx) = txs.front() {
59            // 4 bytes length prefix + tx bytes
60            if payload.len() + 4 + tx.len() > max_bytes {
61                break;
62            }
63            let tx = txs.pop_front().unwrap();
64            let len = tx.len() as u32;
65            payload.extend_from_slice(&len.to_le_bytes());
66            payload.extend_from_slice(&tx);
67        }
68
69        payload
70    }
71
72    /// Reap collected payload back into individual transactions
73    pub fn decode_payload(payload: &[u8]) -> Vec<Vec<u8>> {
74        let mut txs = Vec::new();
75        let mut offset = 0;
76        while offset + 4 <= payload.len() {
77            let len = u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
78            offset += 4;
79            if offset + len > payload.len() {
80                break;
81            }
82            txs.push(payload[offset..offset + len].to_vec());
83            offset += len;
84        }
85        txs
86    }
87
88    pub async fn size(&self) -> usize {
89        self.txs.lock().await.len()
90    }
91
92    fn hash_tx(tx: &[u8]) -> TxHash {
93        blake3_hash(tx)
94    }
95}
96
97fn blake3_hash(data: &[u8]) -> TxHash {
98    // Simple hash using first 32 bytes of a basic hash
99    // Not using blake3 crate here to keep mempool dependency-light
100    use std::collections::hash_map::DefaultHasher;
101    use std::hash::{Hash, Hasher};
102    let mut hasher = DefaultHasher::new();
103    data.hash(&mut hasher);
104    let h = hasher.finish();
105    let mut out = [0u8; 32];
106    out[..8].copy_from_slice(&h.to_le_bytes());
107    // Add length for disambiguation
108    out[8..12].copy_from_slice(&(data.len() as u32).to_le_bytes());
109    out
110}
111
112impl Default for Mempool {
113    fn default() -> Self {
114        Self::new(10_000, 1_048_576) // 10k txs, 1MB max per tx
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[tokio::test]
123    async fn test_add_and_collect() {
124        let pool = Mempool::new(100, 1024);
125        assert!(pool.add_tx(b"tx1".to_vec()).await);
126        assert!(pool.add_tx(b"tx2".to_vec()).await);
127        assert_eq!(pool.size().await, 2);
128
129        let payload = pool.collect_payload(1024).await;
130        let txs = Mempool::decode_payload(&payload);
131        assert_eq!(txs.len(), 2);
132        assert_eq!(txs[0], b"tx1");
133        assert_eq!(txs[1], b"tx2");
134    }
135
136    #[tokio::test]
137    async fn test_dedup() {
138        let pool = Mempool::new(100, 1024);
139        assert!(pool.add_tx(b"tx1".to_vec()).await);
140        assert!(!pool.add_tx(b"tx1".to_vec()).await); // duplicate
141        assert_eq!(pool.size().await, 1);
142    }
143
144    #[tokio::test]
145    async fn test_max_size() {
146        let pool = Mempool::new(2, 1024);
147        assert!(pool.add_tx(b"tx1".to_vec()).await);
148        assert!(pool.add_tx(b"tx2".to_vec()).await);
149        assert!(!pool.add_tx(b"tx3".to_vec()).await); // full
150    }
151
152    #[tokio::test]
153    async fn test_tx_too_large() {
154        let pool = Mempool::new(100, 4);
155        assert!(!pool.add_tx(b"toolarge".to_vec()).await);
156        assert!(pool.add_tx(b"ok".to_vec()).await);
157    }
158
159    #[tokio::test]
160    async fn test_collect_respects_max_bytes() {
161        let pool = Mempool::new(100, 1024);
162        pool.add_tx(b"aaaa".to_vec()).await;
163        pool.add_tx(b"bbbb".to_vec()).await;
164        pool.add_tx(b"cccc".to_vec()).await;
165
166        // Each tx: 4 bytes len prefix + 4 bytes data = 8 bytes
167        // max_bytes = 17 should fit 2 txs (16 bytes) but not 3 (24 bytes)
168        let payload = pool.collect_payload(17).await;
169        let txs = Mempool::decode_payload(&payload);
170        assert_eq!(txs.len(), 2);
171    }
172
173    #[test]
174    fn test_decode_empty_payload() {
175        let txs = Mempool::decode_payload(&[]);
176        assert!(txs.is_empty());
177    }
178
179    #[test]
180    fn test_decode_truncated_payload() {
181        // Only 2 bytes when expecting at least 4 for length prefix
182        let txs = Mempool::decode_payload(&[1, 2]);
183        assert!(txs.is_empty());
184    }
185
186    #[test]
187    fn test_decode_payload_with_truncated_data() {
188        // Length prefix says 100 bytes but only 3 available
189        let mut payload = vec![];
190        payload.extend_from_slice(&100u32.to_le_bytes());
191        payload.extend_from_slice(&[1, 2, 3]);
192        let txs = Mempool::decode_payload(&payload);
193        assert!(txs.is_empty());
194    }
195
196    #[tokio::test]
197    async fn test_empty_tx() {
198        let pool = Mempool::new(100, 1024);
199        assert!(pool.add_tx(vec![]).await);
200        let payload = pool.collect_payload(1024).await;
201        let txs = Mempool::decode_payload(&payload);
202        assert_eq!(txs.len(), 1);
203        assert!(txs[0].is_empty());
204    }
205}