1use std::cmp::Ordering;
2use std::collections::HashMap;
3
4use tokio::sync::Mutex;
5use tracing::debug;
6
7pub type TxHash = [u8; 32];
9
10#[derive(Clone)]
12struct TxEntry {
13 tx: Vec<u8>,
14 priority: u64,
15 gas_wanted: u64,
16 hash: TxHash,
17}
18
19impl Eq for TxEntry {}
20
21impl PartialEq for TxEntry {
22 fn eq(&self, other: &Self) -> bool {
23 self.cmp(other) == Ordering::Equal
24 }
25}
26
27impl Ord for TxEntry {
31 fn cmp(&self, other: &Self) -> Ordering {
32 self.priority
33 .cmp(&other.priority)
34 .then_with(|| self.hash.cmp(&other.hash))
35 }
36}
37
38impl PartialOrd for TxEntry {
39 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
40 Some(self.cmp(other))
41 }
42}
43
44pub struct Mempool {
54 entries: Mutex<std::collections::BTreeSet<TxEntry>>,
55 seen: Mutex<HashMap<TxHash, u64>>,
57 max_size: usize,
58 max_tx_bytes: usize,
59}
60
61impl Mempool {
62 pub fn new(max_size: usize, max_tx_bytes: usize) -> Self {
63 Self {
64 entries: Mutex::new(std::collections::BTreeSet::new()),
65 seen: Mutex::new(HashMap::new()),
66 max_size,
67 max_tx_bytes,
68 }
69 }
70
71 pub async fn add_tx(&self, tx: Vec<u8>, priority: u64) -> bool {
81 self.add_tx_with_gas(tx, priority, 0).await
82 }
83
84 pub async fn add_tx_with_gas(&self, tx: Vec<u8>, priority: u64, gas_wanted: u64) -> bool {
86 if tx.len() > self.max_tx_bytes {
87 debug!(size = tx.len(), max = self.max_tx_bytes, "tx too large");
88 return false;
89 }
90
91 let hash = Self::hash_tx(&tx);
92
93 let mut entries = self.entries.lock().await;
95 let mut seen = self.seen.lock().await;
96
97 if let Some(&existing_priority) = seen.get(&hash) {
98 if priority <= existing_priority {
99 return false;
101 }
102 let old = TxEntry {
104 tx: tx.clone(),
105 priority: existing_priority,
106 gas_wanted: 0,
107 hash,
108 };
109 entries.remove(&old);
110 seen.insert(hash, priority);
111 entries.insert(TxEntry {
112 tx,
113 priority,
114 gas_wanted,
115 hash,
116 });
117 debug!(
118 old = existing_priority,
119 new = priority,
120 "replaced tx via RBF"
121 );
122 return true;
123 }
124
125 if entries.len() >= self.max_size {
126 if let Some(lowest) = entries.first() {
128 if priority <= lowest.priority {
129 debug!(
130 priority,
131 lowest = lowest.priority,
132 "mempool full, priority too low"
133 );
134 return false;
135 }
136 let evicted = entries.pop_first().expect("just checked non-empty");
138 seen.remove(&evicted.hash);
139 debug!(
140 evicted_priority = evicted.priority,
141 new_priority = priority,
142 "evicted low-priority tx"
143 );
144 }
145 }
146
147 seen.insert(hash, priority);
148 entries.insert(TxEntry {
149 tx,
150 priority,
151 gas_wanted,
152 hash,
153 });
154 true
155 }
156
157 pub async fn collect_payload(&self, max_bytes: usize) -> Vec<u8> {
164 self.collect_payload_with_gas(max_bytes, 0).await
165 }
166
167 pub async fn collect_payload_with_gas(&self, max_bytes: usize, max_gas: u64) -> Vec<u8> {
172 let mut entries = self.entries.lock().await;
173 let mut seen = self.seen.lock().await;
174 let mut payload = Vec::new();
175 let mut total_gas = 0u64;
176 let mut skipped = Vec::new();
177 const MAX_SKIPPED: usize = 200;
179
180 while let Some(entry) = entries.pop_last() {
181 if payload.len() + 4 + entry.tx.len() > max_bytes {
183 skipped.push(entry);
184 if skipped.len() >= MAX_SKIPPED {
185 break;
186 }
187 continue;
188 }
189 if max_gas > 0 && total_gas + entry.gas_wanted > max_gas {
191 skipped.push(entry);
192 if skipped.len() >= MAX_SKIPPED {
193 break;
194 }
195 continue;
196 }
197 seen.remove(&entry.hash);
198 total_gas += entry.gas_wanted;
199 let len = entry.tx.len() as u32;
200 payload.extend_from_slice(&len.to_le_bytes());
201 payload.extend_from_slice(&entry.tx);
202 }
203
204 for entry in skipped {
206 entries.insert(entry);
207 }
208
209 payload
210 }
211
212 pub fn decode_payload(payload: &[u8]) -> Vec<Vec<u8>> {
214 let mut txs = Vec::new();
215 let mut offset = 0;
216 while offset + 4 <= payload.len() {
217 let len = u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
218 offset += 4;
219 if offset + len > payload.len() {
220 break;
221 }
222 txs.push(payload[offset..offset + len].to_vec());
223 offset += len;
224 }
225 txs
226 }
227
228 pub async fn size(&self) -> usize {
229 self.entries.lock().await.len()
230 }
231
232 pub async fn recheck(&self, validator: impl Fn(&[u8]) -> Option<(u64, u64)>) {
239 let mut entries = self.entries.lock().await;
240 let mut seen = self.seen.lock().await;
241
242 let old: Vec<TxEntry> = entries.iter().cloned().collect();
243 entries.clear();
244 seen.clear();
245
246 let mut removed = 0usize;
247 for entry in old {
248 match validator(&entry.tx) {
249 Some((new_priority, new_gas)) => {
250 seen.insert(entry.hash, new_priority);
251 entries.insert(TxEntry {
252 tx: entry.tx,
253 priority: new_priority,
254 gas_wanted: new_gas,
255 hash: entry.hash,
256 });
257 }
258 None => {
259 removed += 1;
260 }
261 }
262 }
263
264 if removed > 0 {
265 debug!(
266 removed,
267 remaining = entries.len(),
268 "mempool recheck complete"
269 );
270 }
271 }
272
273 fn hash_tx(tx: &[u8]) -> TxHash {
274 *blake3::hash(tx).as_bytes()
275 }
276}
277
278impl Default for Mempool {
279 fn default() -> Self {
280 Self::new(10_000, 1_048_576) }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[tokio::test]
289 async fn test_add_and_collect() {
290 let pool = Mempool::new(100, 1024);
291 assert!(pool.add_tx(b"tx1".to_vec(), 10).await);
292 assert!(pool.add_tx(b"tx2".to_vec(), 20).await);
293 assert_eq!(pool.size().await, 2);
294
295 let payload = pool.collect_payload(1024).await;
296 let txs = Mempool::decode_payload(&payload);
297 assert_eq!(txs.len(), 2);
298 assert_eq!(txs[0], b"tx2");
300 assert_eq!(txs[1], b"tx1");
301 }
302
303 #[tokio::test]
304 async fn test_dedup() {
305 let pool = Mempool::new(100, 1024);
306 assert!(pool.add_tx(b"tx1".to_vec(), 10).await);
307 assert!(!pool.add_tx(b"tx1".to_vec(), 10).await); assert!(!pool.add_tx(b"tx1".to_vec(), 5).await); assert_eq!(pool.size().await, 1);
310 }
311
312 #[tokio::test]
313 async fn test_rbf_replace_by_fee() {
314 let pool = Mempool::new(100, 1024);
315 assert!(pool.add_tx(b"tx1".to_vec(), 5).await);
317 assert_eq!(pool.size().await, 1);
318 assert!(pool.add_tx(b"tx1".to_vec(), 20).await);
320 assert_eq!(pool.size().await, 1);
322 let payload = pool.collect_payload(1024).await;
324 let txs = Mempool::decode_payload(&payload);
325 assert_eq!(txs.len(), 1);
326 assert_eq!(txs[0], b"tx1");
327 }
328
329 #[tokio::test]
330 async fn test_eviction_by_priority() {
331 let pool = Mempool::new(2, 1024);
332 assert!(pool.add_tx(b"low".to_vec(), 1).await);
333 assert!(pool.add_tx(b"mid".to_vec(), 5).await);
334 assert!(pool.add_tx(b"high".to_vec(), 10).await);
336 assert_eq!(pool.size().await, 2);
337
338 let payload = pool.collect_payload(1024).await;
339 let txs = Mempool::decode_payload(&payload);
340 assert_eq!(txs.len(), 2);
341 assert_eq!(txs[0], b"high");
342 assert_eq!(txs[1], b"mid");
343 }
344
345 #[tokio::test]
346 async fn test_reject_low_priority_when_full() {
347 let pool = Mempool::new(2, 1024);
348 assert!(pool.add_tx(b"a".to_vec(), 5).await);
349 assert!(pool.add_tx(b"b".to_vec(), 10).await);
350 assert!(!pool.add_tx(b"c".to_vec(), 3).await);
352 assert_eq!(pool.size().await, 2);
353 }
354
355 #[tokio::test]
356 async fn test_tx_too_large() {
357 let pool = Mempool::new(100, 4);
358 assert!(!pool.add_tx(b"toolarge".to_vec(), 10).await);
359 assert!(pool.add_tx(b"ok".to_vec(), 10).await);
360 }
361
362 #[tokio::test]
363 async fn test_collect_respects_max_bytes() {
364 let pool = Mempool::new(100, 1024);
365 pool.add_tx(b"aaaa".to_vec(), 1).await;
366 pool.add_tx(b"bbbb".to_vec(), 2).await;
367 pool.add_tx(b"cccc".to_vec(), 3).await;
368
369 let payload = pool.collect_payload(17).await;
372 let txs = Mempool::decode_payload(&payload);
373 assert_eq!(txs.len(), 2);
374 assert_eq!(txs[0], b"cccc");
376 assert_eq!(txs[1], b"bbbb");
377 }
378
379 #[test]
380 fn test_decode_empty_payload() {
381 let txs = Mempool::decode_payload(&[]);
382 assert!(txs.is_empty());
383 }
384
385 #[test]
386 fn test_decode_truncated_payload() {
387 let txs = Mempool::decode_payload(&[1, 2]);
389 assert!(txs.is_empty());
390 }
391
392 #[test]
393 fn test_decode_payload_with_truncated_data() {
394 let mut payload = vec![];
396 payload.extend_from_slice(&100u32.to_le_bytes());
397 payload.extend_from_slice(&[1, 2, 3]);
398 let txs = Mempool::decode_payload(&payload);
399 assert!(txs.is_empty());
400 }
401
402 #[tokio::test]
403 async fn test_empty_tx() {
404 let pool = Mempool::new(100, 1024);
405 assert!(pool.add_tx(vec![], 0).await);
406 let payload = pool.collect_payload(1024).await;
407 let txs = Mempool::decode_payload(&payload);
408 assert_eq!(txs.len(), 1);
409 assert!(txs[0].is_empty());
410 }
411}