Skip to main content

amaru_protocols/tx_submission/
responder.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{BTreeSet, VecDeque},
17    fmt::Display,
18};
19
20use ProtocolError::*;
21use amaru_kernel::Transaction;
22use amaru_ouroboros::TxSubmissionMempool;
23use amaru_ouroboros_traits::{TxId, TxOrigin};
24use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
25use tracing::instrument;
26
27use crate::{
28    mempool_effects::MemoryPool,
29    mux::MuxMessage,
30    protocol::{
31        Inputs, Miniprotocol, Outcome, PROTO_N2N_TX_SUB, ProtocolState, Responder, StageState, miniprotocol, outcome,
32    },
33    tx_submission::{Blocking, Message, ProtocolError, ResponderParams, State},
34};
35
36pub fn register_deserializers() -> DeserializerGuards {
37    vec![
38        pure_stage::register_data_deserializer::<TxSubmissionResponder>().boxed(),
39        pure_stage::register_data_deserializer::<(State, TxSubmissionResponder)>().boxed(),
40    ]
41}
42
43pub fn responder() -> Miniprotocol<State, TxSubmissionResponder, Responder> {
44    miniprotocol(PROTO_N2N_TX_SUB.responder())
45}
46
47impl StageState<State, Responder> for TxSubmissionResponder {
48    type LocalIn = Void;
49
50    async fn local(
51        self,
52        _proto: &State,
53        input: Self::LocalIn,
54        _eff: &Effects<Inputs<Self::LocalIn>>,
55    ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
56        match input {}
57    }
58
59    #[instrument(name = "tx_submission.responder.stage", skip_all, fields(message_type = input.message_type()))]
60    async fn network(
61        mut self,
62        _proto: &State,
63        input: ResponderResult,
64        eff: &Effects<Inputs<Self::LocalIn>>,
65    ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
66        let mempool: &dyn TxSubmissionMempool<Transaction> = &MemoryPool::new(eff.clone());
67
68        let action = match input {
69            ResponderResult::Init => {
70                tracing::trace!("received Init");
71                self.initialize_state(mempool)
72            }
73            ResponderResult::ReplyTxIds(tx_ids) => self.process_tx_ids_reply(mempool, tx_ids)?,
74            ResponderResult::ReplyTxs(txs) => self.process_txs_reply(mempool, txs, self.origin.clone())?,
75            ResponderResult::Done => None,
76        };
77        Ok((action, self))
78    }
79
80    fn muxer(&self) -> &StageRef<MuxMessage> {
81        &self.muxer
82    }
83}
84
85impl ProtocolState<Responder> for State {
86    type WireMsg = Message;
87    type Action = ResponderAction;
88    type Out = ResponderResult;
89    type Error = ProtocolError;
90
91    fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
92        // Responder waits for Init message, doesn't send anything on init
93        Ok((outcome().want_next(), *self))
94    }
95
96    #[instrument(name = "tx_submission.responder.protocol", skip_all, fields(message_type = input.message_type()))]
97    fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
98        Ok(match (self, input) {
99            (State::Init, Message::Init) => (outcome().result(ResponderResult::Init), State::Idle),
100            (State::TxIdsBlocking | State::TxIdsNonBlocking, Message::ReplyTxIds(tx_ids)) => {
101                (outcome().result(ResponderResult::ReplyTxIds(tx_ids)), State::Idle)
102            }
103            (State::Txs, Message::ReplyTxs(txs)) => (outcome().result(ResponderResult::ReplyTxs(txs)), State::Idle),
104            (State::TxIdsBlocking, Message::Done) => (outcome().result(ResponderResult::Done), State::Done),
105            (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
106        })
107    }
108
109    fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
110        Ok(match (self, input) {
111            (State::Idle, ResponderAction::SendRequestTxIds { ack, req, blocking }) => match blocking {
112                Blocking::Yes => {
113                    (outcome().send(Message::RequestTxIdsBlocking(ack, req)).want_next(), State::TxIdsBlocking)
114                }
115                Blocking::No => {
116                    (outcome().send(Message::RequestTxIdsNonBlocking(ack, req)).want_next(), State::TxIdsNonBlocking)
117                }
118            },
119            (State::Idle, ResponderAction::SendRequestTxs(tx_ids)) => {
120                (outcome().send(Message::RequestTxs(tx_ids)).want_next(), State::Txs)
121            }
122            (_, ResponderAction::Error(e)) => (outcome().terminate_with(e), State::Done),
123            (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
124        })
125    }
126}
127
128/// Result from protocol state when network message is received
129#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
130pub enum ResponderResult {
131    Init,
132    ReplyTxIds(Vec<(TxId, u32)>),
133    ReplyTxs(Vec<Transaction>),
134    Done,
135}
136
137impl ResponderResult {
138    pub fn message_type(&self) -> &str {
139        match self {
140            ResponderResult::Init => "Init",
141            ResponderResult::ReplyTxIds(_) => "ReplyTxIds",
142            ResponderResult::ReplyTxs(_) => "ReplyTxs",
143            ResponderResult::Done => "Done",
144        }
145    }
146}
147
148impl Display for ResponderResult {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        match self {
151            ResponderResult::Init => write!(f, "Init"),
152            ResponderResult::ReplyTxIds(tx_ids) => {
153                write!(f, "ReplyTxIds(len: {})", tx_ids.len())
154            }
155            ResponderResult::ReplyTxs(txs) => write!(f, "ReplyTxs(len: {})", txs.len()),
156            ResponderResult::Done => write!(f, "Done"),
157        }
158    }
159}
160
161#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
162pub struct TxSubmissionResponder {
163    /// Responder parameters: batch sizes, window sizes, etc.
164    params: ResponderParams,
165    /// All tx_ids advertised but not yet acked (and their size).
166    window: VecDeque<(TxId, u32)>,
167    /// Tx ids we want to fetch but haven't yet requested.
168    pending_fetch: VecDeque<TxId>,
169    /// Then as a set for quick lookup when processing received ids.
170    /// This is kept in sync with `inflight_fetch_queue`. When we receive a tx body,
171    /// we pop it from the front of the queue and remove it from the set.
172    inflight_fetch_set: BTreeSet<TxId>,
173    /// The origin of the transactions we are fetching.
174    origin: TxOrigin,
175    muxer: StageRef<MuxMessage>,
176}
177
178impl TxSubmissionResponder {
179    pub fn new(muxer: StageRef<MuxMessage>, params: ResponderParams, origin: TxOrigin) -> (State, Self) {
180        (
181            State::Init,
182            Self {
183                params,
184                window: VecDeque::new(),
185                pending_fetch: VecDeque::new(),
186                inflight_fetch_set: BTreeSet::new(),
187                origin,
188                muxer,
189            },
190        )
191    }
192
193    fn initialize_state(&mut self, mempool: &dyn TxSubmissionMempool<Transaction>) -> Option<ResponderAction> {
194        let (ack, req, blocking) = self.request_tx_ids(mempool);
195        Some(ResponderAction::SendRequestTxIds { ack, req, blocking })
196    }
197
198    fn process_tx_ids_reply(
199        &mut self,
200        mempool: &dyn TxSubmissionMempool<Transaction>,
201        tx_ids: Vec<(TxId, u32)>,
202    ) -> anyhow::Result<Option<ResponderAction>> {
203        if self.window.len() + tx_ids.len() > self.params.max_window.into() {
204            return protocol_error(TooManyTxIdsReceived(
205                tx_ids.len(),
206                self.window.len(),
207                self.params.max_window.into(),
208            ));
209        }
210        self.received_tx_ids(mempool, tx_ids);
211
212        let txs = self.txs_to_request();
213        if txs.is_empty() {
214            let (ack, req, blocking) = self.request_tx_ids(mempool);
215            Ok(Some(ResponderAction::SendRequestTxIds { ack, req, blocking }))
216        } else {
217            Ok(Some(ResponderAction::SendRequestTxs(txs)))
218        }
219    }
220
221    fn process_txs_reply(
222        &mut self,
223        mempool: &dyn TxSubmissionMempool<Transaction>,
224        txs: Vec<Transaction>,
225        origin: TxOrigin,
226    ) -> anyhow::Result<Option<ResponderAction>> {
227        if txs.len() > self.params.fetch_batch.into() {
228            return protocol_error(ReceivedTxsExceedsBatchSize(txs.len(), self.params.fetch_batch.into()));
229        }
230
231        // check for duplicate tx ids
232        let tx_ids = txs.iter().map(TxId::from).collect::<BTreeSet<_>>();
233        if tx_ids.len() != txs.len() {
234            // return the full list of tx ids including duplicates
235            let tx_ids = txs.iter().map(TxId::from).collect::<Vec<_>>();
236            return protocol_error(DuplicateTxIds(tx_ids));
237        }
238
239        // check that all received tx ids were in-flight
240        let not_in_flight =
241            tx_ids.iter().filter(|tx_id| !self.inflight_fetch_set.contains(tx_id)).cloned().collect::<Vec<_>>();
242        if !not_in_flight.is_empty() {
243            return protocol_error(SomeReceivedTxsNotInFlight(not_in_flight));
244        }
245
246        self.received_txs(mempool, txs, origin)?;
247        let (ack, req, blocking) = self.request_tx_ids(mempool);
248        Ok(Some(ResponderAction::SendRequestTxIds { ack, req, blocking }))
249    }
250
251    /// Prepare a request for tx ids, acknowledging already processed ones
252    /// and requesting as many as fit in the window.
253    #[allow(clippy::expect_used)]
254    fn request_tx_ids(&mut self, mempool: &dyn TxSubmissionMempool<Transaction>) -> (u16, u16, Blocking) {
255        // Acknowledge everything we’ve already processed.
256        let mut ack = 0_u16;
257
258        while let Some((tx_id, _size)) = self.window.front() {
259            let already_in_mempool = mempool.contains(tx_id);
260            if already_in_mempool {
261                // pop from window and ack it
262                if self.window.pop_front().is_some() {
263                    ack = ack.checked_add(1).expect("ack overflow: protocol invariant violated");
264                }
265            } else {
266                break;
267            }
268        }
269
270        // Request as many as we can fit in the window.
271        let req = self
272            .params
273            .max_window
274            .checked_sub(self.window.len() as u16)
275            .expect("req underflow: protocol invariant violated");
276
277        // We need to block if there are no more outstanding tx ids.
278        let blocking = if self.window.is_empty() { Blocking::Yes } else { Blocking::No };
279        (ack, req, blocking)
280    }
281
282    /// Register received tx ids, adding them to the window and to the pending fetch list
283    /// if they are not already in the mempool.
284    fn received_tx_ids<Tx: Send + Sync + 'static>(
285        &mut self,
286        mempool: &dyn TxSubmissionMempool<Tx>,
287        tx_ids: Vec<(TxId, u32)>,
288    ) {
289        for (tx_id, size) in tx_ids {
290            // We add the tx id to the window to acknowledge it on the next round.
291            self.window.push_back((tx_id, size));
292
293            // We only add to pending fetch if we haven't received it yet in the mempool.
294            if !mempool.contains(&tx_id) {
295                self.pending_fetch.push_back(tx_id);
296            }
297        }
298    }
299
300    /// Prepare a batch of tx ids for the txs to request.
301    fn txs_to_request(&mut self) -> Vec<TxId> {
302        let mut tx_ids = Vec::new();
303
304        while tx_ids.len() < self.params.fetch_batch.into() {
305            if let Some(id) = self.pending_fetch.pop_front() {
306                self.inflight_fetch_set.insert(id);
307                tx_ids.push(id);
308            } else {
309                break;
310            }
311        }
312
313        tx_ids
314    }
315
316    /// Process received txs, validating and inserting them into the mempool.
317    fn received_txs(
318        &mut self,
319        mempool: &dyn TxSubmissionMempool<Transaction>,
320        txs: Vec<Transaction>,
321        origin: TxOrigin,
322    ) -> anyhow::Result<()> {
323        for tx in txs {
324            let requested_id = TxId::from(&tx);
325            self.inflight_fetch_set.remove(&requested_id);
326            match mempool.validate_transaction(tx.clone()) {
327                Ok(_) => {
328                    tracing::debug!("insert transaction {} into the mempool", requested_id);
329                    mempool.insert(tx, origin.clone())?;
330                }
331                Err(e) => {
332                    tracing::warn!("received invalid transaction {}: {}", requested_id, e);
333                }
334            }
335        }
336        Ok(())
337    }
338}
339
340fn protocol_error(error: ProtocolError) -> anyhow::Result<Option<ResponderAction>> {
341    tracing::warn!("protocol error: {error}");
342    Ok(Some(ResponderAction::Error(error)))
343}
344
345impl AsRef<StageRef<MuxMessage>> for TxSubmissionResponder {
346    fn as_ref(&self) -> &StageRef<MuxMessage> {
347        &self.muxer
348    }
349}
350
351#[derive(Debug, PartialEq, Eq)]
352pub enum ResponderAction {
353    SendRequestTxIds { ack: u16, req: u16, blocking: Blocking },
354    SendRequestTxs(Vec<TxId>),
355    Error(ProtocolError),
356}
357
358impl Display for ResponderAction {
359    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360        match self {
361            ResponderAction::SendRequestTxIds { ack, req, blocking } => {
362                write!(f, "SendRequestTxIds(ack: {}, req: {}, blocking: {:?})", ack, req, blocking)
363            }
364            ResponderAction::SendRequestTxs(tx_ids) => {
365                write!(f, "SendRequestTxs(tx_ids: {:?})", tx_ids)
366            }
367            ResponderAction::Error(err) => write!(f, "Error({})", err),
368        }
369    }
370}
371
372#[cfg(test)]
373mod tests {
374
375    use std::sync::Arc;
376
377    use amaru_kernel::Transaction;
378    use amaru_mempool::strategies::InMemoryMempool;
379
380    use super::*;
381    use crate::tx_submission::{assert_actions_eq, tests::create_transactions};
382
383    #[tokio::test]
384    async fn test_responder() -> anyhow::Result<()> {
385        let txs = create_transactions(6);
386
387        // Create a mempool with no initial transactions
388        // since we are going to fetch them from the initiator
389        let mempool = Arc::new(InMemoryMempool::default());
390
391        // Send replies from the initiator as if they were replies to previous requests from the responder
392        let results = vec![
393            init(),
394            reply_tx_ids(&txs, &[0, 1, 2]),
395            reply_txs(&txs, &[0, 1]),
396            reply_tx_ids(&txs, &[3, 4, 5]),
397            reply_txs(&txs, &[2, 3]),
398            reply_tx_ids(&txs, &[]),
399            reply_txs(&txs, &[4, 5]),
400            done(),
401        ];
402
403        let actions = run_stage(mempool.clone(), results).await?;
404
405        assert_actions_eq(
406            &actions,
407            &[
408                request_tx_ids(0, 10, Blocking::Yes),
409                request_txs(&txs, &[0, 1]),
410                request_tx_ids(2, 9, Blocking::No),
411                request_txs(&txs, &[2, 3]),
412                request_tx_ids(2, 8, Blocking::No),
413                request_txs(&txs, &[4, 5]),
414                request_tx_ids(2, 10, Blocking::Yes),
415            ],
416        );
417        Ok(())
418    }
419
420    #[tokio::test]
421    async fn the_returned_tx_ids_should_respect_the_window_size() -> anyhow::Result<()> {
422        let txs = create_transactions(11);
423        let mempool = Arc::new(InMemoryMempool::default());
424
425        let results = vec![init(), reply_tx_ids(&txs, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])];
426
427        let actions = run_stage(mempool.clone(), results).await?;
428        assert_actions_eq(
429            &actions,
430            &[request_tx_ids(0, 10, Blocking::Yes), error_action(TooManyTxIdsReceived(11, 0, 10))],
431        );
432        Ok(())
433    }
434
435    #[tokio::test]
436    async fn the_returned_txs_should_respect_the_batch_size() -> anyhow::Result<()> {
437        let txs = create_transactions(6);
438        let mempool = Arc::new(InMemoryMempool::default());
439
440        let results = vec![
441            init(),
442            reply_tx_ids(&txs, &[0, 1, 2]),
443            reply_txs(&txs, &[0]),
444            reply_tx_ids(&txs, &[]),
445            reply_txs(&txs, &[1, 2, 3]),
446        ];
447
448        let outcomes = run_stage(mempool.clone(), results).await?;
449        assert_actions_eq(
450            &outcomes,
451            &[
452                request_tx_ids(0, 10, Blocking::Yes),
453                request_txs(&txs, &[0, 1]),
454                request_tx_ids(1, 8, Blocking::No),
455                request_txs(&txs, &[2]),
456                error_action(ReceivedTxsExceedsBatchSize(3, 2)),
457            ],
458        );
459        Ok(())
460    }
461
462    #[tokio::test]
463    async fn the_returned_txs_be_a_subset_of_the_inflight_txs() -> anyhow::Result<()> {
464        let txs = create_transactions(6);
465        let mempool = Arc::new(InMemoryMempool::default());
466
467        let results = vec![
468            init(),
469            reply_tx_ids(&txs, &[0, 1, 2]),
470            reply_txs(&txs, &[0]),
471            reply_tx_ids(&txs, &[]),
472            reply_txs(&txs, &[1, 3]),
473        ];
474
475        let actions = run_stage(mempool.clone(), results).await?;
476        assert_actions_eq(
477            &actions,
478            &[
479                request_tx_ids(0, 10, Blocking::Yes),
480                request_txs(&txs, &[0, 1]),
481                request_tx_ids(1, 8, Blocking::No),
482                request_txs(&txs, &[2]),
483                error_action(SomeReceivedTxsNotInFlight(vec![TxId::from(&txs[3])])),
484            ],
485        );
486        Ok(())
487    }
488
489    #[test]
490    fn test_responder_protocol() {
491        crate::tx_submission::spec::<Responder>().check(State::Init, |msg| match msg {
492            Message::RequestTxIdsBlocking(ack, req) => {
493                Some(ResponderAction::SendRequestTxIds { ack: *ack, req: *req, blocking: Blocking::Yes })
494            }
495            Message::RequestTxIdsNonBlocking(ack, req) => {
496                Some(ResponderAction::SendRequestTxIds { ack: *ack, req: *req, blocking: Blocking::No })
497            }
498            Message::RequestTxs(txs) => Some(ResponderAction::SendRequestTxs(txs.clone())),
499            Message::ReplyTxs(_) | Message::ReplyTxIds(_) | Message::Init | Message::Done => None,
500        });
501    }
502
503    // HELPERS
504
505    async fn run_stage(
506        mempool: Arc<dyn TxSubmissionMempool<Transaction>>,
507        results: Vec<ResponderResult>,
508    ) -> anyhow::Result<Vec<ResponderAction>> {
509        let (actions, _responder) = run_stage_and_return_state(mempool, results).await?;
510        Ok(actions)
511    }
512
513    async fn run_stage_and_return_state(
514        mempool: Arc<dyn TxSubmissionMempool<Transaction>>,
515        results: Vec<ResponderResult>,
516    ) -> anyhow::Result<(Vec<ResponderAction>, TxSubmissionResponder)> {
517        run_stage_and_return_state_with(
518            TxSubmissionResponder::new(StageRef::named_for_tests("muxer"), ResponderParams::default(), TxOrigin::Local)
519                .1,
520            mempool,
521            results,
522        )
523        .await
524    }
525
526    async fn run_stage_and_return_state_with(
527        mut responder: TxSubmissionResponder,
528        mempool: Arc<dyn TxSubmissionMempool<Transaction>>,
529        results: Vec<ResponderResult>,
530    ) -> anyhow::Result<(Vec<ResponderAction>, TxSubmissionResponder)> {
531        let mut actions = vec![];
532        for r in results {
533            let action = match r {
534                ResponderResult::Init => responder.initialize_state(mempool.as_ref()),
535                ResponderResult::ReplyTxIds(tx_ids) => responder.process_tx_ids_reply(mempool.as_ref(), tx_ids)?,
536                ResponderResult::ReplyTxs(txs) => {
537                    responder.process_txs_reply(mempool.as_ref(), txs, responder.origin.clone())?
538                }
539                ResponderResult::Done => None,
540            };
541            if let Some(action) = action {
542                actions.push(action)
543            };
544        }
545        Ok((actions, responder))
546    }
547    // HELPERS
548
549    fn init() -> ResponderResult {
550        ResponderResult::Init
551    }
552
553    fn done() -> ResponderResult {
554        ResponderResult::Done
555    }
556
557    fn reply_tx_ids(txs: &[Transaction], ids: &[usize]) -> ResponderResult {
558        ResponderResult::ReplyTxIds(ids.iter().map(|id| (TxId::from(&txs[*id]), 50)).collect())
559    }
560
561    fn reply_txs(txs: &[Transaction], ids: &[usize]) -> ResponderResult {
562        ResponderResult::ReplyTxs(ids.iter().map(|id| txs[*id].clone()).collect())
563    }
564
565    fn request_tx_ids(ack: u16, req: u16, blocking: Blocking) -> ResponderAction {
566        ResponderAction::SendRequestTxIds { ack, req, blocking }
567    }
568
569    fn request_txs(txs: &[Transaction], ids: &[usize]) -> ResponderAction {
570        ResponderAction::SendRequestTxs(ids.iter().map(|id| TxId::from(&txs[*id])).collect())
571    }
572
573    fn error_action(error: ProtocolError) -> ResponderAction {
574        ResponderAction::Error(error)
575    }
576}