1use std::cmp::Ordering;
2use std::collections::HashMap;
3
4use tokio::sync::Mutex;
5use tracing::debug;
6
7pub type TxHash = [u8; 32];
9
10#[derive(Clone)]
12struct TxEntry {
13 tx: Vec<u8>,
14 priority: u64,
15 gas_wanted: u64,
16 hash: TxHash,
17}
18
19impl Eq for TxEntry {}
20
21impl PartialEq for TxEntry {
22 fn eq(&self, other: &Self) -> bool {
23 self.cmp(other) == Ordering::Equal
24 }
25}
26
27impl Ord for TxEntry {
31 fn cmp(&self, other: &Self) -> Ordering {
32 self.priority
33 .cmp(&other.priority)
34 .then_with(|| self.hash.cmp(&other.hash))
35 }
36}
37
38impl PartialOrd for TxEntry {
39 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
40 Some(self.cmp(other))
41 }
42}
43
44pub struct Mempool {
54 entries: Mutex<std::collections::BTreeSet<TxEntry>>,
55 seen: Mutex<HashMap<TxHash, u64>>,
57 max_size: usize,
58 max_tx_bytes: usize,
59}
60
61impl Mempool {
62 pub fn new(max_size: usize, max_tx_bytes: usize) -> Self {
63 Self {
64 entries: Mutex::new(std::collections::BTreeSet::new()),
65 seen: Mutex::new(HashMap::new()),
66 max_size,
67 max_tx_bytes,
68 }
69 }
70
71 pub async fn add_tx(&self, tx: Vec<u8>, priority: u64) -> bool {
81 self.add_tx_with_gas(tx, priority, 0).await
82 }
83
84 pub async fn add_tx_with_gas(&self, tx: Vec<u8>, priority: u64, gas_wanted: u64) -> bool {
86 if tx.len() > self.max_tx_bytes {
87 debug!(size = tx.len(), max = self.max_tx_bytes, "tx too large");
88 return false;
89 }
90
91 let hash = Self::hash_tx(&tx);
92
93 let mut entries = self.entries.lock().await;
95 let mut seen = self.seen.lock().await;
96
97 if let Some(&existing_priority) = seen.get(&hash) {
98 if priority <= existing_priority {
99 return false;
101 }
102 let old = TxEntry {
104 tx: tx.clone(),
105 priority: existing_priority,
106 gas_wanted: 0,
107 hash,
108 };
109 entries.remove(&old);
110 seen.insert(hash, priority);
111 entries.insert(TxEntry {
112 tx,
113 priority,
114 gas_wanted,
115 hash,
116 });
117 debug!(
118 old = existing_priority,
119 new = priority,
120 "replaced tx via RBF"
121 );
122 return true;
123 }
124
125 if entries.len() >= self.max_size {
126 if let Some(lowest) = entries.first() {
128 if priority <= lowest.priority {
129 debug!(
130 priority,
131 lowest = lowest.priority,
132 "mempool full, priority too low"
133 );
134 return false;
135 }
136 let evicted = entries.pop_first().expect("just checked non-empty");
138 seen.remove(&evicted.hash);
139 debug!(
140 evicted_priority = evicted.priority,
141 new_priority = priority,
142 "evicted low-priority tx"
143 );
144 }
145 }
146
147 seen.insert(hash, priority);
148 entries.insert(TxEntry {
149 tx,
150 priority,
151 gas_wanted,
152 hash,
153 });
154 true
155 }
156
157 pub async fn collect_payload(&self, max_bytes: usize) -> Vec<u8> {
164 self.collect_payload_with_gas(max_bytes, 0).await
165 }
166
167 pub async fn collect_payload_with_gas(&self, max_bytes: usize, max_gas: u64) -> Vec<u8> {
172 let mut entries = self.entries.lock().await;
173 let mut seen = self.seen.lock().await;
174 let mut payload = Vec::new();
175 let mut total_gas = 0u64;
176 let mut skipped = Vec::new();
177
178 while let Some(entry) = entries.pop_last() {
179 if payload.len() + 4 + entry.tx.len() > max_bytes {
182 skipped.push(entry);
183 break;
184 }
185 if max_gas > 0 && total_gas + entry.gas_wanted > max_gas {
187 skipped.push(entry);
188 continue;
189 }
190 seen.remove(&entry.hash);
191 total_gas += entry.gas_wanted;
192 let len = entry.tx.len() as u32;
193 payload.extend_from_slice(&len.to_le_bytes());
194 payload.extend_from_slice(&entry.tx);
195 }
196
197 for entry in skipped {
199 entries.insert(entry);
200 }
201
202 payload
203 }
204
205 pub fn decode_payload(payload: &[u8]) -> Vec<Vec<u8>> {
207 let mut txs = Vec::new();
208 let mut offset = 0;
209 while offset + 4 <= payload.len() {
210 let len = u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
211 offset += 4;
212 if offset + len > payload.len() {
213 break;
214 }
215 txs.push(payload[offset..offset + len].to_vec());
216 offset += len;
217 }
218 txs
219 }
220
221 pub async fn size(&self) -> usize {
222 self.entries.lock().await.len()
223 }
224
225 fn hash_tx(tx: &[u8]) -> TxHash {
226 *blake3::hash(tx).as_bytes()
227 }
228}
229
230impl Default for Mempool {
231 fn default() -> Self {
232 Self::new(10_000, 1_048_576) }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[tokio::test]
241 async fn test_add_and_collect() {
242 let pool = Mempool::new(100, 1024);
243 assert!(pool.add_tx(b"tx1".to_vec(), 10).await);
244 assert!(pool.add_tx(b"tx2".to_vec(), 20).await);
245 assert_eq!(pool.size().await, 2);
246
247 let payload = pool.collect_payload(1024).await;
248 let txs = Mempool::decode_payload(&payload);
249 assert_eq!(txs.len(), 2);
250 assert_eq!(txs[0], b"tx2");
252 assert_eq!(txs[1], b"tx1");
253 }
254
255 #[tokio::test]
256 async fn test_dedup() {
257 let pool = Mempool::new(100, 1024);
258 assert!(pool.add_tx(b"tx1".to_vec(), 10).await);
259 assert!(!pool.add_tx(b"tx1".to_vec(), 10).await); assert!(!pool.add_tx(b"tx1".to_vec(), 5).await); assert_eq!(pool.size().await, 1);
262 }
263
264 #[tokio::test]
265 async fn test_rbf_replace_by_fee() {
266 let pool = Mempool::new(100, 1024);
267 assert!(pool.add_tx(b"tx1".to_vec(), 5).await);
269 assert_eq!(pool.size().await, 1);
270 assert!(pool.add_tx(b"tx1".to_vec(), 20).await);
272 assert_eq!(pool.size().await, 1);
274 let payload = pool.collect_payload(1024).await;
276 let txs = Mempool::decode_payload(&payload);
277 assert_eq!(txs.len(), 1);
278 assert_eq!(txs[0], b"tx1");
279 }
280
281 #[tokio::test]
282 async fn test_eviction_by_priority() {
283 let pool = Mempool::new(2, 1024);
284 assert!(pool.add_tx(b"low".to_vec(), 1).await);
285 assert!(pool.add_tx(b"mid".to_vec(), 5).await);
286 assert!(pool.add_tx(b"high".to_vec(), 10).await);
288 assert_eq!(pool.size().await, 2);
289
290 let payload = pool.collect_payload(1024).await;
291 let txs = Mempool::decode_payload(&payload);
292 assert_eq!(txs.len(), 2);
293 assert_eq!(txs[0], b"high");
294 assert_eq!(txs[1], b"mid");
295 }
296
297 #[tokio::test]
298 async fn test_reject_low_priority_when_full() {
299 let pool = Mempool::new(2, 1024);
300 assert!(pool.add_tx(b"a".to_vec(), 5).await);
301 assert!(pool.add_tx(b"b".to_vec(), 10).await);
302 assert!(!pool.add_tx(b"c".to_vec(), 3).await);
304 assert_eq!(pool.size().await, 2);
305 }
306
307 #[tokio::test]
308 async fn test_tx_too_large() {
309 let pool = Mempool::new(100, 4);
310 assert!(!pool.add_tx(b"toolarge".to_vec(), 10).await);
311 assert!(pool.add_tx(b"ok".to_vec(), 10).await);
312 }
313
314 #[tokio::test]
315 async fn test_collect_respects_max_bytes() {
316 let pool = Mempool::new(100, 1024);
317 pool.add_tx(b"aaaa".to_vec(), 1).await;
318 pool.add_tx(b"bbbb".to_vec(), 2).await;
319 pool.add_tx(b"cccc".to_vec(), 3).await;
320
321 let payload = pool.collect_payload(17).await;
324 let txs = Mempool::decode_payload(&payload);
325 assert_eq!(txs.len(), 2);
326 assert_eq!(txs[0], b"cccc");
328 assert_eq!(txs[1], b"bbbb");
329 }
330
331 #[test]
332 fn test_decode_empty_payload() {
333 let txs = Mempool::decode_payload(&[]);
334 assert!(txs.is_empty());
335 }
336
337 #[test]
338 fn test_decode_truncated_payload() {
339 let txs = Mempool::decode_payload(&[1, 2]);
341 assert!(txs.is_empty());
342 }
343
344 #[test]
345 fn test_decode_payload_with_truncated_data() {
346 let mut payload = vec![];
348 payload.extend_from_slice(&100u32.to_le_bytes());
349 payload.extend_from_slice(&[1, 2, 3]);
350 let txs = Mempool::decode_payload(&payload);
351 assert!(txs.is_empty());
352 }
353
354 #[tokio::test]
355 async fn test_empty_tx() {
356 let pool = Mempool::new(100, 1024);
357 assert!(pool.add_tx(vec![], 0).await);
358 let payload = pool.collect_payload(1024).await;
359 let txs = Mempool::decode_payload(&payload);
360 assert_eq!(txs.len(), 1);
361 assert!(txs[0].is_empty());
362 }
363}