amaru_mempool/strategies/
in_memory_mempool.rs1use std::{
16 collections::{BTreeMap, BTreeSet},
17 mem,
18 pin::Pin,
19 sync::Arc,
20};
21
22use amaru_kernel::{cbor, to_cbor};
23use amaru_ouroboros_traits::{
24 CanValidateTransactions, MempoolSeqNo, TransactionValidationError, TxId, TxOrigin, TxRejectReason,
25 TxSubmissionMempool, mempool::Mempool,
26};
27use tokio::sync::Notify;
28
29#[derive(Clone)]
37pub struct InMemoryMempool<Tx> {
38 config: MempoolConfig,
39 inner: Arc<parking_lot::RwLock<MempoolInner<Tx>>>,
40 tx_validator: Arc<dyn CanValidateTransactions<Tx>>,
41}
42
43impl<Tx> Default for InMemoryMempool<Tx> {
44 fn default() -> Self {
45 Self::from_config(MempoolConfig::default())
46 }
47}
48
49impl<Tx> InMemoryMempool<Tx> {
50 pub fn new(config: MempoolConfig, tx_validator: Arc<dyn CanValidateTransactions<Tx>>) -> Self {
51 InMemoryMempool { config, inner: Arc::new(parking_lot::RwLock::new(MempoolInner::default())), tx_validator }
52 }
53
54 pub fn from_config(config: MempoolConfig) -> Self {
55 Self::new(config, Arc::new(DefaultCanValidateTransactions))
56 }
57}
58
59#[derive(Clone, Debug, Default)]
61pub struct DefaultCanValidateTransactions;
62
63impl<Tx> CanValidateTransactions<Tx> for DefaultCanValidateTransactions {
64 fn validate_transaction(&self, _tx: Tx) -> Result<(), TransactionValidationError> {
65 Ok(())
66 }
67}
68
69#[derive(Debug)]
70struct MempoolInner<Tx> {
71 next_seq: u64,
72 entries_by_id: BTreeMap<TxId, MempoolEntry<Tx>>,
73 entries_by_seq: BTreeMap<MempoolSeqNo, TxId>,
74 notify: Arc<Notify>,
75}
76
77impl<Tx> Default for MempoolInner<Tx> {
78 fn default() -> Self {
79 MempoolInner {
80 next_seq: 1,
81 entries_by_id: Default::default(),
82 entries_by_seq: Default::default(),
83 notify: Arc::new(Notify::new()),
84 }
85 }
86}
87
88impl<Tx: cbor::Encode<()> + Clone> MempoolInner<Tx> {
89 fn insert(
92 &mut self,
93 config: &MempoolConfig,
94 tx: Tx,
95 tx_origin: TxOrigin,
96 ) -> Result<(TxId, MempoolSeqNo), TxRejectReason> {
97 if let Some(max_txs) = config.max_txs
98 && self.entries_by_id.len() >= max_txs
99 {
100 return Err(TxRejectReason::MempoolFull);
101 }
102
103 let tx_id = TxId::from(&tx);
104 if self.entries_by_id.contains_key(&tx_id) {
105 return Err(TxRejectReason::Duplicate);
106 }
107
108 let seq_no = MempoolSeqNo(self.next_seq);
109 self.next_seq += 1;
110
111 let tx_size = to_cbor(&tx).len() as u32;
112 let entry = MempoolEntry { seq_no, tx_id, tx, tx_size, origin: tx_origin };
113
114 self.entries_by_id.insert(tx_id, entry);
115 self.entries_by_seq.insert(seq_no, tx_id);
116 Ok((tx_id, seq_no))
117 }
118
119 fn get_tx(&self, tx_id: &TxId) -> Option<Tx> {
121 self.entries_by_id.get(tx_id).map(|entry| entry.tx.clone())
122 }
123
124 fn tx_ids_since(&self, from_seq: MempoolSeqNo, limit: u16) -> Vec<(TxId, u32, MempoolSeqNo)> {
126 let mut result: Vec<(TxId, u32, MempoolSeqNo)> = self
127 .entries_by_seq
128 .range(from_seq..)
129 .take(limit as usize)
130 .map(|(seq, tx_id)| {
131 let Some(entry) = self.entries_by_id.get(tx_id) else {
132 panic!("Inconsistent mempool state: entry missing for tx_id {:?}", tx_id)
133 };
134 (*tx_id, entry.tx_size, *seq)
135 })
136 .collect();
137 result.sort_by_key(|(_, _, seq_no)| *seq_no);
138 result
139 }
140
141 fn get_txs_for_ids(&self, ids: &[TxId]) -> Vec<Tx> {
143 let mut result: Vec<(&TxId, &MempoolEntry<Tx>)> =
145 self.entries_by_id.iter().filter(|(key, _)| ids.contains(*key)).collect();
146 result.sort_by_key(|(_, entry)| entry.seq_no);
147 result.into_iter().map(|(_, entry)| entry.tx.clone()).collect()
148 }
149}
150
151#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
152pub struct MempoolEntry<Tx> {
153 seq_no: MempoolSeqNo,
154 tx_id: TxId,
155 tx: Tx,
156 tx_size: u32,
157 origin: TxOrigin,
158}
159
160#[derive(Debug, Clone, Default)]
161pub struct MempoolConfig {
162 max_txs: Option<usize>,
163}
164
165impl MempoolConfig {
166 pub fn with_max_txs(mut self, max: usize) -> Self {
167 self.max_txs = Some(max);
168 self
169 }
170}
171
172impl<Tx: Send + Sync + 'static> CanValidateTransactions<Tx> for InMemoryMempool<Tx> {
173 fn validate_transaction(&self, tx: Tx) -> Result<(), TransactionValidationError> {
174 self.tx_validator.validate_transaction(tx)
175 }
176}
177
178impl<Tx: Send + Sync + 'static + cbor::Encode<()> + Clone> TxSubmissionMempool<Tx> for InMemoryMempool<Tx> {
179 fn insert(&self, tx: Tx, tx_origin: TxOrigin) -> Result<(TxId, MempoolSeqNo), TxRejectReason> {
180 let mut inner = self.inner.write();
181 let res = inner.insert(&self.config, tx, tx_origin);
182 if res.is_ok() {
183 inner.notify.notify_waiters();
184 }
185 res
186 }
187
188 fn get_tx(&self, tx_id: &TxId) -> Option<Tx> {
189 self.inner.read().get_tx(tx_id)
190 }
191
192 fn tx_ids_since(&self, from_seq: MempoolSeqNo, limit: u16) -> Vec<(TxId, u32, MempoolSeqNo)> {
193 self.inner.read().tx_ids_since(from_seq, limit)
194 }
195
196 fn wait_for_at_least(&self, seq_no: MempoolSeqNo) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
198 Box::pin(async move {
199 loop {
200 let (current_next_seq, notify) = {
203 let inner = self.inner.read();
204 (inner.next_seq, inner.notify.clone())
205 };
206 let notified = notify.notified();
207
208 if current_next_seq >= seq_no.0 {
211 return true;
212 }
213
214 notified.await;
216 }
217 })
218 }
219
220 fn get_txs_for_ids(&self, ids: &[TxId]) -> Vec<Tx> {
221 self.inner.read().get_txs_for_ids(ids)
222 }
223
224 fn last_seq_no(&self) -> MempoolSeqNo {
225 MempoolSeqNo(self.inner.read().next_seq - 1)
226 }
227}
228
229impl<Tx: Send + Sync + 'static + cbor::Encode<()> + Clone> Mempool<Tx> for InMemoryMempool<Tx> {
230 fn take(&self) -> Vec<Tx> {
231 let mut inner = self.inner.write();
232 let entries = mem::take(&mut inner.entries_by_id);
233 let _ = mem::take(&mut inner.entries_by_seq);
234 entries.into_values().map(|entry| entry.tx).collect()
235 }
236
237 fn acknowledge<TxKey: Ord, I>(&self, tx: &Tx, keys: fn(&Tx) -> I)
238 where
239 I: IntoIterator<Item = TxKey>,
240 Self: Sized,
241 {
242 let keys_to_remove = BTreeSet::from_iter(keys(tx));
243 let mut inner = self.inner.write();
244
245 let seq_nos_to_remove: Vec<MempoolSeqNo> = inner
247 .entries_by_id
248 .values()
249 .filter(|entry| keys(&entry.tx).into_iter().any(|k| keys_to_remove.contains(&k)))
250 .map(|entry| entry.seq_no)
251 .collect();
252 inner.entries_by_id.retain(|_, entry| !keys(&entry.tx).into_iter().any(|k| keys_to_remove.contains(&k)));
253 for seq_no in seq_nos_to_remove {
254 inner.entries_by_seq.remove(&seq_no);
255 }
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use std::{ops::Deref, slice, str::FromStr, time::Duration};
262
263 use amaru_kernel::{Peer, cbor, cbor as minicbor};
264 use assertables::assert_some_eq_x;
265 use tokio::time::timeout;
266
267 use super::*;
268
269 #[tokio::test]
270 async fn insert_a_transaction() -> anyhow::Result<()> {
271 let mempool = InMemoryMempool::from_config(MempoolConfig::default().with_max_txs(5));
272 let tx = Tx::from_str("tx1").unwrap();
273 let (tx_id, seq_nb) = mempool.insert(tx.clone(), TxOrigin::Remote(Peer::new("upstream"))).unwrap();
274
275 assert_some_eq_x!(mempool.get_tx(&tx_id), tx.clone());
276 assert_eq!(mempool.get_txs_for_ids(slice::from_ref(&tx_id)), vec![tx.clone()]);
277 assert_eq!(mempool.tx_ids_since(seq_nb, 100), vec![(tx_id, 5, seq_nb)]);
278 assert!(mempool.wait_for_at_least(seq_nb).await, "should have at least seq no");
279 assert!(
280 timeout(Duration::from_millis(100), mempool.wait_for_at_least(seq_nb.add(100))).await.is_err(),
281 "should timeout waiting for a seq no that is too high"
282 );
283 Ok(())
284 }
285
286 #[derive(Debug, PartialEq, Eq, Clone, cbor::Encode, cbor::Decode)]
288 struct Tx(#[n(0)] String);
289
290 impl Deref for Tx {
291 type Target = String;
292 fn deref(&self) -> &Self::Target {
293 &self.0
294 }
295 }
296
297 impl FromStr for Tx {
298 type Err = ();
299 fn from_str(s: &str) -> Result<Self, Self::Err> {
300 Ok(Tx(s.to_string()))
301 }
302 }
303}