1use std::cmp::Ordering;
2use std::collections::HashMap;
3
4use tokio::sync::Mutex;
5use tracing::debug;
6
7pub type TxHash = [u8; 32];
9
10#[derive(Clone)]
12struct TxEntry {
13 tx: Vec<u8>,
14 priority: u64,
15 gas_wanted: u64,
16 hash: TxHash,
17}
18
19impl Eq for TxEntry {}
20
21impl PartialEq for TxEntry {
22 fn eq(&self, other: &Self) -> bool {
23 self.cmp(other) == Ordering::Equal
24 }
25}
26
27impl Ord for TxEntry {
31 fn cmp(&self, other: &Self) -> Ordering {
32 self.priority
33 .cmp(&other.priority)
34 .then_with(|| self.hash.cmp(&other.hash))
35 }
36}
37
38impl PartialOrd for TxEntry {
39 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
40 Some(self.cmp(other))
41 }
42}
43
44pub struct Mempool {
54 entries: Mutex<std::collections::BTreeSet<TxEntry>>,
55 seen: Mutex<HashMap<TxHash, u64>>,
57 max_size: usize,
58 max_tx_bytes: usize,
59}
60
61impl Mempool {
62 pub fn new(max_size: usize, max_tx_bytes: usize) -> Self {
63 Self {
64 entries: Mutex::new(std::collections::BTreeSet::new()),
65 seen: Mutex::new(HashMap::new()),
66 max_size,
67 max_tx_bytes,
68 }
69 }
70
71 pub async fn add_tx(&self, tx: Vec<u8>, priority: u64) -> bool {
81 self.add_tx_with_gas(tx, priority, 0).await
82 }
83
84 pub async fn add_tx_with_gas(&self, tx: Vec<u8>, priority: u64, gas_wanted: u64) -> bool {
86 if tx.len() > self.max_tx_bytes {
87 debug!(size = tx.len(), max = self.max_tx_bytes, "tx too large");
88 return false;
89 }
90
91 let hash = Self::hash_tx(&tx);
92
93 let mut entries = self.entries.lock().await;
95 let mut seen = self.seen.lock().await;
96
97 if let Some(&existing_priority) = seen.get(&hash) {
98 if priority <= existing_priority {
99 return false;
101 }
102 let old = TxEntry {
104 tx: tx.clone(),
105 priority: existing_priority,
106 gas_wanted: 0,
107 hash,
108 };
109 entries.remove(&old);
110 seen.insert(hash, priority);
111 entries.insert(TxEntry {
112 tx,
113 priority,
114 gas_wanted,
115 hash,
116 });
117 debug!(
118 old = existing_priority,
119 new = priority,
120 "replaced tx via RBF"
121 );
122 return true;
123 }
124
125 if entries.len() >= self.max_size {
126 if let Some(lowest) = entries.first() {
128 if priority <= lowest.priority {
129 debug!(
130 priority,
131 lowest = lowest.priority,
132 "mempool full, priority too low"
133 );
134 return false;
135 }
136 let evicted = entries.pop_first().expect("just checked non-empty");
138 seen.remove(&evicted.hash);
139 debug!(
140 evicted_priority = evicted.priority,
141 new_priority = priority,
142 "evicted low-priority tx"
143 );
144 }
145 }
146
147 seen.insert(hash, priority);
148 entries.insert(TxEntry {
149 tx,
150 priority,
151 gas_wanted,
152 hash,
153 });
154 true
155 }
156
157 pub async fn collect_payload(&self, max_bytes: usize) -> Vec<u8> {
164 self.collect_payload_with_gas(max_bytes, 0).await
165 }
166
167 pub async fn collect_payload_with_gas(&self, max_bytes: usize, max_gas: u64) -> Vec<u8> {
172 let mut entries = self.entries.lock().await;
173 let mut seen = self.seen.lock().await;
174 let mut payload = Vec::new();
175 let mut total_gas = 0u64;
176 let mut skipped = Vec::new();
177
178 while let Some(entry) = entries.pop_last() {
179 if payload.len() + 4 + entry.tx.len() > max_bytes {
182 skipped.push(entry);
183 break;
184 }
185 if max_gas > 0 && total_gas + entry.gas_wanted > max_gas {
187 skipped.push(entry);
188 continue;
189 }
190 seen.remove(&entry.hash);
191 total_gas += entry.gas_wanted;
192 let len = entry.tx.len() as u32;
193 payload.extend_from_slice(&len.to_le_bytes());
194 payload.extend_from_slice(&entry.tx);
195 }
196
197 for entry in skipped {
199 entries.insert(entry);
200 }
201
202 payload
203 }
204
205 pub fn decode_payload(payload: &[u8]) -> Vec<Vec<u8>> {
207 let mut txs = Vec::new();
208 let mut offset = 0;
209 while offset + 4 <= payload.len() {
210 let len = u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
211 offset += 4;
212 if offset + len > payload.len() {
213 break;
214 }
215 txs.push(payload[offset..offset + len].to_vec());
216 offset += len;
217 }
218 txs
219 }
220
221 pub async fn size(&self) -> usize {
222 self.entries.lock().await.len()
223 }
224
225 pub async fn recheck(&self, validator: impl Fn(&[u8]) -> Option<(u64, u64)>) {
232 let mut entries = self.entries.lock().await;
233 let mut seen = self.seen.lock().await;
234
235 let old: Vec<TxEntry> = entries.iter().cloned().collect();
236 entries.clear();
237 seen.clear();
238
239 let mut removed = 0usize;
240 for entry in old {
241 match validator(&entry.tx) {
242 Some((new_priority, new_gas)) => {
243 seen.insert(entry.hash, new_priority);
244 entries.insert(TxEntry {
245 tx: entry.tx,
246 priority: new_priority,
247 gas_wanted: new_gas,
248 hash: entry.hash,
249 });
250 }
251 None => {
252 removed += 1;
253 }
254 }
255 }
256
257 if removed > 0 {
258 debug!(
259 removed,
260 remaining = entries.len(),
261 "mempool recheck complete"
262 );
263 }
264 }
265
266 fn hash_tx(tx: &[u8]) -> TxHash {
267 *blake3::hash(tx).as_bytes()
268 }
269}
270
271impl Default for Mempool {
272 fn default() -> Self {
273 Self::new(10_000, 1_048_576) }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[tokio::test]
282 async fn test_add_and_collect() {
283 let pool = Mempool::new(100, 1024);
284 assert!(pool.add_tx(b"tx1".to_vec(), 10).await);
285 assert!(pool.add_tx(b"tx2".to_vec(), 20).await);
286 assert_eq!(pool.size().await, 2);
287
288 let payload = pool.collect_payload(1024).await;
289 let txs = Mempool::decode_payload(&payload);
290 assert_eq!(txs.len(), 2);
291 assert_eq!(txs[0], b"tx2");
293 assert_eq!(txs[1], b"tx1");
294 }
295
296 #[tokio::test]
297 async fn test_dedup() {
298 let pool = Mempool::new(100, 1024);
299 assert!(pool.add_tx(b"tx1".to_vec(), 10).await);
300 assert!(!pool.add_tx(b"tx1".to_vec(), 10).await); assert!(!pool.add_tx(b"tx1".to_vec(), 5).await); assert_eq!(pool.size().await, 1);
303 }
304
305 #[tokio::test]
306 async fn test_rbf_replace_by_fee() {
307 let pool = Mempool::new(100, 1024);
308 assert!(pool.add_tx(b"tx1".to_vec(), 5).await);
310 assert_eq!(pool.size().await, 1);
311 assert!(pool.add_tx(b"tx1".to_vec(), 20).await);
313 assert_eq!(pool.size().await, 1);
315 let payload = pool.collect_payload(1024).await;
317 let txs = Mempool::decode_payload(&payload);
318 assert_eq!(txs.len(), 1);
319 assert_eq!(txs[0], b"tx1");
320 }
321
322 #[tokio::test]
323 async fn test_eviction_by_priority() {
324 let pool = Mempool::new(2, 1024);
325 assert!(pool.add_tx(b"low".to_vec(), 1).await);
326 assert!(pool.add_tx(b"mid".to_vec(), 5).await);
327 assert!(pool.add_tx(b"high".to_vec(), 10).await);
329 assert_eq!(pool.size().await, 2);
330
331 let payload = pool.collect_payload(1024).await;
332 let txs = Mempool::decode_payload(&payload);
333 assert_eq!(txs.len(), 2);
334 assert_eq!(txs[0], b"high");
335 assert_eq!(txs[1], b"mid");
336 }
337
338 #[tokio::test]
339 async fn test_reject_low_priority_when_full() {
340 let pool = Mempool::new(2, 1024);
341 assert!(pool.add_tx(b"a".to_vec(), 5).await);
342 assert!(pool.add_tx(b"b".to_vec(), 10).await);
343 assert!(!pool.add_tx(b"c".to_vec(), 3).await);
345 assert_eq!(pool.size().await, 2);
346 }
347
348 #[tokio::test]
349 async fn test_tx_too_large() {
350 let pool = Mempool::new(100, 4);
351 assert!(!pool.add_tx(b"toolarge".to_vec(), 10).await);
352 assert!(pool.add_tx(b"ok".to_vec(), 10).await);
353 }
354
355 #[tokio::test]
356 async fn test_collect_respects_max_bytes() {
357 let pool = Mempool::new(100, 1024);
358 pool.add_tx(b"aaaa".to_vec(), 1).await;
359 pool.add_tx(b"bbbb".to_vec(), 2).await;
360 pool.add_tx(b"cccc".to_vec(), 3).await;
361
362 let payload = pool.collect_payload(17).await;
365 let txs = Mempool::decode_payload(&payload);
366 assert_eq!(txs.len(), 2);
367 assert_eq!(txs[0], b"cccc");
369 assert_eq!(txs[1], b"bbbb");
370 }
371
372 #[test]
373 fn test_decode_empty_payload() {
374 let txs = Mempool::decode_payload(&[]);
375 assert!(txs.is_empty());
376 }
377
378 #[test]
379 fn test_decode_truncated_payload() {
380 let txs = Mempool::decode_payload(&[1, 2]);
382 assert!(txs.is_empty());
383 }
384
385 #[test]
386 fn test_decode_payload_with_truncated_data() {
387 let mut payload = vec![];
389 payload.extend_from_slice(&100u32.to_le_bytes());
390 payload.extend_from_slice(&[1, 2, 3]);
391 let txs = Mempool::decode_payload(&payload);
392 assert!(txs.is_empty());
393 }
394
395 #[tokio::test]
396 async fn test_empty_tx() {
397 let pool = Mempool::new(100, 1024);
398 assert!(pool.add_tx(vec![], 0).await);
399 let payload = pool.collect_payload(1024).await;
400 let txs = Mempool::decode_payload(&payload);
401 assert_eq!(txs.len(), 1);
402 assert!(txs[0].is_empty());
403 }
404}