Skip to main content

amaru_protocols/tx_submission/
initiator.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15/// Manages the transaction submission protocol from the initiator's perspective.
16///
17/// This module implements the initiator side of the Cardano transaction submission protocol (N2N),
18/// responsible for requesting transaction IDs from a responder and then requesting the full
19/// transactions based on those IDs.
20///
21/// # Protocol Flow
22///
23/// The initiator follows this state machine:
24/// - **Init**: Sends an initialization message
25/// - **Idle**: Waits for requests from the protocol layer
26/// - **TxIdsBlocking**: Waits for transaction IDs in a blocking manner (waits for mempool to have data)
27/// - **TxIdsNonBlocking**: Waits for transaction IDs in a non-blocking manner (returns immediately)
28/// - **Txs**: Waits for full transaction responses
29///
30/// # Key Components
31///
32/// - [`TxSubmissionInitiator`]: The main state machine that tracks advertised transactions
33/// - [`InitiatorAction`]: Actions produced by the initiator (send replies or errors)
34/// - [`InitiatorResult`]: Results from processing network messages
35///
36/// # Window Management
37///
38/// The initiator maintains a sliding window of advertised transaction IDs. Transactions are:
39/// - Added to the window when they're sent to the responder
40/// - Removed from the window when they're acknowledged by the responder
41/// - Tracked with their mempool sequence numbers to prevent re-advertisement
42///
43/// # Error Handling
44///
45/// The protocol validates:
46/// - Request counts don't exceed maximum protocol limits
47/// - Acknowledgment counts don't exceed advertised transactions
48/// - Only advertised transaction IDs are requested
49/// - Appropriate blocking/non-blocking requests based on acknowledgment state
50use std::collections::VecDeque;
51use std::fmt::{Debug, Display};
52
53use ProtocolError::*;
54use amaru_kernel::{Transaction, utils::string::display_collection};
55use amaru_ouroboros::{MempoolSeqNo, TxSubmissionMempool};
56use amaru_ouroboros_traits::TxId;
57use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
58use tracing::instrument;
59
60use crate::{
61    mempool_effects::MemoryPool,
62    mux::MuxMessage,
63    protocol::{
64        Initiator, Inputs, Miniprotocol, Outcome, PROTO_N2N_TX_SUB, ProtocolState, StageState, miniprotocol, outcome,
65    },
66    tx_submission::{Blocking, Message, ProtocolError, State},
67};
68
69const MAX_REQUESTED_TX_IDS: u16 = 10;
70
71pub fn register_deserializers() -> DeserializerGuards {
72    vec![
73        pure_stage::register_data_deserializer::<Void>().boxed(),
74        pure_stage::register_data_deserializer::<TxSubmissionInitiator>().boxed(),
75        pure_stage::register_data_deserializer::<(State, TxSubmissionInitiator)>().boxed(),
76    ]
77}
78
79pub fn initiator() -> Miniprotocol<State, TxSubmissionInitiator, Initiator> {
80    miniprotocol(PROTO_N2N_TX_SUB)
81}
82
83impl StageState<State, Initiator> for TxSubmissionInitiator {
84    type LocalIn = Void;
85
86    async fn local(
87        self,
88        _proto: &State,
89        _input: Self::LocalIn,
90        _eff: &Effects<Inputs<Self::LocalIn>>,
91    ) -> anyhow::Result<(Option<InitiatorAction>, Self)> {
92        // Currently no local inputs are handled
93        Ok((None, self))
94    }
95
96    #[instrument(name = "tx_submission.initiator.stage", skip_all, fields(message_type = input.message_type()))]
97    async fn network(
98        mut self,
99        _proto: &State,
100        input: InitiatorResult,
101        eff: &Effects<Inputs<Self::LocalIn>>,
102    ) -> anyhow::Result<(Option<InitiatorAction>, Self)> {
103        let mempool: &dyn TxSubmissionMempool<Transaction> = &MemoryPool::new(eff.clone());
104
105        let action = match input {
106            InitiatorResult::RequestTxIds { ack, req, blocking: Blocking::Yes } => {
107                self.request_tx_ids_blocking(mempool, ack, req).await?
108            }
109            InitiatorResult::RequestTxIds { ack, req, blocking: Blocking::No } => {
110                self.request_tx_ids_non_blocking(mempool, ack, req)?
111            }
112            InitiatorResult::RequestTxs(tx_ids) => self.request_txs(mempool, tx_ids)?,
113        };
114        Ok((action, self))
115    }
116
117    fn muxer(&self) -> &StageRef<MuxMessage> {
118        &self.muxer
119    }
120}
121
122impl ProtocolState<Initiator> for State {
123    type WireMsg = Message;
124    type Action = InitiatorAction;
125    type Out = InitiatorResult;
126    type Error = ProtocolError;
127
128    fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
129        Ok((outcome().send(Message::Init).want_next(), State::Idle))
130    }
131
132    #[instrument(name = "tx_submission.initiator.protocol", skip_all, fields(message_type = input.message_type()))]
133    fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
134        Ok(match (self, input) {
135            (State::Idle, Message::RequestTxIdsBlocking(ack, req)) => {
136                tracing::debug!(%ack, %req, "received RequestTxIdsBlocking");
137                (
138                    outcome().result(InitiatorResult::RequestTxIds { ack, req, blocking: Blocking::Yes }),
139                    State::TxIdsBlocking,
140                )
141            }
142            (State::Idle, Message::RequestTxIdsNonBlocking(ack, req)) => {
143                tracing::debug!(%ack, %req, "received RequestTxIdsNonBlocking");
144                (
145                    outcome().result(InitiatorResult::RequestTxIds { ack, req, blocking: Blocking::No }),
146                    State::TxIdsNonBlocking,
147                )
148            }
149            (State::Idle, Message::RequestTxs(tx_ids)) => {
150                tracing::debug!(tx_ids_nb = tx_ids.len(), "received RequestTxs");
151                (outcome().result(InitiatorResult::RequestTxs(tx_ids)), State::Txs)
152            }
153            (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
154        })
155    }
156
157    fn local(&self, action: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
158        Ok(match (self, action) {
159            (State::TxIdsBlocking, InitiatorAction::SendReplyTxIds(tx_ids)) => {
160                (outcome().send(Message::ReplyTxIds(tx_ids)).want_next(), State::Idle)
161            }
162            (State::TxIdsNonBlocking, InitiatorAction::SendReplyTxIds(tx_ids)) => {
163                (outcome().send(Message::ReplyTxIds(tx_ids)).want_next(), State::Idle)
164            }
165            (State::Txs, InitiatorAction::SendReplyTxs(txs)) => {
166                (outcome().send(Message::ReplyTxs(txs)).want_next(), State::Idle)
167            }
168            (State::TxIdsBlocking, InitiatorAction::Done) => (outcome().send(Message::Done), State::Done),
169            (_, InitiatorAction::Error(e)) => (outcome().terminate_with(e), State::Done),
170            (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
171        })
172    }
173}
174
175#[derive(Debug, PartialEq, Eq)]
176pub enum InitiatorAction {
177    SendReplyTxIds(Vec<(TxId, u32)>),
178    SendReplyTxs(Vec<Transaction>),
179    Error(ProtocolError),
180    Done,
181}
182
183impl Display for InitiatorAction {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        match self {
186            InitiatorAction::SendReplyTxIds(tx_ids) => {
187                write!(f, "SendReplyTxIds(len={})", tx_ids.len())
188            }
189            InitiatorAction::SendReplyTxs(txs) => write!(f, "SendReplyTxs(len={})", txs.len()),
190            InitiatorAction::Error(err) => write!(f, "Error({})", err),
191            InitiatorAction::Done => write!(f, "Done"),
192        }
193    }
194}
195
196/// Result from protocol state when network message is received
197#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
198pub enum InitiatorResult {
199    RequestTxIds { ack: u16, req: u16, blocking: Blocking },
200    RequestTxs(Vec<TxId>),
201}
202
203impl InitiatorResult {
204    pub fn message_type(&self) -> &str {
205        match self {
206            Self::RequestTxIds { .. } => "RequestTxIds",
207            Self::RequestTxs(_) => "RequestTxs",
208        }
209    }
210}
211
212impl Display for InitiatorResult {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        match self {
215            InitiatorResult::RequestTxIds { ack, req, blocking } => {
216                write!(f, "RequestTxIds(ack: {}, req: {}, blocking: {:?})", ack, req, blocking)
217            }
218            InitiatorResult::RequestTxs(tx_ids) => {
219                write!(
220                    f,
221                    "RequestTxs(ids: [{}])",
222                    tx_ids.iter().map(|id| format!("{}", id)).collect::<Vec<_>>().join(", ")
223                )
224            }
225        }
226    }
227}
228
229#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
230pub struct TxSubmissionInitiator {
231    /// What we’ve already advertised but has not yet been fully acked.
232    window: VecDeque<(TxId, MempoolSeqNo)>,
233    /// Last seq_no we have ever pulled from the mempool for this peer.
234    /// None if we have not pulled anything yet.
235    last_seq: Option<MempoolSeqNo>,
236    muxer: StageRef<MuxMessage>,
237}
238
239impl TxSubmissionInitiator {
240    pub fn new(muxer: StageRef<MuxMessage>) -> (State, Self) {
241        (State::Init, Self { window: VecDeque::new(), last_seq: None, muxer })
242    }
243
244    async fn request_tx_ids_blocking(
245        &mut self,
246        mempool: &dyn TxSubmissionMempool<Transaction>,
247        ack: u16,
248        req: u16,
249    ) -> anyhow::Result<Option<InitiatorAction>> {
250        // check the ack and req values
251        tracing::debug!(%ack, %req, "received RequestTxIdsBlocking");
252        if req == 0 {
253            return protocol_error(NoTxIdsRequested);
254        };
255        if let Some(value) = self.check_ack_req(ack, req) {
256            return protocol_error(value);
257        }
258        if (ack as usize) < self.window.len() {
259            return protocol_error(BlockingRequestMadeWhenTxsStillUnacknowledged);
260        }
261
262        // update the window by discarding acknowledged tx ids and update the last_seq
263        self.discard(ack);
264        if !mempool.wait_for_at_least(self.last_seq.unwrap_or_default().add(req as u64)).await {
265            return Ok(None);
266        }
267        let tx_ids = self.get_next_tx_ids(mempool, req)?;
268        Ok(Some(InitiatorAction::SendReplyTxIds(tx_ids)))
269    }
270
271    fn request_tx_ids_non_blocking(
272        &mut self,
273        mempool: &dyn TxSubmissionMempool<Transaction>,
274        ack: u16,
275        req: u16,
276    ) -> anyhow::Result<Option<InitiatorAction>> {
277        // check the ack and req values
278        tracing::debug!(%ack, %req, "received RequestTxIdsNonBlocking");
279        if ack == 0 && req == 0 {
280            return protocol_error(NoAckOrReqTxIdsRequested);
281        }
282        if let Some(error) = self.check_ack_req(ack, req) {
283            return protocol_error(error);
284        }
285        if ack as usize == self.window.len() {
286            return protocol_error(NonBlockingRequestMadeWhenAllTxsAcknowledged);
287        }
288
289        // update the window by discarding acknowledged tx ids and update the last_seq
290        self.discard(ack);
291        Ok(Some(InitiatorAction::SendReplyTxIds(self.get_next_tx_ids(mempool, req)?)))
292    }
293
294    fn request_txs(
295        &mut self,
296        mempool: &dyn TxSubmissionMempool<Transaction>,
297        tx_ids: Vec<TxId>,
298    ) -> anyhow::Result<Option<InitiatorAction>> {
299        tracing::debug!(tx_ids = display_collection(&tx_ids), "received RequestTxs");
300        if tx_ids.is_empty() {
301            return protocol_error(NoTxsRequested);
302        }
303        if tx_ids.iter().any(|id| !self.window.iter().any(|(wid, _)| wid == id)) {
304            return protocol_error(UnadvertisedTransactionIdsRequested(tx_ids));
305        }
306        let txs = mempool.get_txs_for_ids(tx_ids.as_slice());
307        if txs.is_empty() {
308            protocol_error(UnknownTxsRequested(tx_ids))
309        } else {
310            Ok(Some(InitiatorAction::SendReplyTxs(txs)))
311        }
312    }
313
314    /// Check that the ack and req values are valid for a request whether it is blocking or non blocking.
315    fn check_ack_req(&mut self, ack: u16, req: u16) -> Option<ProtocolError> {
316        if req > MAX_REQUESTED_TX_IDS {
317            Some(MaxOutstandingTxIdsRequested(req, MAX_REQUESTED_TX_IDS))
318        } else if ack as usize > self.window.len() {
319            Some(TooManyAcknowledgedTxs(ack, self.window.len() as u16))
320        } else {
321            None
322        }
323    }
324
325    /// Take notice of the acknowledged transactions, and send the next batch of tx ids.
326    fn get_next_tx_ids<Tx: Send + Debug + Sync + 'static>(
327        &mut self,
328        mempool: &dyn TxSubmissionMempool<Tx>,
329        required_next: u16,
330    ) -> anyhow::Result<Vec<(TxId, u32)>> {
331        let tx_ids = mempool.tx_ids_since(self.next_seq(), required_next);
332        let result = tx_ids.clone().into_iter().map(|(tx_id, tx_size, _)| (tx_id, tx_size)).collect();
333        self.update(tx_ids);
334        Ok(result)
335    }
336
337    /// We discard up to 'acknowledged' transactions from our window, in a FIFO manner.
338    fn discard(&mut self, acknowledged: u16) {
339        if self.window.len() >= acknowledged as usize {
340            self.window = self.window.drain(acknowledged as usize..).collect();
341        }
342    }
343
344    /// We update our window with tx ids retrieved from the mempool and just sent to the server.
345    fn update(&mut self, tx_ids: Vec<(TxId, u32, MempoolSeqNo)>) {
346        for (tx_id, _size, seq_no) in tx_ids {
347            self.window.push_back((tx_id, seq_no));
348            self.last_seq = Some(seq_no);
349        }
350    }
351
352    /// Compute the next sequence number to use when pulling from the mempool.
353    fn next_seq(&self) -> MempoolSeqNo {
354        match self.last_seq {
355            Some(seq) => seq.next(),
356            None => MempoolSeqNo(0),
357        }
358    }
359}
360
361fn protocol_error(error: ProtocolError) -> anyhow::Result<Option<InitiatorAction>> {
362    tracing::warn!("protocol error: {error}");
363    Ok(Some(InitiatorAction::Error(error)))
364}
365
366impl AsRef<StageRef<MuxMessage>> for TxSubmissionInitiator {
367    fn as_ref(&self) -> &StageRef<MuxMessage> {
368        &self.muxer
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use std::sync::Arc;
375
376    use super::*;
377    use crate::tx_submission::{
378        assert_actions_eq, create_transactions_in_mempool,
379        tests::{SizedMempool, create_transactions},
380    };
381
382    #[tokio::test]
383    async fn serve_transactions() -> anyhow::Result<()> {
384        // Create a mempool with some transactions
385        let mempool = Arc::new(SizedMempool::with_capacity(6));
386        let txs = create_transactions_in_mempool(mempool.clone(), 6);
387
388        // Send requests to retrieve transactions and block until they are available.
389        // In this case they are immediately available since we pre-populated the mempool.
390        // Note that we acknowledge tx[0] first in a non-blocking request (line 1), then tx[1],
391        // in the next blocking request (line 2).
392        let results = vec![
393            request_tx_ids(0, 2, Blocking::Yes),
394            request_txs(&txs, &[0, 1]),
395            request_tx_ids(1, 2, Blocking::No), // line 1
396            request_txs(&txs, &[2, 3]),
397            request_tx_ids(3, 2, Blocking::Yes), // line 2
398            request_txs(&txs, &[4, 5]),
399            request_tx_ids(2, 2, Blocking::Yes),
400        ];
401
402        let outcomes = run_stage(mempool, results).await?;
403
404        // Check replies
405        // We basically assert that we receive the expected ids and transactions
406        // 2 by 2, then the last one, since we requested batches of 2.
407        assert_actions_eq(
408            &outcomes,
409            &[
410                reply_tx_ids(&txs, &[0, 1]),
411                reply_txs(&txs, &[0, 1]),
412                reply_tx_ids(&txs, &[2, 3]),
413                reply_txs(&txs, &[2, 3]),
414                reply_tx_ids(&txs, &[4, 5]),
415                reply_txs(&txs, &[4, 5]),
416            ],
417        );
418
419        Ok(())
420    }
421
422    #[tokio::test]
423    async fn serve_transactions_with_mempool_refilling() -> anyhow::Result<()> {
424        // Create a mempool with some transactions
425        let mempool = Arc::new(SizedMempool::with_capacity(6));
426        let txs = create_transactions(6);
427
428        for tx in txs.iter().take(2) {
429            mempool.add(tx.clone())?;
430        }
431
432        // Send requests to retrieve transactions and block until they are available.
433        // In this case they are immediately available since we pre-populated the mempool.
434        let results =
435            vec![request_tx_ids(0, 2, Blocking::Yes), request_txs(&txs, &[0, 1]), request_tx_ids(1, 2, Blocking::No)];
436
437        let (actions, initiator) = run_stage_and_return_state(mempool.clone(), results).await?;
438        assert_actions_eq(&actions, &[reply_tx_ids(&txs, &[0, 1]), reply_txs(&txs, &[0, 1]), reply_tx_ids(&txs, &[])]);
439
440        // Refill the mempool with more transactions
441        for tx in &txs[2..] {
442            mempool.add(tx.clone())?;
443        }
444        let messages = vec![
445            request_tx_ids(1, 2, Blocking::Yes),
446            request_txs(&txs, &[2, 3]),
447            request_tx_ids(2, 2, Blocking::Yes),
448            request_txs(&txs, &[4, 5]),
449            request_tx_ids(2, 2, Blocking::Yes),
450        ];
451
452        let (actions, _) = run_stage_and_return_state_with(initiator, mempool, messages).await?;
453
454        // Check replies
455        // We basically assert that we receive the expected ids and transactions
456        // 2 by 2, then the last one, since we requested batches of 2.
457        assert_actions_eq(
458            &actions,
459            &[
460                reply_tx_ids(&txs, &[2, 3]),
461                reply_txs(&txs, &[2, 3]),
462                reply_tx_ids(&txs, &[4, 5]),
463                reply_txs(&txs, &[4, 5]),
464            ],
465        );
466        Ok(())
467    }
468
469    #[tokio::test]
470    async fn request_txs_must_come_from_requested_ids() -> anyhow::Result<()> {
471        // Create a mempool with some transactions
472        let mempool = Arc::new(SizedMempool::with_capacity(6));
473        let txs = create_transactions_in_mempool(mempool.clone(), 4);
474
475        // Send requests to retrieve transactions and block until they are available.
476        // In this case they are immediately available since we pre-populated the mempool.
477        // The reply to the first message will be tx ids 0 and 1, which means that the responder
478        // should then request transactions for those ids only.
479        // In this test we receive a request for tx ids 2 and 3, which were not advertised yet,
480        // so the initiator should terminate the session.
481        let results = vec![request_tx_ids(0, 2, Blocking::Yes), request_txs(&txs, &[2, 3])];
482
483        let actions = run_stage(mempool, results).await?;
484        assert_actions_eq(
485            &actions,
486            &[
487                reply_tx_ids(&txs, &[0, 1]),
488                error_action(UnadvertisedTransactionIdsRequested(vec![TxId::from(&txs[2]), TxId::from(&txs[3])])),
489            ],
490        );
491        Ok(())
492    }
493
494    #[tokio::test]
495    async fn blocking_requested_ids_must_be_greater_than_0() -> anyhow::Result<()> {
496        let mempool = Arc::new(SizedMempool::with_capacity(6));
497
498        let results = vec![request_tx_ids(0, 0, Blocking::Yes)];
499        let actions = run_stage(mempool, results).await?;
500        assert_actions_eq(&actions, &[error_action(NoTxIdsRequested)]);
501        Ok(())
502    }
503
504    #[tokio::test]
505    async fn blocking_requested_txs_must_be_greater_than_0() -> anyhow::Result<()> {
506        let mempool = Arc::new(SizedMempool::with_capacity(4));
507        let txs = create_transactions_in_mempool(mempool.clone(), 4);
508
509        let results = vec![request_tx_ids(0, 2, Blocking::Yes), request_txs(&txs, &[])];
510
511        let actions = run_stage(mempool, results).await?;
512        assert_actions_eq(&actions, &[reply_tx_ids(&txs, &[0, 1]), error_action(NoTxsRequested)]);
513        Ok(())
514    }
515
516    #[tokio::test]
517    async fn non_blocking_ack_or_requested_ids_must_be_greater_than_0() -> anyhow::Result<()> {
518        let mempool = Arc::new(SizedMempool::with_capacity(6));
519
520        let results = vec![request_tx_ids(0, 0, Blocking::No)];
521        let actions = run_stage(mempool, results).await?;
522        assert_actions_eq(&actions, &[error_action(NoAckOrReqTxIdsRequested)]);
523        Ok(())
524    }
525
526    #[tokio::test]
527    async fn blocking_requested_nb_must_be_less_than_protocol_limit() -> anyhow::Result<()> {
528        let mempool = Arc::new(SizedMempool::with_capacity(6));
529
530        let results = vec![request_tx_ids(0, 12, Blocking::Yes)];
531        let actions = run_stage(mempool, results).await?;
532        assert_actions_eq(&actions, &[error_action(MaxOutstandingTxIdsRequested(12, MAX_REQUESTED_TX_IDS))]);
533        Ok(())
534    }
535
536    #[tokio::test]
537    async fn non_blocking_requested_nb_must_be_less_than_protocol_limit() -> anyhow::Result<()> {
538        let mempool = Arc::new(SizedMempool::with_capacity(6));
539
540        let results = vec![request_tx_ids(0, 12, Blocking::No)];
541        let actions = run_stage(mempool, results).await?;
542        assert_actions_eq(&actions, &[error_action(MaxOutstandingTxIdsRequested(12, MAX_REQUESTED_TX_IDS))]);
543        Ok(())
544    }
545
546    #[tokio::test]
547    async fn a_blocking_request_must_be_made_when_all_txs_are_acknowledged() -> anyhow::Result<()> {
548        let mempool = Arc::new(SizedMempool::with_capacity(4));
549        let txs = create_transactions_in_mempool(mempool.clone(), 4);
550
551        let results = vec![
552            request_tx_ids(0, 4, Blocking::Yes),
553            request_txs(&txs, &[0, 1]),
554            request_tx_ids(2, 4, Blocking::No),
555            request_txs(&txs, &[2, 3]),
556            request_tx_ids(2, 4, Blocking::No),
557        ];
558        let actions = run_stage(mempool, results).await?;
559        assert_actions_eq(
560            &actions,
561            &[
562                reply_tx_ids(&txs, &[0, 1, 2, 3]),
563                reply_txs(&txs, &[0, 1]),
564                reply_tx_ids(&txs, &[]),
565                reply_txs(&txs, &[2, 3]),
566                error_action(NonBlockingRequestMadeWhenAllTxsAcknowledged),
567            ],
568        );
569        Ok(())
570    }
571
572    #[tokio::test]
573    async fn a_non_blocking_request_must_be_made_when_some_txs_are_unacknowledged() -> anyhow::Result<()> {
574        let mempool = Arc::new(SizedMempool::with_capacity(4));
575        let txs = create_transactions_in_mempool(mempool.clone(), 4);
576
577        let results =
578            vec![request_tx_ids(0, 4, Blocking::Yes), request_txs(&txs, &[0, 1]), request_tx_ids(2, 4, Blocking::Yes)];
579        let actions = run_stage(mempool, results).await?;
580        assert_actions_eq(
581            &actions,
582            &[
583                reply_tx_ids(&txs, &[0, 1, 2, 3]),
584                reply_txs(&txs, &[0, 1]),
585                error_action(BlockingRequestMadeWhenTxsStillUnacknowledged),
586            ],
587        );
588        Ok(())
589    }
590
591    #[tokio::test]
592    async fn the_responder_cannot_acknowledge_more_than_the_current_unacknowledged_blocking() -> anyhow::Result<()> {
593        let mempool = Arc::new(SizedMempool::with_capacity(4));
594        let txs = create_transactions_in_mempool(mempool.clone(), 4);
595
596        let results = vec![
597            request_tx_ids(0, 4, Blocking::Yes),
598            request_txs(&txs, &[0, 1]),
599            request_tx_ids(2, 4, Blocking::No),
600            request_txs(&txs, &[2, 3]),
601            request_tx_ids(4, 4, Blocking::Yes),
602        ];
603        let actions = run_stage(mempool, results).await?;
604        assert_actions_eq(
605            &actions,
606            &[
607                reply_tx_ids(&txs, &[0, 1, 2, 3]),
608                reply_txs(&txs, &[0, 1]),
609                reply_tx_ids(&txs, &[]),
610                reply_txs(&txs, &[2, 3]),
611                error_action(TooManyAcknowledgedTxs(4, 2)),
612            ],
613        );
614        Ok(())
615    }
616
617    #[tokio::test]
618    async fn the_responder_cannot_acknowledge_more_than_the_current_unacknowledged_non_blocking() -> anyhow::Result<()>
619    {
620        let mempool = Arc::new(SizedMempool::with_capacity(4));
621        let txs = create_transactions_in_mempool(mempool.clone(), 4);
622
623        let results = vec![
624            request_tx_ids(0, 4, Blocking::Yes),
625            request_txs(&txs, &[0, 1]),
626            request_tx_ids(2, 4, Blocking::No),
627            request_txs(&txs, &[2, 3]),
628            request_tx_ids(4, 4, Blocking::No),
629        ];
630        let actions = run_stage(mempool, results).await?;
631        assert_actions_eq(
632            &actions,
633            &[
634                reply_tx_ids(&txs, &[0, 1, 2, 3]),
635                reply_txs(&txs, &[0, 1]),
636                reply_tx_ids(&txs, &[]),
637                reply_txs(&txs, &[2, 3]),
638                error_action(TooManyAcknowledgedTxs(4, 2)),
639            ],
640        );
641        Ok(())
642    }
643
644    #[test]
645    fn test_initiator_protocol() {
646        crate::tx_submission::spec::<Initiator>().check(State::Init, |msg| match msg {
647            Message::ReplyTxIds(tx_ids) => Some(InitiatorAction::SendReplyTxIds(tx_ids.clone())),
648            Message::ReplyTxs(txs) => Some(InitiatorAction::SendReplyTxs(txs.clone())),
649            Message::Done => Some(InitiatorAction::Done),
650            Message::Init
651            | Message::RequestTxs(_)
652            | Message::RequestTxIdsBlocking(_, _)
653            | Message::RequestTxIdsNonBlocking(_, _) => None,
654        });
655    }
656
657    // HELPERS
658
659    async fn run_stage(
660        mempool: Arc<dyn TxSubmissionMempool<Transaction>>,
661        results: Vec<InitiatorResult>,
662    ) -> anyhow::Result<Vec<InitiatorAction>> {
663        let (actions, _initiator) = run_stage_and_return_state(mempool, results).await?;
664        Ok(actions)
665    }
666
667    async fn run_stage_and_return_state(
668        mempool: Arc<dyn TxSubmissionMempool<Transaction>>,
669        results: Vec<InitiatorResult>,
670    ) -> anyhow::Result<(Vec<InitiatorAction>, TxSubmissionInitiator)> {
671        run_stage_and_return_state_with(
672            TxSubmissionInitiator::new(StageRef::named_for_tests("muxer")).1,
673            mempool,
674            results,
675        )
676        .await
677    }
678
679    async fn run_stage_and_return_state_with(
680        mut initiator: TxSubmissionInitiator,
681        mempool: Arc<dyn TxSubmissionMempool<Transaction>>,
682        results: Vec<InitiatorResult>,
683    ) -> anyhow::Result<(Vec<InitiatorAction>, TxSubmissionInitiator)> {
684        let mut actions = vec![];
685        for r in results {
686            let action = step(&mut initiator, r, mempool.as_ref()).await?;
687            if let Some(action) = action {
688                actions.push(action);
689            }
690        }
691        Ok((actions, initiator))
692    }
693
694    async fn step(
695        initiator: &mut TxSubmissionInitiator,
696        input: InitiatorResult,
697        mempool: &dyn TxSubmissionMempool<Transaction>,
698    ) -> anyhow::Result<Option<InitiatorAction>> {
699        let action = match input {
700            InitiatorResult::RequestTxIds { ack, req, blocking: Blocking::Yes } => {
701                initiator.request_tx_ids_blocking(mempool, ack, req).await?
702            }
703            InitiatorResult::RequestTxIds { ack, req, blocking: Blocking::No } => {
704                initiator.request_tx_ids_non_blocking(mempool, ack, req)?
705            }
706            InitiatorResult::RequestTxs(tx_ids) => initiator.request_txs(mempool, tx_ids)?,
707        };
708        Ok(action)
709    }
710
711    fn reply_tx_ids(txs: &[Transaction], ids: &[usize]) -> InitiatorAction {
712        let default_transaction_size = 49;
713        InitiatorAction::SendReplyTxIds(
714            ids.iter().map(|id| (TxId::from(&txs[*id]), default_transaction_size)).collect(),
715        )
716    }
717
718    fn reply_txs(txs: &[Transaction], ids: &[usize]) -> InitiatorAction {
719        InitiatorAction::SendReplyTxs(ids.iter().map(|id| txs[*id].clone()).collect())
720    }
721
722    fn request_tx_ids(ack: u16, req: u16, blocking: Blocking) -> InitiatorResult {
723        InitiatorResult::RequestTxIds { ack, req, blocking }
724    }
725
726    fn request_txs(txs: &[Transaction], ids: &[usize]) -> InitiatorResult {
727        InitiatorResult::RequestTxs(ids.iter().map(|id| TxId::from(&txs[*id])).collect())
728    }
729
730    fn error_action(error: ProtocolError) -> InitiatorAction {
731        InitiatorAction::Error(error)
732    }
733}