Skip to main content

amaru_protocols/
mempool_effects.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{fmt::Debug, pin::Pin};
16
17use amaru_kernel::Transaction;
18use amaru_ouroboros::ResourceMempool;
19use amaru_ouroboros_traits::{
20    CanValidateTransactions, MempoolSeqNo, TransactionValidationError, TxId, TxOrigin, TxRejectReason,
21    TxSubmissionMempool,
22};
23use pure_stage::{BoxFuture, Effects, ExternalEffect, ExternalEffectAPI, ExternalEffectSync, Resources, SendData};
24use serde::{Deserialize, Serialize};
25
26/// Implementation of Mempool effects using pure_stage::Effects.
27///
28/// It supports operations
29///
30/// - for the tx submission protocol
31/// - for transaction validation
32///
33#[derive(Clone)]
34pub struct MemoryPool<T> {
35    effects: Effects<T>,
36}
37
38impl<T> MemoryPool<T> {
39    pub fn new(effects: Effects<T>) -> MemoryPool<T> {
40        MemoryPool { effects }
41    }
42
43    /// This function runs an external effect synchronously.
44    pub fn external_sync<E: ExternalEffectSync + serde::Serialize + 'static>(&self, effect: E) -> E::Response
45    where
46        T: SendData + Sync,
47    {
48        self.effects.external_sync(effect)
49    }
50}
51
52impl<T: SendData + Sync> CanValidateTransactions<Transaction> for MemoryPool<T> {
53    /// This effect uses the ledger to validate a transaction before adding it to the mempool.
54    fn validate_transaction(&self, tx: Transaction) -> Result<(), TransactionValidationError> {
55        self.effects.external_sync(ValidateTransaction(tx))
56    }
57}
58
59impl<T: SendData + Sync> TxSubmissionMempool<Transaction> for MemoryPool<T> {
60    /// This effect inserts a transaction into the mempool, specifying its origin.
61    /// A TxOrigin::Local origin indicates the transaction was created on the current node,
62    /// A TxOrigin::Remote(origin_peer) indicates the transaction was received from a remote peer
63    fn insert(&self, tx: Transaction, tx_origin: TxOrigin) -> Result<(TxId, MempoolSeqNo), TxRejectReason> {
64        self.external_sync(Insert::new(tx, tx_origin))
65    }
66
67    /// This effect retrieves a transaction by its id.
68    /// It returns None if the transaction is not found.
69    fn get_tx(&self, tx_id: &TxId) -> Option<Transaction> {
70        self.external_sync(GetTx::new(*tx_id))
71    }
72
73    /// This effect retrieves a list of transaction ids from a given sequence number (inclusive), up to a given limit.
74    fn tx_ids_since(&self, from_seq: MempoolSeqNo, limit: u16) -> Vec<(TxId, u32, MempoolSeqNo)> {
75        self.external_sync(TxIdsSince::new(from_seq, limit))
76    }
77
78    /// This effect waits until the mempool reaches at least the given sequence number.
79    fn wait_for_at_least(&self, seq_no: MempoolSeqNo) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
80        self.effects.external(WaitForAtLeast::new(seq_no))
81    }
82
83    /// This effect retrieves a list of transactions for the given ids.
84    fn get_txs_for_ids(&self, ids: &[TxId]) -> Vec<Transaction> {
85        self.external_sync(GetTxsForIds::new(ids))
86    }
87
88    /// This effect gets the last assigned sequence number in the mempool.
89    fn last_seq_no(&self) -> MempoolSeqNo {
90        self.external_sync(LastSeqNo)
91    }
92}
93
94// EXTERNAL EFFECTS DEFINITIONS
95
96#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
97struct Insert {
98    tx: Transaction,
99    tx_origin: TxOrigin,
100}
101
102impl Insert {
103    pub fn new(tx: Transaction, tx_origin: TxOrigin) -> Self {
104        Self { tx, tx_origin }
105    }
106}
107
108impl ExternalEffect for Insert {
109    #[expect(clippy::expect_used)]
110    fn run(self: Box<Self>, resources: Resources) -> BoxFuture<'static, Box<dyn SendData>> {
111        Self::wrap_sync({
112            let mempool = resources.get::<ResourceMempool<Transaction>>().expect("ResourceMempool requires a mempool");
113            mempool.insert(self.tx, self.tx_origin)
114        })
115    }
116}
117
118impl ExternalEffectAPI for Insert {
119    type Response = Result<(TxId, MempoolSeqNo), TxRejectReason>;
120}
121
122impl ExternalEffectSync for Insert {}
123
124#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
125struct ValidateTransaction(Transaction);
126
127impl ExternalEffect for ValidateTransaction {
128    #[expect(clippy::expect_used)]
129    fn run(self: Box<Self>, resources: Resources) -> BoxFuture<'static, Box<dyn SendData>> {
130        Self::wrap_sync({
131            let mempool = resources.get::<ResourceMempool<Transaction>>().expect("ResourceMempool requires a mempool");
132            mempool.validate_transaction(self.0)
133        })
134    }
135}
136
137impl ExternalEffectAPI for ValidateTransaction {
138    type Response = Result<(), TransactionValidationError>;
139}
140
141impl ExternalEffectSync for ValidateTransaction {}
142
143#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
144struct GetTx {
145    tx_id: TxId,
146}
147
148impl GetTx {
149    pub fn new(tx_id: TxId) -> Self {
150        Self { tx_id }
151    }
152}
153
154impl ExternalEffect for GetTx {
155    #[expect(clippy::expect_used)]
156    fn run(self: Box<Self>, resources: Resources) -> BoxFuture<'static, Box<dyn SendData>> {
157        Self::wrap_sync({
158            let mempool = resources.get::<ResourceMempool<Transaction>>().expect("ResourceMempool requires a mempool");
159            mempool.get_tx(&self.tx_id)
160        })
161    }
162}
163
164impl ExternalEffectAPI for GetTx {
165    type Response = Option<Transaction>;
166}
167
168impl ExternalEffectSync for GetTx {}
169
170#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
171struct TxIdsSince {
172    mempool_seqno: MempoolSeqNo,
173    limit: u16,
174}
175
176impl TxIdsSince {
177    pub fn new(mempool_seqno: MempoolSeqNo, limit: u16) -> Self {
178        Self { mempool_seqno, limit }
179    }
180}
181
182impl ExternalEffect for TxIdsSince {
183    #[expect(clippy::expect_used)]
184    fn run(self: Box<Self>, resources: Resources) -> BoxFuture<'static, Box<dyn SendData>> {
185        Self::wrap_sync({
186            let mempool = resources.get::<ResourceMempool<Transaction>>().expect("ResourceMempool requires a mempool");
187            mempool.tx_ids_since(self.mempool_seqno, self.limit)
188        })
189    }
190}
191
192impl ExternalEffectAPI for TxIdsSince {
193    type Response = Vec<(TxId, u32, MempoolSeqNo)>;
194}
195
196impl ExternalEffectSync for TxIdsSince {}
197
198#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
199struct WaitForAtLeast {
200    seq_no: MempoolSeqNo,
201}
202
203impl WaitForAtLeast {
204    pub fn new(seq_no: MempoolSeqNo) -> Self {
205        Self { seq_no }
206    }
207}
208
209impl ExternalEffect for WaitForAtLeast {
210    #[expect(clippy::expect_used)]
211    fn run(self: Box<Self>, resources: Resources) -> BoxFuture<'static, Box<dyn SendData>> {
212        Self::wrap(async move {
213            let mempool =
214                resources.get::<ResourceMempool<Transaction>>().expect("ResourceMempool requires a mempool").clone();
215            mempool.wait_for_at_least(self.seq_no).await
216        })
217    }
218}
219
220impl ExternalEffectAPI for WaitForAtLeast {
221    type Response = bool;
222}
223
224#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
225struct GetTxsForIds {
226    tx_ids: Vec<TxId>,
227}
228
229impl GetTxsForIds {
230    pub fn new(ids: &[TxId]) -> Self {
231        Self { tx_ids: ids.to_vec() }
232    }
233}
234
235impl ExternalEffect for GetTxsForIds {
236    #[expect(clippy::expect_used)]
237    fn run(self: Box<Self>, resources: Resources) -> BoxFuture<'static, Box<dyn SendData>> {
238        Self::wrap_sync({
239            let mempool = resources.get::<ResourceMempool<Transaction>>().expect("ResourceMempool requires a mempool");
240            mempool.get_txs_for_ids(&self.tx_ids)
241        })
242    }
243}
244
245impl ExternalEffectAPI for GetTxsForIds {
246    type Response = Vec<Transaction>;
247}
248
249impl ExternalEffectSync for GetTxsForIds {}
250
251#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
252struct LastSeqNo;
253
254impl ExternalEffect for LastSeqNo {
255    #[expect(clippy::expect_used)]
256    fn run(self: Box<Self>, resources: Resources) -> BoxFuture<'static, Box<dyn SendData>> {
257        Self::wrap_sync({
258            let mempool = resources.get::<ResourceMempool<Transaction>>().expect("ResourceMempool requires a mempool");
259            mempool.last_seq_no()
260        })
261    }
262}
263
264impl ExternalEffectAPI for LastSeqNo {
265    type Response = MempoolSeqNo;
266}
267
268impl ExternalEffectSync for LastSeqNo {}
269
270#[cfg(test)]
271mod tests {
272    use std::pin::Pin;
273
274    use amaru_kernel::{Transaction, TransactionBody, WitnessSet};
275    use amaru_ouroboros_traits::{
276        CanValidateTransactions, MempoolSeqNo, TransactionValidationError, TxId, TxOrigin, TxRejectReason,
277        TxSubmissionMempool,
278    };
279
280    #[allow(dead_code)]
281    pub struct ConstantMempool {
282        tx: Transaction,
283    }
284
285    impl ConstantMempool {
286        #[allow(dead_code)]
287        pub fn new() -> Self {
288            let body = TransactionBody::new([], [], 0);
289            let witnesses = WitnessSet::default();
290            let tx: Transaction = Transaction { body, witnesses, is_expected_valid: true, auxiliary_data: None };
291            Self { tx }
292        }
293    }
294
295    impl TxSubmissionMempool<Transaction> for ConstantMempool {
296        fn insert(&self, tx: Transaction, _tx_origin: TxOrigin) -> Result<(TxId, MempoolSeqNo), TxRejectReason> {
297            Ok((TxId::from(&tx), MempoolSeqNo(1)))
298        }
299
300        fn get_tx(&self, _tx_id: &TxId) -> Option<Transaction> {
301            Some(self.tx.clone())
302        }
303
304        fn tx_ids_since(&self, _from_seq: MempoolSeqNo, _limit: u16) -> Vec<(TxId, u32, MempoolSeqNo)> {
305            vec![(TxId::from(&self.tx), 100, MempoolSeqNo(1))]
306        }
307
308        fn wait_for_at_least(&self, _seq_no: MempoolSeqNo) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
309            Box::pin(async { true })
310        }
311
312        fn get_txs_for_ids(&self, _ids: &[TxId]) -> Vec<Transaction> {
313            vec![self.tx.clone()]
314        }
315
316        fn last_seq_no(&self) -> MempoolSeqNo {
317            MempoolSeqNo(1)
318        }
319    }
320
321    impl CanValidateTransactions<Transaction> for ConstantMempool {
322        fn validate_transaction(&self, _tx: Transaction) -> Result<(), TransactionValidationError> {
323            Ok(())
324        }
325    }
326}