1use std::{
16 collections::{BTreeSet, VecDeque},
17 fmt::Display,
18};
19
20use ProtocolError::*;
21use amaru_kernel::Transaction;
22use amaru_ouroboros::TxSubmissionMempool;
23use amaru_ouroboros_traits::{TxId, TxOrigin};
24use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
25use tracing::instrument;
26
27use crate::{
28 mempool_effects::MemoryPool,
29 mux::MuxMessage,
30 protocol::{
31 Inputs, Miniprotocol, Outcome, PROTO_N2N_TX_SUB, ProtocolState, Responder, StageState, miniprotocol, outcome,
32 },
33 tx_submission::{Blocking, Message, ProtocolError, ResponderParams, State},
34};
35
36pub fn register_deserializers() -> DeserializerGuards {
37 vec![
38 pure_stage::register_data_deserializer::<TxSubmissionResponder>().boxed(),
39 pure_stage::register_data_deserializer::<(State, TxSubmissionResponder)>().boxed(),
40 ]
41}
42
43pub fn responder() -> Miniprotocol<State, TxSubmissionResponder, Responder> {
44 miniprotocol(PROTO_N2N_TX_SUB.responder())
45}
46
47impl StageState<State, Responder> for TxSubmissionResponder {
48 type LocalIn = Void;
49
50 async fn local(
51 self,
52 _proto: &State,
53 input: Self::LocalIn,
54 _eff: &Effects<Inputs<Self::LocalIn>>,
55 ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
56 match input {}
57 }
58
59 #[instrument(name = "tx_submission.responder.stage", skip_all, fields(message_type = input.message_type()))]
60 async fn network(
61 mut self,
62 _proto: &State,
63 input: ResponderResult,
64 eff: &Effects<Inputs<Self::LocalIn>>,
65 ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
66 let mempool: &dyn TxSubmissionMempool<Transaction> = &MemoryPool::new(eff.clone());
67
68 let action = match input {
69 ResponderResult::Init => {
70 tracing::trace!("received Init");
71 self.initialize_state(mempool)
72 }
73 ResponderResult::ReplyTxIds(tx_ids) => self.process_tx_ids_reply(mempool, tx_ids)?,
74 ResponderResult::ReplyTxs(txs) => self.process_txs_reply(mempool, txs, self.origin.clone())?,
75 ResponderResult::Done => None,
76 };
77 Ok((action, self))
78 }
79
80 fn muxer(&self) -> &StageRef<MuxMessage> {
81 &self.muxer
82 }
83}
84
85impl ProtocolState<Responder> for State {
86 type WireMsg = Message;
87 type Action = ResponderAction;
88 type Out = ResponderResult;
89 type Error = ProtocolError;
90
91 fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
92 Ok((outcome().want_next(), *self))
94 }
95
96 #[instrument(name = "tx_submission.responder.protocol", skip_all, fields(message_type = input.message_type()))]
97 fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
98 Ok(match (self, input) {
99 (State::Init, Message::Init) => (outcome().result(ResponderResult::Init), State::Idle),
100 (State::TxIdsBlocking | State::TxIdsNonBlocking, Message::ReplyTxIds(tx_ids)) => {
101 (outcome().result(ResponderResult::ReplyTxIds(tx_ids)), State::Idle)
102 }
103 (State::Txs, Message::ReplyTxs(txs)) => (outcome().result(ResponderResult::ReplyTxs(txs)), State::Idle),
104 (State::TxIdsBlocking, Message::Done) => (outcome().result(ResponderResult::Done), State::Done),
105 (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
106 })
107 }
108
109 fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
110 Ok(match (self, input) {
111 (State::Idle, ResponderAction::SendRequestTxIds { ack, req, blocking }) => match blocking {
112 Blocking::Yes => {
113 (outcome().send(Message::RequestTxIdsBlocking(ack, req)).want_next(), State::TxIdsBlocking)
114 }
115 Blocking::No => {
116 (outcome().send(Message::RequestTxIdsNonBlocking(ack, req)).want_next(), State::TxIdsNonBlocking)
117 }
118 },
119 (State::Idle, ResponderAction::SendRequestTxs(tx_ids)) => {
120 (outcome().send(Message::RequestTxs(tx_ids)).want_next(), State::Txs)
121 }
122 (_, ResponderAction::Error(e)) => (outcome().terminate_with(e), State::Done),
123 (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
124 })
125 }
126}
127
128#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
130pub enum ResponderResult {
131 Init,
132 ReplyTxIds(Vec<(TxId, u32)>),
133 ReplyTxs(Vec<Transaction>),
134 Done,
135}
136
137impl ResponderResult {
138 pub fn message_type(&self) -> &str {
139 match self {
140 ResponderResult::Init => "Init",
141 ResponderResult::ReplyTxIds(_) => "ReplyTxIds",
142 ResponderResult::ReplyTxs(_) => "ReplyTxs",
143 ResponderResult::Done => "Done",
144 }
145 }
146}
147
148impl Display for ResponderResult {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 match self {
151 ResponderResult::Init => write!(f, "Init"),
152 ResponderResult::ReplyTxIds(tx_ids) => {
153 write!(f, "ReplyTxIds(len: {})", tx_ids.len())
154 }
155 ResponderResult::ReplyTxs(txs) => write!(f, "ReplyTxs(len: {})", txs.len()),
156 ResponderResult::Done => write!(f, "Done"),
157 }
158 }
159}
160
161#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
162pub struct TxSubmissionResponder {
163 params: ResponderParams,
165 window: VecDeque<(TxId, u32)>,
167 pending_fetch: VecDeque<TxId>,
169 inflight_fetch_set: BTreeSet<TxId>,
173 origin: TxOrigin,
175 muxer: StageRef<MuxMessage>,
176}
177
178impl TxSubmissionResponder {
179 pub fn new(muxer: StageRef<MuxMessage>, params: ResponderParams, origin: TxOrigin) -> (State, Self) {
180 (
181 State::Init,
182 Self {
183 params,
184 window: VecDeque::new(),
185 pending_fetch: VecDeque::new(),
186 inflight_fetch_set: BTreeSet::new(),
187 origin,
188 muxer,
189 },
190 )
191 }
192
193 fn initialize_state(&mut self, mempool: &dyn TxSubmissionMempool<Transaction>) -> Option<ResponderAction> {
194 let (ack, req, blocking) = self.request_tx_ids(mempool);
195 Some(ResponderAction::SendRequestTxIds { ack, req, blocking })
196 }
197
198 fn process_tx_ids_reply(
199 &mut self,
200 mempool: &dyn TxSubmissionMempool<Transaction>,
201 tx_ids: Vec<(TxId, u32)>,
202 ) -> anyhow::Result<Option<ResponderAction>> {
203 if self.window.len() + tx_ids.len() > self.params.max_window.into() {
204 return protocol_error(TooManyTxIdsReceived(
205 tx_ids.len(),
206 self.window.len(),
207 self.params.max_window.into(),
208 ));
209 }
210 self.received_tx_ids(mempool, tx_ids);
211
212 let txs = self.txs_to_request();
213 if txs.is_empty() {
214 let (ack, req, blocking) = self.request_tx_ids(mempool);
215 Ok(Some(ResponderAction::SendRequestTxIds { ack, req, blocking }))
216 } else {
217 Ok(Some(ResponderAction::SendRequestTxs(txs)))
218 }
219 }
220
221 fn process_txs_reply(
222 &mut self,
223 mempool: &dyn TxSubmissionMempool<Transaction>,
224 txs: Vec<Transaction>,
225 origin: TxOrigin,
226 ) -> anyhow::Result<Option<ResponderAction>> {
227 if txs.len() > self.params.fetch_batch.into() {
228 return protocol_error(ReceivedTxsExceedsBatchSize(txs.len(), self.params.fetch_batch.into()));
229 }
230
231 let tx_ids = txs.iter().map(TxId::from).collect::<BTreeSet<_>>();
233 if tx_ids.len() != txs.len() {
234 let tx_ids = txs.iter().map(TxId::from).collect::<Vec<_>>();
236 return protocol_error(DuplicateTxIds(tx_ids));
237 }
238
239 let not_in_flight =
241 tx_ids.iter().filter(|tx_id| !self.inflight_fetch_set.contains(tx_id)).cloned().collect::<Vec<_>>();
242 if !not_in_flight.is_empty() {
243 return protocol_error(SomeReceivedTxsNotInFlight(not_in_flight));
244 }
245
246 self.received_txs(mempool, txs, origin)?;
247 let (ack, req, blocking) = self.request_tx_ids(mempool);
248 Ok(Some(ResponderAction::SendRequestTxIds { ack, req, blocking }))
249 }
250
251 #[allow(clippy::expect_used)]
254 fn request_tx_ids(&mut self, mempool: &dyn TxSubmissionMempool<Transaction>) -> (u16, u16, Blocking) {
255 let mut ack = 0_u16;
257
258 while let Some((tx_id, _size)) = self.window.front() {
259 let already_in_mempool = mempool.contains(tx_id);
260 if already_in_mempool {
261 if self.window.pop_front().is_some() {
263 ack = ack.checked_add(1).expect("ack overflow: protocol invariant violated");
264 }
265 } else {
266 break;
267 }
268 }
269
270 let req = self
272 .params
273 .max_window
274 .checked_sub(self.window.len() as u16)
275 .expect("req underflow: protocol invariant violated");
276
277 let blocking = if self.window.is_empty() { Blocking::Yes } else { Blocking::No };
279 (ack, req, blocking)
280 }
281
282 fn received_tx_ids<Tx: Send + Sync + 'static>(
285 &mut self,
286 mempool: &dyn TxSubmissionMempool<Tx>,
287 tx_ids: Vec<(TxId, u32)>,
288 ) {
289 for (tx_id, size) in tx_ids {
290 self.window.push_back((tx_id, size));
292
293 if !mempool.contains(&tx_id) {
295 self.pending_fetch.push_back(tx_id);
296 }
297 }
298 }
299
300 fn txs_to_request(&mut self) -> Vec<TxId> {
302 let mut tx_ids = Vec::new();
303
304 while tx_ids.len() < self.params.fetch_batch.into() {
305 if let Some(id) = self.pending_fetch.pop_front() {
306 self.inflight_fetch_set.insert(id);
307 tx_ids.push(id);
308 } else {
309 break;
310 }
311 }
312
313 tx_ids
314 }
315
316 fn received_txs(
318 &mut self,
319 mempool: &dyn TxSubmissionMempool<Transaction>,
320 txs: Vec<Transaction>,
321 origin: TxOrigin,
322 ) -> anyhow::Result<()> {
323 for tx in txs {
324 let requested_id = TxId::from(&tx);
325 self.inflight_fetch_set.remove(&requested_id);
326 match mempool.validate_transaction(tx.clone()) {
327 Ok(_) => {
328 tracing::debug!("insert transaction {} into the mempool", requested_id);
329 mempool.insert(tx, origin.clone())?;
330 }
331 Err(e) => {
332 tracing::warn!("received invalid transaction {}: {}", requested_id, e);
333 }
334 }
335 }
336 Ok(())
337 }
338}
339
340fn protocol_error(error: ProtocolError) -> anyhow::Result<Option<ResponderAction>> {
341 tracing::warn!("protocol error: {error}");
342 Ok(Some(ResponderAction::Error(error)))
343}
344
345impl AsRef<StageRef<MuxMessage>> for TxSubmissionResponder {
346 fn as_ref(&self) -> &StageRef<MuxMessage> {
347 &self.muxer
348 }
349}
350
351#[derive(Debug, PartialEq, Eq)]
352pub enum ResponderAction {
353 SendRequestTxIds { ack: u16, req: u16, blocking: Blocking },
354 SendRequestTxs(Vec<TxId>),
355 Error(ProtocolError),
356}
357
358impl Display for ResponderAction {
359 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360 match self {
361 ResponderAction::SendRequestTxIds { ack, req, blocking } => {
362 write!(f, "SendRequestTxIds(ack: {}, req: {}, blocking: {:?})", ack, req, blocking)
363 }
364 ResponderAction::SendRequestTxs(tx_ids) => {
365 write!(f, "SendRequestTxs(tx_ids: {:?})", tx_ids)
366 }
367 ResponderAction::Error(err) => write!(f, "Error({})", err),
368 }
369 }
370}
371
372#[cfg(test)]
373mod tests {
374
375 use std::sync::Arc;
376
377 use amaru_kernel::Transaction;
378 use amaru_mempool::strategies::InMemoryMempool;
379
380 use super::*;
381 use crate::tx_submission::{assert_actions_eq, tests::create_transactions};
382
383 #[tokio::test]
384 async fn test_responder() -> anyhow::Result<()> {
385 let txs = create_transactions(6);
386
387 let mempool = Arc::new(InMemoryMempool::default());
390
391 let results = vec![
393 init(),
394 reply_tx_ids(&txs, &[0, 1, 2]),
395 reply_txs(&txs, &[0, 1]),
396 reply_tx_ids(&txs, &[3, 4, 5]),
397 reply_txs(&txs, &[2, 3]),
398 reply_tx_ids(&txs, &[]),
399 reply_txs(&txs, &[4, 5]),
400 done(),
401 ];
402
403 let actions = run_stage(mempool.clone(), results).await?;
404
405 assert_actions_eq(
406 &actions,
407 &[
408 request_tx_ids(0, 10, Blocking::Yes),
409 request_txs(&txs, &[0, 1]),
410 request_tx_ids(2, 9, Blocking::No),
411 request_txs(&txs, &[2, 3]),
412 request_tx_ids(2, 8, Blocking::No),
413 request_txs(&txs, &[4, 5]),
414 request_tx_ids(2, 10, Blocking::Yes),
415 ],
416 );
417 Ok(())
418 }
419
420 #[tokio::test]
421 async fn the_returned_tx_ids_should_respect_the_window_size() -> anyhow::Result<()> {
422 let txs = create_transactions(11);
423 let mempool = Arc::new(InMemoryMempool::default());
424
425 let results = vec![init(), reply_tx_ids(&txs, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])];
426
427 let actions = run_stage(mempool.clone(), results).await?;
428 assert_actions_eq(
429 &actions,
430 &[request_tx_ids(0, 10, Blocking::Yes), error_action(TooManyTxIdsReceived(11, 0, 10))],
431 );
432 Ok(())
433 }
434
435 #[tokio::test]
436 async fn the_returned_txs_should_respect_the_batch_size() -> anyhow::Result<()> {
437 let txs = create_transactions(6);
438 let mempool = Arc::new(InMemoryMempool::default());
439
440 let results = vec![
441 init(),
442 reply_tx_ids(&txs, &[0, 1, 2]),
443 reply_txs(&txs, &[0]),
444 reply_tx_ids(&txs, &[]),
445 reply_txs(&txs, &[1, 2, 3]),
446 ];
447
448 let outcomes = run_stage(mempool.clone(), results).await?;
449 assert_actions_eq(
450 &outcomes,
451 &[
452 request_tx_ids(0, 10, Blocking::Yes),
453 request_txs(&txs, &[0, 1]),
454 request_tx_ids(1, 8, Blocking::No),
455 request_txs(&txs, &[2]),
456 error_action(ReceivedTxsExceedsBatchSize(3, 2)),
457 ],
458 );
459 Ok(())
460 }
461
462 #[tokio::test]
463 async fn the_returned_txs_be_a_subset_of_the_inflight_txs() -> anyhow::Result<()> {
464 let txs = create_transactions(6);
465 let mempool = Arc::new(InMemoryMempool::default());
466
467 let results = vec![
468 init(),
469 reply_tx_ids(&txs, &[0, 1, 2]),
470 reply_txs(&txs, &[0]),
471 reply_tx_ids(&txs, &[]),
472 reply_txs(&txs, &[1, 3]),
473 ];
474
475 let actions = run_stage(mempool.clone(), results).await?;
476 assert_actions_eq(
477 &actions,
478 &[
479 request_tx_ids(0, 10, Blocking::Yes),
480 request_txs(&txs, &[0, 1]),
481 request_tx_ids(1, 8, Blocking::No),
482 request_txs(&txs, &[2]),
483 error_action(SomeReceivedTxsNotInFlight(vec![TxId::from(&txs[3])])),
484 ],
485 );
486 Ok(())
487 }
488
489 #[test]
490 fn test_responder_protocol() {
491 crate::tx_submission::spec::<Responder>().check(State::Init, |msg| match msg {
492 Message::RequestTxIdsBlocking(ack, req) => {
493 Some(ResponderAction::SendRequestTxIds { ack: *ack, req: *req, blocking: Blocking::Yes })
494 }
495 Message::RequestTxIdsNonBlocking(ack, req) => {
496 Some(ResponderAction::SendRequestTxIds { ack: *ack, req: *req, blocking: Blocking::No })
497 }
498 Message::RequestTxs(txs) => Some(ResponderAction::SendRequestTxs(txs.clone())),
499 Message::ReplyTxs(_) | Message::ReplyTxIds(_) | Message::Init | Message::Done => None,
500 });
501 }
502
503 async fn run_stage(
506 mempool: Arc<dyn TxSubmissionMempool<Transaction>>,
507 results: Vec<ResponderResult>,
508 ) -> anyhow::Result<Vec<ResponderAction>> {
509 let (actions, _responder) = run_stage_and_return_state(mempool, results).await?;
510 Ok(actions)
511 }
512
513 async fn run_stage_and_return_state(
514 mempool: Arc<dyn TxSubmissionMempool<Transaction>>,
515 results: Vec<ResponderResult>,
516 ) -> anyhow::Result<(Vec<ResponderAction>, TxSubmissionResponder)> {
517 run_stage_and_return_state_with(
518 TxSubmissionResponder::new(StageRef::named_for_tests("muxer"), ResponderParams::default(), TxOrigin::Local)
519 .1,
520 mempool,
521 results,
522 )
523 .await
524 }
525
526 async fn run_stage_and_return_state_with(
527 mut responder: TxSubmissionResponder,
528 mempool: Arc<dyn TxSubmissionMempool<Transaction>>,
529 results: Vec<ResponderResult>,
530 ) -> anyhow::Result<(Vec<ResponderAction>, TxSubmissionResponder)> {
531 let mut actions = vec![];
532 for r in results {
533 let action = match r {
534 ResponderResult::Init => responder.initialize_state(mempool.as_ref()),
535 ResponderResult::ReplyTxIds(tx_ids) => responder.process_tx_ids_reply(mempool.as_ref(), tx_ids)?,
536 ResponderResult::ReplyTxs(txs) => {
537 responder.process_txs_reply(mempool.as_ref(), txs, responder.origin.clone())?
538 }
539 ResponderResult::Done => None,
540 };
541 if let Some(action) = action {
542 actions.push(action)
543 };
544 }
545 Ok((actions, responder))
546 }
547 fn init() -> ResponderResult {
550 ResponderResult::Init
551 }
552
553 fn done() -> ResponderResult {
554 ResponderResult::Done
555 }
556
557 fn reply_tx_ids(txs: &[Transaction], ids: &[usize]) -> ResponderResult {
558 ResponderResult::ReplyTxIds(ids.iter().map(|id| (TxId::from(&txs[*id]), 50)).collect())
559 }
560
561 fn reply_txs(txs: &[Transaction], ids: &[usize]) -> ResponderResult {
562 ResponderResult::ReplyTxs(ids.iter().map(|id| txs[*id].clone()).collect())
563 }
564
565 fn request_tx_ids(ack: u16, req: u16, blocking: Blocking) -> ResponderAction {
566 ResponderAction::SendRequestTxIds { ack, req, blocking }
567 }
568
569 fn request_txs(txs: &[Transaction], ids: &[usize]) -> ResponderAction {
570 ResponderAction::SendRequestTxs(ids.iter().map(|id| TxId::from(&txs[*id])).collect())
571 }
572
573 fn error_action(error: ProtocolError) -> ResponderAction {
574 ResponderAction::Error(error)
575 }
576}