1use std::collections::VecDeque;
51use std::fmt::{Debug, Display};
52
53use ProtocolError::*;
54use amaru_kernel::{Transaction, utils::string::display_collection};
55use amaru_ouroboros::{MempoolSeqNo, TxSubmissionMempool};
56use amaru_ouroboros_traits::TxId;
57use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
58use tracing::instrument;
59
60use crate::{
61 mempool_effects::MemoryPool,
62 mux::MuxMessage,
63 protocol::{
64 Initiator, Inputs, Miniprotocol, Outcome, PROTO_N2N_TX_SUB, ProtocolState, StageState, miniprotocol, outcome,
65 },
66 tx_submission::{Blocking, Message, ProtocolError, State},
67};
68
69const MAX_REQUESTED_TX_IDS: u16 = 10;
70
71pub fn register_deserializers() -> DeserializerGuards {
72 vec![
73 pure_stage::register_data_deserializer::<Void>().boxed(),
74 pure_stage::register_data_deserializer::<TxSubmissionInitiator>().boxed(),
75 pure_stage::register_data_deserializer::<(State, TxSubmissionInitiator)>().boxed(),
76 ]
77}
78
79pub fn initiator() -> Miniprotocol<State, TxSubmissionInitiator, Initiator> {
80 miniprotocol(PROTO_N2N_TX_SUB)
81}
82
83impl StageState<State, Initiator> for TxSubmissionInitiator {
84 type LocalIn = Void;
85
86 async fn local(
87 self,
88 _proto: &State,
89 _input: Self::LocalIn,
90 _eff: &Effects<Inputs<Self::LocalIn>>,
91 ) -> anyhow::Result<(Option<InitiatorAction>, Self)> {
92 Ok((None, self))
94 }
95
96 #[instrument(name = "tx_submission.initiator.stage", skip_all, fields(message_type = input.message_type()))]
97 async fn network(
98 mut self,
99 _proto: &State,
100 input: InitiatorResult,
101 eff: &Effects<Inputs<Self::LocalIn>>,
102 ) -> anyhow::Result<(Option<InitiatorAction>, Self)> {
103 let mempool: &dyn TxSubmissionMempool<Transaction> = &MemoryPool::new(eff.clone());
104
105 let action = match input {
106 InitiatorResult::RequestTxIds { ack, req, blocking: Blocking::Yes } => {
107 self.request_tx_ids_blocking(mempool, ack, req).await?
108 }
109 InitiatorResult::RequestTxIds { ack, req, blocking: Blocking::No } => {
110 self.request_tx_ids_non_blocking(mempool, ack, req)?
111 }
112 InitiatorResult::RequestTxs(tx_ids) => self.request_txs(mempool, tx_ids)?,
113 };
114 Ok((action, self))
115 }
116
117 fn muxer(&self) -> &StageRef<MuxMessage> {
118 &self.muxer
119 }
120}
121
122impl ProtocolState<Initiator> for State {
123 type WireMsg = Message;
124 type Action = InitiatorAction;
125 type Out = InitiatorResult;
126 type Error = ProtocolError;
127
128 fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
129 Ok((outcome().send(Message::Init).want_next(), State::Idle))
130 }
131
132 #[instrument(name = "tx_submission.initiator.protocol", skip_all, fields(message_type = input.message_type()))]
133 fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
134 Ok(match (self, input) {
135 (State::Idle, Message::RequestTxIdsBlocking(ack, req)) => {
136 tracing::debug!(%ack, %req, "received RequestTxIdsBlocking");
137 (
138 outcome().result(InitiatorResult::RequestTxIds { ack, req, blocking: Blocking::Yes }),
139 State::TxIdsBlocking,
140 )
141 }
142 (State::Idle, Message::RequestTxIdsNonBlocking(ack, req)) => {
143 tracing::debug!(%ack, %req, "received RequestTxIdsNonBlocking");
144 (
145 outcome().result(InitiatorResult::RequestTxIds { ack, req, blocking: Blocking::No }),
146 State::TxIdsNonBlocking,
147 )
148 }
149 (State::Idle, Message::RequestTxs(tx_ids)) => {
150 tracing::debug!(tx_ids_nb = tx_ids.len(), "received RequestTxs");
151 (outcome().result(InitiatorResult::RequestTxs(tx_ids)), State::Txs)
152 }
153 (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
154 })
155 }
156
157 fn local(&self, action: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
158 Ok(match (self, action) {
159 (State::TxIdsBlocking, InitiatorAction::SendReplyTxIds(tx_ids)) => {
160 (outcome().send(Message::ReplyTxIds(tx_ids)).want_next(), State::Idle)
161 }
162 (State::TxIdsNonBlocking, InitiatorAction::SendReplyTxIds(tx_ids)) => {
163 (outcome().send(Message::ReplyTxIds(tx_ids)).want_next(), State::Idle)
164 }
165 (State::Txs, InitiatorAction::SendReplyTxs(txs)) => {
166 (outcome().send(Message::ReplyTxs(txs)).want_next(), State::Idle)
167 }
168 (State::TxIdsBlocking, InitiatorAction::Done) => (outcome().send(Message::Done), State::Done),
169 (_, InitiatorAction::Error(e)) => (outcome().terminate_with(e), State::Done),
170 (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
171 })
172 }
173}
174
175#[derive(Debug, PartialEq, Eq)]
176pub enum InitiatorAction {
177 SendReplyTxIds(Vec<(TxId, u32)>),
178 SendReplyTxs(Vec<Transaction>),
179 Error(ProtocolError),
180 Done,
181}
182
183impl Display for InitiatorAction {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 match self {
186 InitiatorAction::SendReplyTxIds(tx_ids) => {
187 write!(f, "SendReplyTxIds(len={})", tx_ids.len())
188 }
189 InitiatorAction::SendReplyTxs(txs) => write!(f, "SendReplyTxs(len={})", txs.len()),
190 InitiatorAction::Error(err) => write!(f, "Error({})", err),
191 InitiatorAction::Done => write!(f, "Done"),
192 }
193 }
194}
195
196#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
198pub enum InitiatorResult {
199 RequestTxIds { ack: u16, req: u16, blocking: Blocking },
200 RequestTxs(Vec<TxId>),
201}
202
203impl InitiatorResult {
204 pub fn message_type(&self) -> &str {
205 match self {
206 Self::RequestTxIds { .. } => "RequestTxIds",
207 Self::RequestTxs(_) => "RequestTxs",
208 }
209 }
210}
211
212impl Display for InitiatorResult {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 match self {
215 InitiatorResult::RequestTxIds { ack, req, blocking } => {
216 write!(f, "RequestTxIds(ack: {}, req: {}, blocking: {:?})", ack, req, blocking)
217 }
218 InitiatorResult::RequestTxs(tx_ids) => {
219 write!(
220 f,
221 "RequestTxs(ids: [{}])",
222 tx_ids.iter().map(|id| format!("{}", id)).collect::<Vec<_>>().join(", ")
223 )
224 }
225 }
226 }
227}
228
229#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
230pub struct TxSubmissionInitiator {
231 window: VecDeque<(TxId, MempoolSeqNo)>,
233 last_seq: Option<MempoolSeqNo>,
236 muxer: StageRef<MuxMessage>,
237}
238
239impl TxSubmissionInitiator {
240 pub fn new(muxer: StageRef<MuxMessage>) -> (State, Self) {
241 (State::Init, Self { window: VecDeque::new(), last_seq: None, muxer })
242 }
243
244 async fn request_tx_ids_blocking(
245 &mut self,
246 mempool: &dyn TxSubmissionMempool<Transaction>,
247 ack: u16,
248 req: u16,
249 ) -> anyhow::Result<Option<InitiatorAction>> {
250 tracing::debug!(%ack, %req, "received RequestTxIdsBlocking");
252 if req == 0 {
253 return protocol_error(NoTxIdsRequested);
254 };
255 if let Some(value) = self.check_ack_req(ack, req) {
256 return protocol_error(value);
257 }
258 if (ack as usize) < self.window.len() {
259 return protocol_error(BlockingRequestMadeWhenTxsStillUnacknowledged);
260 }
261
262 self.discard(ack);
264 if !mempool.wait_for_at_least(self.last_seq.unwrap_or_default().add(req as u64)).await {
265 return Ok(None);
266 }
267 let tx_ids = self.get_next_tx_ids(mempool, req)?;
268 Ok(Some(InitiatorAction::SendReplyTxIds(tx_ids)))
269 }
270
271 fn request_tx_ids_non_blocking(
272 &mut self,
273 mempool: &dyn TxSubmissionMempool<Transaction>,
274 ack: u16,
275 req: u16,
276 ) -> anyhow::Result<Option<InitiatorAction>> {
277 tracing::debug!(%ack, %req, "received RequestTxIdsNonBlocking");
279 if ack == 0 && req == 0 {
280 return protocol_error(NoAckOrReqTxIdsRequested);
281 }
282 if let Some(error) = self.check_ack_req(ack, req) {
283 return protocol_error(error);
284 }
285 if ack as usize == self.window.len() {
286 return protocol_error(NonBlockingRequestMadeWhenAllTxsAcknowledged);
287 }
288
289 self.discard(ack);
291 Ok(Some(InitiatorAction::SendReplyTxIds(self.get_next_tx_ids(mempool, req)?)))
292 }
293
294 fn request_txs(
295 &mut self,
296 mempool: &dyn TxSubmissionMempool<Transaction>,
297 tx_ids: Vec<TxId>,
298 ) -> anyhow::Result<Option<InitiatorAction>> {
299 tracing::debug!(tx_ids = display_collection(&tx_ids), "received RequestTxs");
300 if tx_ids.is_empty() {
301 return protocol_error(NoTxsRequested);
302 }
303 if tx_ids.iter().any(|id| !self.window.iter().any(|(wid, _)| wid == id)) {
304 return protocol_error(UnadvertisedTransactionIdsRequested(tx_ids));
305 }
306 let txs = mempool.get_txs_for_ids(tx_ids.as_slice());
307 if txs.is_empty() {
308 protocol_error(UnknownTxsRequested(tx_ids))
309 } else {
310 Ok(Some(InitiatorAction::SendReplyTxs(txs)))
311 }
312 }
313
314 fn check_ack_req(&mut self, ack: u16, req: u16) -> Option<ProtocolError> {
316 if req > MAX_REQUESTED_TX_IDS {
317 Some(MaxOutstandingTxIdsRequested(req, MAX_REQUESTED_TX_IDS))
318 } else if ack as usize > self.window.len() {
319 Some(TooManyAcknowledgedTxs(ack, self.window.len() as u16))
320 } else {
321 None
322 }
323 }
324
325 fn get_next_tx_ids<Tx: Send + Debug + Sync + 'static>(
327 &mut self,
328 mempool: &dyn TxSubmissionMempool<Tx>,
329 required_next: u16,
330 ) -> anyhow::Result<Vec<(TxId, u32)>> {
331 let tx_ids = mempool.tx_ids_since(self.next_seq(), required_next);
332 let result = tx_ids.clone().into_iter().map(|(tx_id, tx_size, _)| (tx_id, tx_size)).collect();
333 self.update(tx_ids);
334 Ok(result)
335 }
336
337 fn discard(&mut self, acknowledged: u16) {
339 if self.window.len() >= acknowledged as usize {
340 self.window = self.window.drain(acknowledged as usize..).collect();
341 }
342 }
343
344 fn update(&mut self, tx_ids: Vec<(TxId, u32, MempoolSeqNo)>) {
346 for (tx_id, _size, seq_no) in tx_ids {
347 self.window.push_back((tx_id, seq_no));
348 self.last_seq = Some(seq_no);
349 }
350 }
351
352 fn next_seq(&self) -> MempoolSeqNo {
354 match self.last_seq {
355 Some(seq) => seq.next(),
356 None => MempoolSeqNo(0),
357 }
358 }
359}
360
361fn protocol_error(error: ProtocolError) -> anyhow::Result<Option<InitiatorAction>> {
362 tracing::warn!("protocol error: {error}");
363 Ok(Some(InitiatorAction::Error(error)))
364}
365
366impl AsRef<StageRef<MuxMessage>> for TxSubmissionInitiator {
367 fn as_ref(&self) -> &StageRef<MuxMessage> {
368 &self.muxer
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use std::sync::Arc;
375
376 use super::*;
377 use crate::tx_submission::{
378 assert_actions_eq, create_transactions_in_mempool,
379 tests::{SizedMempool, create_transactions},
380 };
381
382 #[tokio::test]
383 async fn serve_transactions() -> anyhow::Result<()> {
384 let mempool = Arc::new(SizedMempool::with_capacity(6));
386 let txs = create_transactions_in_mempool(mempool.clone(), 6);
387
388 let results = vec![
393 request_tx_ids(0, 2, Blocking::Yes),
394 request_txs(&txs, &[0, 1]),
395 request_tx_ids(1, 2, Blocking::No), request_txs(&txs, &[2, 3]),
397 request_tx_ids(3, 2, Blocking::Yes), request_txs(&txs, &[4, 5]),
399 request_tx_ids(2, 2, Blocking::Yes),
400 ];
401
402 let outcomes = run_stage(mempool, results).await?;
403
404 assert_actions_eq(
408 &outcomes,
409 &[
410 reply_tx_ids(&txs, &[0, 1]),
411 reply_txs(&txs, &[0, 1]),
412 reply_tx_ids(&txs, &[2, 3]),
413 reply_txs(&txs, &[2, 3]),
414 reply_tx_ids(&txs, &[4, 5]),
415 reply_txs(&txs, &[4, 5]),
416 ],
417 );
418
419 Ok(())
420 }
421
422 #[tokio::test]
423 async fn serve_transactions_with_mempool_refilling() -> anyhow::Result<()> {
424 let mempool = Arc::new(SizedMempool::with_capacity(6));
426 let txs = create_transactions(6);
427
428 for tx in txs.iter().take(2) {
429 mempool.add(tx.clone())?;
430 }
431
432 let results =
435 vec![request_tx_ids(0, 2, Blocking::Yes), request_txs(&txs, &[0, 1]), request_tx_ids(1, 2, Blocking::No)];
436
437 let (actions, initiator) = run_stage_and_return_state(mempool.clone(), results).await?;
438 assert_actions_eq(&actions, &[reply_tx_ids(&txs, &[0, 1]), reply_txs(&txs, &[0, 1]), reply_tx_ids(&txs, &[])]);
439
440 for tx in &txs[2..] {
442 mempool.add(tx.clone())?;
443 }
444 let messages = vec![
445 request_tx_ids(1, 2, Blocking::Yes),
446 request_txs(&txs, &[2, 3]),
447 request_tx_ids(2, 2, Blocking::Yes),
448 request_txs(&txs, &[4, 5]),
449 request_tx_ids(2, 2, Blocking::Yes),
450 ];
451
452 let (actions, _) = run_stage_and_return_state_with(initiator, mempool, messages).await?;
453
454 assert_actions_eq(
458 &actions,
459 &[
460 reply_tx_ids(&txs, &[2, 3]),
461 reply_txs(&txs, &[2, 3]),
462 reply_tx_ids(&txs, &[4, 5]),
463 reply_txs(&txs, &[4, 5]),
464 ],
465 );
466 Ok(())
467 }
468
469 #[tokio::test]
470 async fn request_txs_must_come_from_requested_ids() -> anyhow::Result<()> {
471 let mempool = Arc::new(SizedMempool::with_capacity(6));
473 let txs = create_transactions_in_mempool(mempool.clone(), 4);
474
475 let results = vec![request_tx_ids(0, 2, Blocking::Yes), request_txs(&txs, &[2, 3])];
482
483 let actions = run_stage(mempool, results).await?;
484 assert_actions_eq(
485 &actions,
486 &[
487 reply_tx_ids(&txs, &[0, 1]),
488 error_action(UnadvertisedTransactionIdsRequested(vec![TxId::from(&txs[2]), TxId::from(&txs[3])])),
489 ],
490 );
491 Ok(())
492 }
493
494 #[tokio::test]
495 async fn blocking_requested_ids_must_be_greater_than_0() -> anyhow::Result<()> {
496 let mempool = Arc::new(SizedMempool::with_capacity(6));
497
498 let results = vec![request_tx_ids(0, 0, Blocking::Yes)];
499 let actions = run_stage(mempool, results).await?;
500 assert_actions_eq(&actions, &[error_action(NoTxIdsRequested)]);
501 Ok(())
502 }
503
504 #[tokio::test]
505 async fn blocking_requested_txs_must_be_greater_than_0() -> anyhow::Result<()> {
506 let mempool = Arc::new(SizedMempool::with_capacity(4));
507 let txs = create_transactions_in_mempool(mempool.clone(), 4);
508
509 let results = vec![request_tx_ids(0, 2, Blocking::Yes), request_txs(&txs, &[])];
510
511 let actions = run_stage(mempool, results).await?;
512 assert_actions_eq(&actions, &[reply_tx_ids(&txs, &[0, 1]), error_action(NoTxsRequested)]);
513 Ok(())
514 }
515
516 #[tokio::test]
517 async fn non_blocking_ack_or_requested_ids_must_be_greater_than_0() -> anyhow::Result<()> {
518 let mempool = Arc::new(SizedMempool::with_capacity(6));
519
520 let results = vec![request_tx_ids(0, 0, Blocking::No)];
521 let actions = run_stage(mempool, results).await?;
522 assert_actions_eq(&actions, &[error_action(NoAckOrReqTxIdsRequested)]);
523 Ok(())
524 }
525
526 #[tokio::test]
527 async fn blocking_requested_nb_must_be_less_than_protocol_limit() -> anyhow::Result<()> {
528 let mempool = Arc::new(SizedMempool::with_capacity(6));
529
530 let results = vec![request_tx_ids(0, 12, Blocking::Yes)];
531 let actions = run_stage(mempool, results).await?;
532 assert_actions_eq(&actions, &[error_action(MaxOutstandingTxIdsRequested(12, MAX_REQUESTED_TX_IDS))]);
533 Ok(())
534 }
535
536 #[tokio::test]
537 async fn non_blocking_requested_nb_must_be_less_than_protocol_limit() -> anyhow::Result<()> {
538 let mempool = Arc::new(SizedMempool::with_capacity(6));
539
540 let results = vec![request_tx_ids(0, 12, Blocking::No)];
541 let actions = run_stage(mempool, results).await?;
542 assert_actions_eq(&actions, &[error_action(MaxOutstandingTxIdsRequested(12, MAX_REQUESTED_TX_IDS))]);
543 Ok(())
544 }
545
546 #[tokio::test]
547 async fn a_blocking_request_must_be_made_when_all_txs_are_acknowledged() -> anyhow::Result<()> {
548 let mempool = Arc::new(SizedMempool::with_capacity(4));
549 let txs = create_transactions_in_mempool(mempool.clone(), 4);
550
551 let results = vec![
552 request_tx_ids(0, 4, Blocking::Yes),
553 request_txs(&txs, &[0, 1]),
554 request_tx_ids(2, 4, Blocking::No),
555 request_txs(&txs, &[2, 3]),
556 request_tx_ids(2, 4, Blocking::No),
557 ];
558 let actions = run_stage(mempool, results).await?;
559 assert_actions_eq(
560 &actions,
561 &[
562 reply_tx_ids(&txs, &[0, 1, 2, 3]),
563 reply_txs(&txs, &[0, 1]),
564 reply_tx_ids(&txs, &[]),
565 reply_txs(&txs, &[2, 3]),
566 error_action(NonBlockingRequestMadeWhenAllTxsAcknowledged),
567 ],
568 );
569 Ok(())
570 }
571
572 #[tokio::test]
573 async fn a_non_blocking_request_must_be_made_when_some_txs_are_unacknowledged() -> anyhow::Result<()> {
574 let mempool = Arc::new(SizedMempool::with_capacity(4));
575 let txs = create_transactions_in_mempool(mempool.clone(), 4);
576
577 let results =
578 vec![request_tx_ids(0, 4, Blocking::Yes), request_txs(&txs, &[0, 1]), request_tx_ids(2, 4, Blocking::Yes)];
579 let actions = run_stage(mempool, results).await?;
580 assert_actions_eq(
581 &actions,
582 &[
583 reply_tx_ids(&txs, &[0, 1, 2, 3]),
584 reply_txs(&txs, &[0, 1]),
585 error_action(BlockingRequestMadeWhenTxsStillUnacknowledged),
586 ],
587 );
588 Ok(())
589 }
590
591 #[tokio::test]
592 async fn the_responder_cannot_acknowledge_more_than_the_current_unacknowledged_blocking() -> anyhow::Result<()> {
593 let mempool = Arc::new(SizedMempool::with_capacity(4));
594 let txs = create_transactions_in_mempool(mempool.clone(), 4);
595
596 let results = vec![
597 request_tx_ids(0, 4, Blocking::Yes),
598 request_txs(&txs, &[0, 1]),
599 request_tx_ids(2, 4, Blocking::No),
600 request_txs(&txs, &[2, 3]),
601 request_tx_ids(4, 4, Blocking::Yes),
602 ];
603 let actions = run_stage(mempool, results).await?;
604 assert_actions_eq(
605 &actions,
606 &[
607 reply_tx_ids(&txs, &[0, 1, 2, 3]),
608 reply_txs(&txs, &[0, 1]),
609 reply_tx_ids(&txs, &[]),
610 reply_txs(&txs, &[2, 3]),
611 error_action(TooManyAcknowledgedTxs(4, 2)),
612 ],
613 );
614 Ok(())
615 }
616
617 #[tokio::test]
618 async fn the_responder_cannot_acknowledge_more_than_the_current_unacknowledged_non_blocking() -> anyhow::Result<()>
619 {
620 let mempool = Arc::new(SizedMempool::with_capacity(4));
621 let txs = create_transactions_in_mempool(mempool.clone(), 4);
622
623 let results = vec![
624 request_tx_ids(0, 4, Blocking::Yes),
625 request_txs(&txs, &[0, 1]),
626 request_tx_ids(2, 4, Blocking::No),
627 request_txs(&txs, &[2, 3]),
628 request_tx_ids(4, 4, Blocking::No),
629 ];
630 let actions = run_stage(mempool, results).await?;
631 assert_actions_eq(
632 &actions,
633 &[
634 reply_tx_ids(&txs, &[0, 1, 2, 3]),
635 reply_txs(&txs, &[0, 1]),
636 reply_tx_ids(&txs, &[]),
637 reply_txs(&txs, &[2, 3]),
638 error_action(TooManyAcknowledgedTxs(4, 2)),
639 ],
640 );
641 Ok(())
642 }
643
644 #[test]
645 fn test_initiator_protocol() {
646 crate::tx_submission::spec::<Initiator>().check(State::Init, |msg| match msg {
647 Message::ReplyTxIds(tx_ids) => Some(InitiatorAction::SendReplyTxIds(tx_ids.clone())),
648 Message::ReplyTxs(txs) => Some(InitiatorAction::SendReplyTxs(txs.clone())),
649 Message::Done => Some(InitiatorAction::Done),
650 Message::Init
651 | Message::RequestTxs(_)
652 | Message::RequestTxIdsBlocking(_, _)
653 | Message::RequestTxIdsNonBlocking(_, _) => None,
654 });
655 }
656
657 async fn run_stage(
660 mempool: Arc<dyn TxSubmissionMempool<Transaction>>,
661 results: Vec<InitiatorResult>,
662 ) -> anyhow::Result<Vec<InitiatorAction>> {
663 let (actions, _initiator) = run_stage_and_return_state(mempool, results).await?;
664 Ok(actions)
665 }
666
667 async fn run_stage_and_return_state(
668 mempool: Arc<dyn TxSubmissionMempool<Transaction>>,
669 results: Vec<InitiatorResult>,
670 ) -> anyhow::Result<(Vec<InitiatorAction>, TxSubmissionInitiator)> {
671 run_stage_and_return_state_with(
672 TxSubmissionInitiator::new(StageRef::named_for_tests("muxer")).1,
673 mempool,
674 results,
675 )
676 .await
677 }
678
679 async fn run_stage_and_return_state_with(
680 mut initiator: TxSubmissionInitiator,
681 mempool: Arc<dyn TxSubmissionMempool<Transaction>>,
682 results: Vec<InitiatorResult>,
683 ) -> anyhow::Result<(Vec<InitiatorAction>, TxSubmissionInitiator)> {
684 let mut actions = vec![];
685 for r in results {
686 let action = step(&mut initiator, r, mempool.as_ref()).await?;
687 if let Some(action) = action {
688 actions.push(action);
689 }
690 }
691 Ok((actions, initiator))
692 }
693
694 async fn step(
695 initiator: &mut TxSubmissionInitiator,
696 input: InitiatorResult,
697 mempool: &dyn TxSubmissionMempool<Transaction>,
698 ) -> anyhow::Result<Option<InitiatorAction>> {
699 let action = match input {
700 InitiatorResult::RequestTxIds { ack, req, blocking: Blocking::Yes } => {
701 initiator.request_tx_ids_blocking(mempool, ack, req).await?
702 }
703 InitiatorResult::RequestTxIds { ack, req, blocking: Blocking::No } => {
704 initiator.request_tx_ids_non_blocking(mempool, ack, req)?
705 }
706 InitiatorResult::RequestTxs(tx_ids) => initiator.request_txs(mempool, tx_ids)?,
707 };
708 Ok(action)
709 }
710
711 fn reply_tx_ids(txs: &[Transaction], ids: &[usize]) -> InitiatorAction {
712 let default_transaction_size = 49;
713 InitiatorAction::SendReplyTxIds(
714 ids.iter().map(|id| (TxId::from(&txs[*id]), default_transaction_size)).collect(),
715 )
716 }
717
718 fn reply_txs(txs: &[Transaction], ids: &[usize]) -> InitiatorAction {
719 InitiatorAction::SendReplyTxs(ids.iter().map(|id| txs[*id].clone()).collect())
720 }
721
722 fn request_tx_ids(ack: u16, req: u16, blocking: Blocking) -> InitiatorResult {
723 InitiatorResult::RequestTxIds { ack, req, blocking }
724 }
725
726 fn request_txs(txs: &[Transaction], ids: &[usize]) -> InitiatorResult {
727 InitiatorResult::RequestTxs(ids.iter().map(|id| TxId::from(&txs[*id])).collect())
728 }
729
730 fn error_action(error: ProtocolError) -> InitiatorAction {
731 InitiatorAction::Error(error)
732 }
733}