1use std::{collections::VecDeque, mem, sync::Arc};
16
17use amaru_kernel::{EraHistory, IsHeader, Peer, Point, RawBlock, cardano::network_block::NetworkBlock};
18use amaru_ouroboros::ConnectionId;
19use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
20use tracing::instrument;
21
22use crate::{
23 blockfetch::{State, messages::Message, responder::MAX_FETCHED_BLOCKS},
24 mux::MuxMessage,
25 protocol::{
26 Initiator, Inputs, Miniprotocol, Outcome, PROTO_N2N_BLOCK_FETCH, ProtocolState, StageState, miniprotocol,
27 outcome,
28 },
29};
30
31pub fn register_deserializers() -> DeserializerGuards {
32 vec![
33 pure_stage::register_data_deserializer::<BlockFetchInitiator>().boxed(),
34 pure_stage::register_data_deserializer::<(State, BlockFetchInitiator)>().boxed(),
35 pure_stage::register_data_deserializer::<BlockFetchMessage>().boxed(),
36 pure_stage::register_data_deserializer::<Blocks>().boxed(),
37 ]
38}
39
40pub fn initiator() -> Miniprotocol<State, BlockFetchInitiator, Initiator> {
41 miniprotocol(PROTO_N2N_BLOCK_FETCH)
42}
43
44#[derive(Default, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
45pub struct Blocks {
46 pub blocks: Vec<NetworkBlock>,
47}
48
49impl std::fmt::Debug for Blocks {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 f.debug_struct("Blocks").field("blocks", &self.blocks.len()).finish()
52 }
53}
54
55#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
57pub enum BlockFetchMessage {
58 RequestRange { from: Point, through: Point, cr: StageRef<Blocks> },
59}
60
61#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
62pub struct BlockFetchInitiator {
63 muxer: StageRef<MuxMessage>,
64 peer: Peer,
65 conn_id: ConnectionId,
66 queue: VecDeque<(Point, Point, StageRef<Blocks>)>,
71 blocks: Vec<NetworkBlock>,
72 era_history: Arc<EraHistory>,
73}
74
75impl BlockFetchInitiator {
76 pub fn new(
81 muxer: StageRef<MuxMessage>,
82 peer: Peer,
83 conn_id: ConnectionId,
84 era_history: Arc<EraHistory>,
85 ) -> (State, Self) {
86 (
87 State::Idle,
88 Self { muxer, peer, conn_id, queue: VecDeque::new(), blocks: Vec::new(), era_history: era_history.clone() },
89 )
90 }
91}
92
93#[expect(clippy::expect_used)]
100fn is_valid_block_range(
101 era_history: &EraHistory,
102 network_blocks: &[NetworkBlock],
103 from: Point,
104 through: Point,
105) -> bool {
106 assert!(!network_blocks.is_empty(), "some blocks should have been fetched from {from} to {through}");
107
108 let mut headers = Vec::with_capacity(network_blocks.len());
110 for (idx, network_block) in network_blocks.iter().enumerate() {
111 match network_block.decode_header() {
112 Ok(header) => {
113 if let Ok(expected_era_tag) = era_history.slot_to_era_tag(header.slot()) {
114 if network_block.era_tag() == expected_era_tag {
115 headers.push(header);
116 } else {
117 tracing::warn!(
118 era_tag = %network_block.era_tag(),
119 expected_era_tag = %expected_era_tag,
120 slot = %header.slot(),
121 "block slot does not map to expected era tag in range validation"
122 );
123 return false;
124 }
125 } else {
126 tracing::warn!(
127 slot = %header.slot(),
128 "the header slot should be in the era history"
129 );
130 return false;
131 }
132 }
133 Err(e) => {
134 tracing::warn!(
135 block_index = idx,
136 error = %e,
137 "failed to extract header from block in range validation"
138 );
139 return false;
140 }
141 }
142 }
143
144 let first_point = headers.first().expect("non-empty headers").point();
146 if first_point != from {
147 tracing::debug!(
148 ?from,
149 actual = ?first_point,
150 "first block does not match 'from' point"
151 );
152 return false;
153 }
154
155 let last_point = headers.last().expect("non-empty headers").point();
157 if last_point != through {
158 tracing::debug!(
159 ?through,
160 actual = ?last_point,
161 "last block does not match 'through' point"
162 );
163 return false;
164 }
165
166 for window in headers.windows(2) {
168 let parent = &window[0];
169 let child = &window[1];
170
171 if child.slot() <= parent.slot() {
173 tracing::debug!(
174 parent_point = ?parent.point(),
175 child_point = ?child.point(),
176 "blocks are not in ascending slot order"
177 );
178 return false;
179 }
180
181 let expected_parent_hash = Some(parent.hash());
183 let actual_parent_hash = child.parent_hash();
184 if actual_parent_hash != expected_parent_hash {
185 tracing::debug!(
186 parent_hash = ?parent.hash(),
187 child_parent_hash = ?actual_parent_hash,
188 child_point = ?child.point(),
189 "child block's parent hash does not match previous block's hash"
190 );
191 return false;
192 }
193 }
194
195 true
196}
197
198impl StageState<State, Initiator> for BlockFetchInitiator {
199 type LocalIn = BlockFetchMessage;
200
201 async fn local(
202 mut self,
203 proto: &State,
204 input: Self::LocalIn,
205 _eff: &Effects<Inputs<Self::LocalIn>>,
206 ) -> anyhow::Result<(Option<InitiatorAction>, Self)> {
207 match input {
208 BlockFetchMessage::RequestRange { from, through, cr } => {
209 let action = (*proto == State::Idle).then_some(InitiatorAction::RequestRange { from, through });
210 self.queue.push_back((from, through, cr));
211 Ok((action, self))
212 }
213 }
214 }
215
216 #[instrument(name = "blockfetch.initiator.protocol", skip_all, fields(message_type = input.message_type()))]
217 #[expect(clippy::expect_used)]
218 async fn network(
219 mut self,
220 _proto: &State,
221 input: InitiatorResult,
222 eff: &Effects<Inputs<Self::LocalIn>>,
223 ) -> anyhow::Result<(Option<InitiatorAction>, Self)> {
224 let queued = match input {
225 InitiatorResult::Initialize => None,
226 InitiatorResult::NoBlocks => {
227 let (_, _, cr) = self.queue.pop_front().expect("queue must not be empty");
228 eff.send(&cr, Blocks { blocks: Vec::new() }).await;
229 self.queue.get(1)
230 }
231 InitiatorResult::Block(body) => {
232 if let Ok(network_block) = NetworkBlock::try_from(RawBlock::from(body.as_slice())) {
233 if self.blocks.len() < MAX_FETCHED_BLOCKS {
234 self.blocks.push(network_block);
235 } else {
236 tracing::warn!(
237 "the responder sent more {MAX_FETCHED_BLOCKS} blocks; terminating the connection"
238 );
239 return eff.terminate().await;
240 }
241 } else {
242 tracing::warn!("received invalid block CBOR {}; terminating the connection", hex::encode(&body));
243 return eff.terminate().await;
244 }
245 None
246 }
247 InitiatorResult::Done => {
248 let (from, through, cr) = self.queue.pop_front().expect("queue must not be empty");
249 let blocks = mem::take(&mut self.blocks);
250 if is_valid_block_range(self.era_history.as_ref(), &blocks, from, through) {
251 eff.send(&cr, Blocks { blocks }).await;
252 } else {
253 tracing::warn!(
254 ?from,
255 ?through,
256 "received blocks do not form a valid range; terminating the connection"
257 );
258 return eff.terminate().await;
259 }
260 self.queue.get(1)
261 }
262 };
263 let action = queued.map(|(from, through, _)| InitiatorAction::RequestRange { from: *from, through: *through });
264 Ok((action, self))
265 }
266
267 fn muxer(&self) -> &StageRef<MuxMessage> {
268 &self.muxer
269 }
270}
271
272impl ProtocolState<Initiator> for State {
273 type WireMsg = Message;
274 type Action = InitiatorAction;
275 type Out = InitiatorResult;
276 type Error = Void;
277
278 fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
279 Ok((outcome().result(InitiatorResult::Initialize), *self))
280 }
281
282 #[instrument(name = "blockfetch.initiator.stage", skip_all, fields(message_type = input.message_type()))]
283 fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
284 use Message::*;
285 match (self, input) {
286 (Self::Busy, StartBatch) => Ok((outcome().want_next(), Self::Streaming)),
287 (Self::Busy, NoBlocks) => Ok((outcome().result(InitiatorResult::NoBlocks), Self::Idle)),
288 (Self::Streaming, Block { body }) => {
289 Ok((outcome().want_next().result(InitiatorResult::Block(body)), Self::Streaming))
290 }
291 (Self::Streaming, BatchDone) => Ok((outcome().result(InitiatorResult::Done), Self::Idle)),
292 (state, msg) => anyhow::bail!("unexpected message in state {:?}: {:?}", state, msg),
293 }
294 }
295
296 fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
297 use InitiatorAction::*;
298 match (self, input) {
299 (Self::Idle, RequestRange { from, through }) => {
300 Ok((outcome().send(Message::RequestRange { from, through }).want_next(), Self::Busy))
301 }
302 (Self::Idle, ClientDone) => Ok((outcome().send(Message::ClientDone), Self::Done)),
303 (state, action) => {
304 anyhow::bail!("unexpected action in state {:?}: {:?}", state, action)
305 }
306 }
307 }
308}
309
310#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
312pub enum InitiatorResult {
313 Initialize,
314 NoBlocks,
315 Block(Vec<u8>),
316 Done,
317}
318
319impl InitiatorResult {
320 fn message_type(&self) -> &'static str {
321 match self {
322 Self::Initialize => "Initialize",
323 Self::NoBlocks => "NoBlocks",
324 Self::Block(_) => "Block",
325 Self::Done => "Done",
326 }
327 }
328}
329
330#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
332pub enum InitiatorAction {
333 RequestRange { from: Point, through: Point },
334 ClientDone,
335}
336
337#[cfg(test)]
338pub mod tests {
339 use std::time::Duration;
340
341 use amaru_kernel::{
342 BlockHeader, Epoch, EraBound, EraName, EraParams, EraSummary, HeaderHash, IsHeader, Slot, any_headers_chain,
343 cbor, make_header, utils::tests::run_strategy,
344 };
345
346 use super::*;
347 use crate::protocol::Initiator;
348
349 #[test]
350 #[expect(clippy::wildcard_enum_match_arm)]
351 fn test_initiator_protocol() {
352 crate::blockfetch::spec::<Initiator>().check(State::Idle, |msg| match msg {
353 Message::RequestRange { from, through } => {
354 Some(InitiatorAction::RequestRange { from: *from, through: *through })
355 }
356 Message::ClientDone => Some(InitiatorAction::ClientDone),
357 _ => None,
358 });
359 }
360
361 #[test]
362 fn test_valid_block_range_single_block() {
363 let headers = run_strategy(any_headers_chain(1));
364 let blocks = vec![make_network_block(&headers[0])];
365
366 assert!(is_valid_block_range(&test_era_history(), &blocks, headers[0].point(), headers[0].point()));
367 }
368
369 #[test]
370 fn test_valid_block_range_consecutive_blocks() {
371 let headers = run_strategy(any_headers_chain(3));
372 let blocks =
373 vec![make_network_block(&headers[0]), make_network_block(&headers[1]), make_network_block(&headers[2])];
374
375 assert!(is_valid_block_range(&test_era_history(), &blocks, headers[0].point(), headers[2].point()));
376 }
377
378 #[test]
379 #[should_panic(expected = "some blocks should have been fetched")]
380 fn test_empty_blocks_with_equal_range() {
381 let headers = run_strategy(any_headers_chain(1));
382 is_valid_block_range(&test_era_history(), &[], headers[0].point(), headers[0].point());
383 }
384
385 #[test]
386 fn test_first_block_point_mismatch() {
387 let header1 = make_header(1, 100, None);
389 let block_header1 = BlockHeader::from(header1.clone());
390
391 let header2 = make_header(2, 101, Some(block_header1.hash()));
392 let block_header2 = BlockHeader::from(header2.clone());
393 let point2 = block_header2.point();
394
395 let blocks = vec![make_network_block(&block_header1), make_network_block(&block_header2)];
396
397 let wrong_from = Point::Specific(99u64.into(), HeaderHash::from([99u8; 32]));
399 assert!(!is_valid_block_range(&test_era_history(), &blocks, wrong_from, point2));
400 }
401
402 #[test]
403 fn test_last_block_point_mismatch() {
404 let header1 = make_header(1, 100, None);
406 let block_header1 = BlockHeader::from(header1.clone());
407 let point1 = block_header1.point();
408
409 let header2 = make_header(2, 101, Some(block_header1.hash()));
410 let block_header2 = BlockHeader::from(header2.clone());
411
412 let blocks = vec![make_network_block(&block_header1), make_network_block(&block_header2)];
413
414 let wrong_through = Point::Specific(102u64.into(), HeaderHash::from([102u8; 32]));
416 assert!(!is_valid_block_range(&test_era_history(), &blocks, point1, wrong_through));
417 }
418
419 #[test]
420 fn test_blocks_with_non_increasing_slots() {
421 let header1 = make_header(1, 100, None);
423 let block_header1 = BlockHeader::from(header1.clone());
424 let point1 = block_header1.point();
425
426 let header2 = make_header(2, 99, Some(block_header1.hash())); let block_header2 = BlockHeader::from(header2.clone());
428 let point2 = block_header2.point();
429
430 let blocks = vec![make_network_block(&block_header1), make_network_block(&block_header2)];
431
432 assert!(!is_valid_block_range(&test_era_history(), &blocks, point1, point2));
433 }
434
435 #[test]
436 fn test_blocks_with_equal_slots() {
437 let header1 = make_header(1, 100, None);
439 let block_header1 = BlockHeader::from(header1.clone());
440 let point1 = block_header1.point();
441
442 let header2 = make_header(2, 100, Some(block_header1.hash())); let block_header2 = BlockHeader::from(header2.clone());
444 let point2 = block_header2.point();
445
446 let blocks = vec![make_network_block(&block_header1), make_network_block(&block_header2)];
447
448 assert!(!is_valid_block_range(&test_era_history(), &blocks, point1, point2));
449 }
450
451 #[test]
452 fn test_broken_parent_child_hash_chain() {
453 let header1 = make_header(1, 100, None);
455 let block_header1 = BlockHeader::from(header1.clone());
456 let point1 = block_header1.point();
457
458 let wrong_parent_hash = HeaderHash::from([99u8; 32]);
460 let header2 = make_header(2, 101, Some(wrong_parent_hash));
461 let block_header2 = BlockHeader::from(header2.clone());
462 let point2 = block_header2.point();
464
465 let blocks = vec![make_network_block(&block_header1), make_network_block(&block_header2)];
466
467 assert!(!is_valid_block_range(&test_era_history(), &blocks, point1, point2));
468 }
469
470 #[test]
471 fn test_invalid_cbor_in_block() {
472 let header1 = make_header(1, 100, None);
474 let block_header1 = BlockHeader::from(header1.clone());
475 let point1 = block_header1.point();
476
477 let blocks = vec![make_network_block(&block_header1), make_invalid_network_block()];
478
479 let point2 = Point::Specific(101u64.into(), HeaderHash::from([2u8; 32]));
480 assert!(!is_valid_block_range(&test_era_history(), &blocks, point1, point2));
481 }
482
483 pub fn test_era_history() -> Arc<EraHistory> {
487 Arc::new(EraHistory::new(
488 &[EraSummary {
489 start: EraBound { time: Duration::from_secs(0), slot: Slot::from(0), epoch: Epoch::from(0) },
490 end: None,
491 params: EraParams::new(86400, Duration::from_secs(1), EraName::Conway).expect("valid era params"),
492 }],
493 Slot::from(2160 * 3),
494 ))
495 }
496
497 pub fn make_network_block(header: &BlockHeader) -> NetworkBlock {
498 NetworkBlock::try_from(make_raw_block(header)).expect("valid network block")
499 }
500
501 pub fn make_invalid_network_block() -> NetworkBlock {
502 let mut incomplete_bytes = Vec::new();
503 let mut encoder = cbor::Encoder::new(&mut incomplete_bytes);
504 encoder.array(2).expect("failed to encode array");
505 encoder.u16(1).expect("failed to encode tag");
506 encoder.array(1).expect("failed to encode inner array");
507 encoder.null().expect("failed to encode placeholder");
508 let raw_block = RawBlock::from(incomplete_bytes.as_slice());
509
510 NetworkBlock::try_from(raw_block).unwrap()
512 }
513
514 pub fn make_raw_block(header: &BlockHeader) -> RawBlock {
515 let mut block_bytes = Vec::new();
516 let mut encoder = cbor::Encoder::new(&mut block_bytes);
517
518 encoder.array(2).expect("failed to encode array");
520 let era_history = test_era_history();
521 let era_tag = era_history.slot_to_era_tag(header.slot()).unwrap();
522 encoder.encode(era_tag).expect("failed to encode tag");
523 encoder.array(5).expect("failed to encode inner array");
524 encoder.encode(header.header()).expect("failed to encode header");
525 encoder.array(0).expect("failed to encode tx bodies");
526 encoder.array(0).expect("failed to encode witnesses");
527 encoder.null().expect("failed to encode auxiliary data");
528 encoder.null().expect("failed to encode invalid txs");
529
530 RawBlock::from(block_bytes.as_slice())
531 }
532}