Skip to main content

amaru_mempool/strategies/
in_memory_mempool.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{BTreeMap, BTreeSet},
17    mem,
18    pin::Pin,
19    sync::Arc,
20};
21
22use amaru_kernel::{cbor, to_cbor};
23use amaru_ouroboros_traits::{
24    CanValidateTransactions, MempoolSeqNo, TransactionValidationError, TxId, TxOrigin, TxRejectReason,
25    TxSubmissionMempool, mempool::Mempool,
26};
27use tokio::sync::Notify;
28
29/// A temporary in-memory mempool implementation to support the transaction submission protocol.
30///
31/// It stores transactions in memory, indexed by their TxId and by a sequence number assigned
32/// at insertion time.
33///
34/// The validation of the transactions are delegated to a `CanValidateTransactions` implementation.
35///
36#[derive(Clone)]
37pub struct InMemoryMempool<Tx> {
38    config: MempoolConfig,
39    inner: Arc<parking_lot::RwLock<MempoolInner<Tx>>>,
40    tx_validator: Arc<dyn CanValidateTransactions<Tx>>,
41}
42
43impl<Tx> Default for InMemoryMempool<Tx> {
44    fn default() -> Self {
45        Self::from_config(MempoolConfig::default())
46    }
47}
48
49impl<Tx> InMemoryMempool<Tx> {
50    pub fn new(config: MempoolConfig, tx_validator: Arc<dyn CanValidateTransactions<Tx>>) -> Self {
51        InMemoryMempool { config, inner: Arc::new(parking_lot::RwLock::new(MempoolInner::default())), tx_validator }
52    }
53
54    pub fn from_config(config: MempoolConfig) -> Self {
55        Self::new(config, Arc::new(DefaultCanValidateTransactions))
56    }
57}
58
59/// A default transactions validator.
60#[derive(Clone, Debug, Default)]
61pub struct DefaultCanValidateTransactions;
62
63impl<Tx> CanValidateTransactions<Tx> for DefaultCanValidateTransactions {
64    fn validate_transaction(&self, _tx: Tx) -> Result<(), TransactionValidationError> {
65        Ok(())
66    }
67}
68
69#[derive(Debug)]
70struct MempoolInner<Tx> {
71    next_seq: u64,
72    entries_by_id: BTreeMap<TxId, MempoolEntry<Tx>>,
73    entries_by_seq: BTreeMap<MempoolSeqNo, TxId>,
74    notify: Arc<Notify>,
75}
76
77impl<Tx> Default for MempoolInner<Tx> {
78    fn default() -> Self {
79        MempoolInner {
80            next_seq: 1,
81            entries_by_id: Default::default(),
82            entries_by_seq: Default::default(),
83            notify: Arc::new(Notify::new()),
84        }
85    }
86}
87
88impl<Tx: cbor::Encode<()> + Clone> MempoolInner<Tx> {
89    /// Inserts a new transaction into the mempool.
90    /// The transaction id is a hash of the transaction body.
91    fn insert(
92        &mut self,
93        config: &MempoolConfig,
94        tx: Tx,
95        tx_origin: TxOrigin,
96    ) -> Result<(TxId, MempoolSeqNo), TxRejectReason> {
97        if let Some(max_txs) = config.max_txs
98            && self.entries_by_id.len() >= max_txs
99        {
100            return Err(TxRejectReason::MempoolFull);
101        }
102
103        let tx_id = TxId::from(&tx);
104        if self.entries_by_id.contains_key(&tx_id) {
105            return Err(TxRejectReason::Duplicate);
106        }
107
108        let seq_no = MempoolSeqNo(self.next_seq);
109        self.next_seq += 1;
110
111        let tx_size = to_cbor(&tx).len() as u32;
112        let entry = MempoolEntry { seq_no, tx_id, tx, tx_size, origin: tx_origin };
113
114        self.entries_by_id.insert(tx_id, entry);
115        self.entries_by_seq.insert(seq_no, tx_id);
116        Ok((tx_id, seq_no))
117    }
118
119    /// Retrieves a transaction by its id.
120    fn get_tx(&self, tx_id: &TxId) -> Option<Tx> {
121        self.entries_by_id.get(tx_id).map(|entry| entry.tx.clone())
122    }
123
124    /// Retrieves all the transaction ids since a given sequence number, up to a limit.
125    fn tx_ids_since(&self, from_seq: MempoolSeqNo, limit: u16) -> Vec<(TxId, u32, MempoolSeqNo)> {
126        let mut result: Vec<(TxId, u32, MempoolSeqNo)> = self
127            .entries_by_seq
128            .range(from_seq..)
129            .take(limit as usize)
130            .map(|(seq, tx_id)| {
131                let Some(entry) = self.entries_by_id.get(tx_id) else {
132                    panic!("Inconsistent mempool state: entry missing for tx_id {:?}", tx_id)
133                };
134                (*tx_id, entry.tx_size, *seq)
135            })
136            .collect();
137        result.sort_by_key(|(_, _, seq_no)| *seq_no);
138        result
139    }
140
141    /// Retrieves transactions for the given ids, sorted by their sequence number.
142    fn get_txs_for_ids(&self, ids: &[TxId]) -> Vec<Tx> {
143        // Make sure that the result are sorted by seq_no
144        let mut result: Vec<(&TxId, &MempoolEntry<Tx>)> =
145            self.entries_by_id.iter().filter(|(key, _)| ids.contains(*key)).collect();
146        result.sort_by_key(|(_, entry)| entry.seq_no);
147        result.into_iter().map(|(_, entry)| entry.tx.clone()).collect()
148    }
149}
150
151#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
152pub struct MempoolEntry<Tx> {
153    seq_no: MempoolSeqNo,
154    tx_id: TxId,
155    tx: Tx,
156    tx_size: u32,
157    origin: TxOrigin,
158}
159
160#[derive(Debug, Clone, Default)]
161pub struct MempoolConfig {
162    max_txs: Option<usize>,
163}
164
165impl MempoolConfig {
166    pub fn with_max_txs(mut self, max: usize) -> Self {
167        self.max_txs = Some(max);
168        self
169    }
170}
171
172impl<Tx: Send + Sync + 'static> CanValidateTransactions<Tx> for InMemoryMempool<Tx> {
173    fn validate_transaction(&self, tx: Tx) -> Result<(), TransactionValidationError> {
174        self.tx_validator.validate_transaction(tx)
175    }
176}
177
178impl<Tx: Send + Sync + 'static + cbor::Encode<()> + Clone> TxSubmissionMempool<Tx> for InMemoryMempool<Tx> {
179    fn insert(&self, tx: Tx, tx_origin: TxOrigin) -> Result<(TxId, MempoolSeqNo), TxRejectReason> {
180        let mut inner = self.inner.write();
181        let res = inner.insert(&self.config, tx, tx_origin);
182        if res.is_ok() {
183            inner.notify.notify_waiters();
184        }
185        res
186    }
187
188    fn get_tx(&self, tx_id: &TxId) -> Option<Tx> {
189        self.inner.read().get_tx(tx_id)
190    }
191
192    fn tx_ids_since(&self, from_seq: MempoolSeqNo, limit: u16) -> Vec<(TxId, u32, MempoolSeqNo)> {
193        self.inner.read().tx_ids_since(from_seq, limit)
194    }
195
196    /// Waits until the mempool reaches at least the given sequence number.
197    fn wait_for_at_least(&self, seq_no: MempoolSeqNo) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
198        Box::pin(async move {
199            loop {
200                // Prepare a notification future first to avoid races where we miss a notify
201                // between the check and awaiting.
202                let (current_next_seq, notify) = {
203                    let inner = self.inner.read();
204                    (inner.next_seq, inner.notify.clone())
205                };
206                let notified = notify.notified();
207
208                // Check if we already reached the requested sequence number.
209                // (No lock guard is held across the await.)
210                if current_next_seq >= seq_no.0 {
211                    return true;
212                }
213
214                // Wait until someone inserts a new transaction and notifies us.
215                notified.await;
216            }
217        })
218    }
219
220    fn get_txs_for_ids(&self, ids: &[TxId]) -> Vec<Tx> {
221        self.inner.read().get_txs_for_ids(ids)
222    }
223
224    fn last_seq_no(&self) -> MempoolSeqNo {
225        MempoolSeqNo(self.inner.read().next_seq - 1)
226    }
227}
228
229impl<Tx: Send + Sync + 'static + cbor::Encode<()> + Clone> Mempool<Tx> for InMemoryMempool<Tx> {
230    fn take(&self) -> Vec<Tx> {
231        let mut inner = self.inner.write();
232        let entries = mem::take(&mut inner.entries_by_id);
233        let _ = mem::take(&mut inner.entries_by_seq);
234        entries.into_values().map(|entry| entry.tx).collect()
235    }
236
237    fn acknowledge<TxKey: Ord, I>(&self, tx: &Tx, keys: fn(&Tx) -> I)
238    where
239        I: IntoIterator<Item = TxKey>,
240        Self: Sized,
241    {
242        let keys_to_remove = BTreeSet::from_iter(keys(tx));
243        let mut inner = self.inner.write();
244
245        // remove entries matching the keys criteria in both maps
246        let seq_nos_to_remove: Vec<MempoolSeqNo> = inner
247            .entries_by_id
248            .values()
249            .filter(|entry| keys(&entry.tx).into_iter().any(|k| keys_to_remove.contains(&k)))
250            .map(|entry| entry.seq_no)
251            .collect();
252        inner.entries_by_id.retain(|_, entry| !keys(&entry.tx).into_iter().any(|k| keys_to_remove.contains(&k)));
253        for seq_no in seq_nos_to_remove {
254            inner.entries_by_seq.remove(&seq_no);
255        }
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use std::{ops::Deref, slice, str::FromStr, time::Duration};
262
263    use amaru_kernel::{Peer, cbor, cbor as minicbor};
264    use assertables::assert_some_eq_x;
265    use tokio::time::timeout;
266
267    use super::*;
268
269    #[tokio::test]
270    async fn insert_a_transaction() -> anyhow::Result<()> {
271        let mempool = InMemoryMempool::from_config(MempoolConfig::default().with_max_txs(5));
272        let tx = Tx::from_str("tx1").unwrap();
273        let (tx_id, seq_nb) = mempool.insert(tx.clone(), TxOrigin::Remote(Peer::new("upstream"))).unwrap();
274
275        assert_some_eq_x!(mempool.get_tx(&tx_id), tx.clone());
276        assert_eq!(mempool.get_txs_for_ids(slice::from_ref(&tx_id)), vec![tx.clone()]);
277        assert_eq!(mempool.tx_ids_since(seq_nb, 100), vec![(tx_id, 5, seq_nb)]);
278        assert!(mempool.wait_for_at_least(seq_nb).await, "should have at least seq no");
279        assert!(
280            timeout(Duration::from_millis(100), mempool.wait_for_at_least(seq_nb.add(100))).await.is_err(),
281            "should timeout waiting for a seq no that is too high"
282        );
283        Ok(())
284    }
285
286    // HELPERS
287    #[derive(Debug, PartialEq, Eq, Clone, cbor::Encode, cbor::Decode)]
288    struct Tx(#[n(0)] String);
289
290    impl Deref for Tx {
291        type Target = String;
292        fn deref(&self) -> &Self::Target {
293            &self.0
294        }
295    }
296
297    impl FromStr for Tx {
298        type Err = ();
299        fn from_str(s: &str) -> Result<Self, Self::Err> {
300            Ok(Tx(s.to_string()))
301        }
302    }
303}