1use std::collections::{HashSet, VecDeque};
2use tokio::sync::Mutex;
3use tracing::debug;
4
5pub type TxHash = [u8; 32];
7
8pub struct Mempool {
10 txs: Mutex<VecDeque<Vec<u8>>>,
11 seen: Mutex<HashSet<TxHash>>,
12 max_size: usize,
13 max_tx_bytes: usize,
14}
15
16impl Mempool {
17 pub fn new(max_size: usize, max_tx_bytes: usize) -> Self {
18 Self {
19 txs: Mutex::new(VecDeque::new()),
20 seen: Mutex::new(HashSet::new()),
21 max_size,
22 max_tx_bytes,
23 }
24 }
25
26 pub async fn add_tx(&self, tx: Vec<u8>) -> bool {
28 if tx.len() > self.max_tx_bytes {
29 debug!(size = tx.len(), max = self.max_tx_bytes, "tx too large");
30 return false;
31 }
32
33 let hash = Self::hash_tx(&tx);
34
35 let mut txs = self.txs.lock().await;
37 let mut seen = self.seen.lock().await;
38
39 if seen.contains(&hash) {
40 return false;
41 }
42 if txs.len() >= self.max_size {
43 debug!(size = txs.len(), max = self.max_size, "mempool full");
44 return false;
45 }
46
47 seen.insert(hash);
48 txs.push_back(tx);
49 true
50 }
51
52 pub async fn collect_payload(&self, max_bytes: usize) -> Vec<u8> {
56 let mut txs = self.txs.lock().await;
57 let mut seen = self.seen.lock().await;
58 let mut payload = Vec::new();
59
60 while let Some(tx) = txs.front() {
61 if payload.len() + 4 + tx.len() > max_bytes {
63 break;
64 }
65 let tx = txs.pop_front().unwrap();
66 seen.remove(&Self::hash_tx(&tx));
67 let len = tx.len() as u32;
68 payload.extend_from_slice(&len.to_le_bytes());
69 payload.extend_from_slice(&tx);
70 }
71
72 payload
73 }
74
75 pub fn decode_payload(payload: &[u8]) -> Vec<Vec<u8>> {
77 let mut txs = Vec::new();
78 let mut offset = 0;
79 while offset + 4 <= payload.len() {
80 let len = u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
81 offset += 4;
82 if offset + len > payload.len() {
83 break;
84 }
85 txs.push(payload[offset..offset + len].to_vec());
86 offset += len;
87 }
88 txs
89 }
90
91 pub async fn size(&self) -> usize {
92 self.txs.lock().await.len()
93 }
94
95 fn hash_tx(tx: &[u8]) -> TxHash {
96 blake3_hash(tx)
97 }
98}
99
100fn blake3_hash(data: &[u8]) -> TxHash {
101 *blake3::hash(data).as_bytes()
102}
103
104impl Default for Mempool {
105 fn default() -> Self {
106 Self::new(10_000, 1_048_576) }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 #[tokio::test]
115 async fn test_add_and_collect() {
116 let pool = Mempool::new(100, 1024);
117 assert!(pool.add_tx(b"tx1".to_vec()).await);
118 assert!(pool.add_tx(b"tx2".to_vec()).await);
119 assert_eq!(pool.size().await, 2);
120
121 let payload = pool.collect_payload(1024).await;
122 let txs = Mempool::decode_payload(&payload);
123 assert_eq!(txs.len(), 2);
124 assert_eq!(txs[0], b"tx1");
125 assert_eq!(txs[1], b"tx2");
126 }
127
128 #[tokio::test]
129 async fn test_dedup() {
130 let pool = Mempool::new(100, 1024);
131 assert!(pool.add_tx(b"tx1".to_vec()).await);
132 assert!(!pool.add_tx(b"tx1".to_vec()).await); assert_eq!(pool.size().await, 1);
134 }
135
136 #[tokio::test]
137 async fn test_max_size() {
138 let pool = Mempool::new(2, 1024);
139 assert!(pool.add_tx(b"tx1".to_vec()).await);
140 assert!(pool.add_tx(b"tx2".to_vec()).await);
141 assert!(!pool.add_tx(b"tx3".to_vec()).await); }
143
144 #[tokio::test]
145 async fn test_tx_too_large() {
146 let pool = Mempool::new(100, 4);
147 assert!(!pool.add_tx(b"toolarge".to_vec()).await);
148 assert!(pool.add_tx(b"ok".to_vec()).await);
149 }
150
151 #[tokio::test]
152 async fn test_collect_respects_max_bytes() {
153 let pool = Mempool::new(100, 1024);
154 pool.add_tx(b"aaaa".to_vec()).await;
155 pool.add_tx(b"bbbb".to_vec()).await;
156 pool.add_tx(b"cccc".to_vec()).await;
157
158 let payload = pool.collect_payload(17).await;
161 let txs = Mempool::decode_payload(&payload);
162 assert_eq!(txs.len(), 2);
163 }
164
165 #[test]
166 fn test_decode_empty_payload() {
167 let txs = Mempool::decode_payload(&[]);
168 assert!(txs.is_empty());
169 }
170
171 #[test]
172 fn test_decode_truncated_payload() {
173 let txs = Mempool::decode_payload(&[1, 2]);
175 assert!(txs.is_empty());
176 }
177
178 #[test]
179 fn test_decode_payload_with_truncated_data() {
180 let mut payload = vec![];
182 payload.extend_from_slice(&100u32.to_le_bytes());
183 payload.extend_from_slice(&[1, 2, 3]);
184 let txs = Mempool::decode_payload(&payload);
185 assert!(txs.is_empty());
186 }
187
188 #[tokio::test]
189 async fn test_empty_tx() {
190 let pool = Mempool::new(100, 1024);
191 assert!(pool.add_tx(vec![]).await);
192 let payload = pool.collect_payload(1024).await;
193 let txs = Mempool::decode_payload(&payload);
194 assert_eq!(txs.len(), 1);
195 assert!(txs[0].is_empty());
196 }
197}