1use std::collections::VecDeque;
2use tokio::sync::Mutex;
3use tracing::debug;
4
5pub type TxHash = [u8; 32];
7
8pub struct Mempool {
10 txs: Mutex<VecDeque<Vec<u8>>>,
11 seen: Mutex<std::collections::HashSet<TxHash>>,
12 max_size: usize,
13 max_tx_bytes: usize,
14}
15
16impl Mempool {
17 pub fn new(max_size: usize, max_tx_bytes: usize) -> Self {
18 Self {
19 txs: Mutex::new(VecDeque::new()),
20 seen: Mutex::new(std::collections::HashSet::new()),
21 max_size,
22 max_tx_bytes,
23 }
24 }
25
26 pub async fn add_tx(&self, tx: Vec<u8>) -> bool {
28 if tx.len() > self.max_tx_bytes {
29 debug!(size = tx.len(), max = self.max_tx_bytes, "tx too large");
30 return false;
31 }
32
33 let hash = Self::hash_tx(&tx);
34
35 let mut seen = self.seen.lock().await;
36 if seen.contains(&hash) {
37 return false;
38 }
39
40 let mut txs = self.txs.lock().await;
41 if txs.len() >= self.max_size {
42 debug!(size = txs.len(), max = self.max_size, "mempool full");
43 return false;
44 }
45
46 seen.insert(hash);
47 txs.push_back(tx);
48 true
49 }
50
51 pub async fn collect_payload(&self, max_bytes: usize) -> Vec<u8> {
55 let mut txs = self.txs.lock().await;
56 let mut payload = Vec::new();
57
58 while let Some(tx) = txs.front() {
59 if payload.len() + 4 + tx.len() > max_bytes {
61 break;
62 }
63 let tx = txs.pop_front().unwrap();
64 let len = tx.len() as u32;
65 payload.extend_from_slice(&len.to_le_bytes());
66 payload.extend_from_slice(&tx);
67 }
68
69 payload
70 }
71
72 pub fn decode_payload(payload: &[u8]) -> Vec<Vec<u8>> {
74 let mut txs = Vec::new();
75 let mut offset = 0;
76 while offset + 4 <= payload.len() {
77 let len = u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
78 offset += 4;
79 if offset + len > payload.len() {
80 break;
81 }
82 txs.push(payload[offset..offset + len].to_vec());
83 offset += len;
84 }
85 txs
86 }
87
88 pub async fn size(&self) -> usize {
89 self.txs.lock().await.len()
90 }
91
92 fn hash_tx(tx: &[u8]) -> TxHash {
93 blake3_hash(tx)
94 }
95}
96
97fn blake3_hash(data: &[u8]) -> TxHash {
98 use std::collections::hash_map::DefaultHasher;
101 use std::hash::{Hash, Hasher};
102 let mut hasher = DefaultHasher::new();
103 data.hash(&mut hasher);
104 let h = hasher.finish();
105 let mut out = [0u8; 32];
106 out[..8].copy_from_slice(&h.to_le_bytes());
107 out[8..12].copy_from_slice(&(data.len() as u32).to_le_bytes());
109 out
110}
111
112impl Default for Mempool {
113 fn default() -> Self {
114 Self::new(10_000, 1_048_576) }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[tokio::test]
123 async fn test_add_and_collect() {
124 let pool = Mempool::new(100, 1024);
125 assert!(pool.add_tx(b"tx1".to_vec()).await);
126 assert!(pool.add_tx(b"tx2".to_vec()).await);
127 assert_eq!(pool.size().await, 2);
128
129 let payload = pool.collect_payload(1024).await;
130 let txs = Mempool::decode_payload(&payload);
131 assert_eq!(txs.len(), 2);
132 assert_eq!(txs[0], b"tx1");
133 assert_eq!(txs[1], b"tx2");
134 }
135
136 #[tokio::test]
137 async fn test_dedup() {
138 let pool = Mempool::new(100, 1024);
139 assert!(pool.add_tx(b"tx1".to_vec()).await);
140 assert!(!pool.add_tx(b"tx1".to_vec()).await); assert_eq!(pool.size().await, 1);
142 }
143
144 #[tokio::test]
145 async fn test_max_size() {
146 let pool = Mempool::new(2, 1024);
147 assert!(pool.add_tx(b"tx1".to_vec()).await);
148 assert!(pool.add_tx(b"tx2".to_vec()).await);
149 assert!(!pool.add_tx(b"tx3".to_vec()).await); }
151
152 #[tokio::test]
153 async fn test_tx_too_large() {
154 let pool = Mempool::new(100, 4);
155 assert!(!pool.add_tx(b"toolarge".to_vec()).await);
156 assert!(pool.add_tx(b"ok".to_vec()).await);
157 }
158
159 #[tokio::test]
160 async fn test_collect_respects_max_bytes() {
161 let pool = Mempool::new(100, 1024);
162 pool.add_tx(b"aaaa".to_vec()).await;
163 pool.add_tx(b"bbbb".to_vec()).await;
164 pool.add_tx(b"cccc".to_vec()).await;
165
166 let payload = pool.collect_payload(17).await;
169 let txs = Mempool::decode_payload(&payload);
170 assert_eq!(txs.len(), 2);
171 }
172
173 #[test]
174 fn test_decode_empty_payload() {
175 let txs = Mempool::decode_payload(&[]);
176 assert!(txs.is_empty());
177 }
178
179 #[test]
180 fn test_decode_truncated_payload() {
181 let txs = Mempool::decode_payload(&[1, 2]);
183 assert!(txs.is_empty());
184 }
185
186 #[test]
187 fn test_decode_payload_with_truncated_data() {
188 let mut payload = vec![];
190 payload.extend_from_slice(&100u32.to_le_bytes());
191 payload.extend_from_slice(&[1, 2, 3]);
192 let txs = Mempool::decode_payload(&payload);
193 assert!(txs.is_empty());
194 }
195
196 #[tokio::test]
197 async fn test_empty_tx() {
198 let pool = Mempool::new(100, 1024);
199 assert!(pool.add_tx(vec![]).await);
200 let payload = pool.collect_payload(1024).await;
201 let txs = Mempool::decode_payload(&payload);
202 assert_eq!(txs.len(), 1);
203 assert!(txs[0].is_empty());
204 }
205}