Skip to main content

amaru_protocols/blockfetch/
initiator.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{collections::VecDeque, mem, sync::Arc};
16
17use amaru_kernel::{EraHistory, IsHeader, Peer, Point, RawBlock, cardano::network_block::NetworkBlock};
18use amaru_ouroboros::ConnectionId;
19use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
20use tracing::instrument;
21
22use crate::{
23    blockfetch::{State, messages::Message, responder::MAX_FETCHED_BLOCKS},
24    mux::MuxMessage,
25    protocol::{
26        Initiator, Inputs, Miniprotocol, Outcome, PROTO_N2N_BLOCK_FETCH, ProtocolState, StageState, miniprotocol,
27        outcome,
28    },
29};
30
31pub fn register_deserializers() -> DeserializerGuards {
32    vec![
33        pure_stage::register_data_deserializer::<BlockFetchInitiator>().boxed(),
34        pure_stage::register_data_deserializer::<(State, BlockFetchInitiator)>().boxed(),
35        pure_stage::register_data_deserializer::<BlockFetchMessage>().boxed(),
36        pure_stage::register_data_deserializer::<Blocks>().boxed(),
37    ]
38}
39
40pub fn initiator() -> Miniprotocol<State, BlockFetchInitiator, Initiator> {
41    miniprotocol(PROTO_N2N_BLOCK_FETCH)
42}
43
44#[derive(Default, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
45pub struct Blocks {
46    pub blocks: Vec<NetworkBlock>,
47}
48
49impl std::fmt::Debug for Blocks {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        f.debug_struct("Blocks").field("blocks", &self.blocks.len()).finish()
52    }
53}
54
55/// Message that can be sent by an internal stage to request blocks for range of points.
56#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
57pub enum BlockFetchMessage {
58    RequestRange { from: Point, through: Point, cr: StageRef<Blocks> },
59}
60
61#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
62pub struct BlockFetchInitiator {
63    muxer: StageRef<MuxMessage>,
64    peer: Peer,
65    conn_id: ConnectionId,
66    /// Queue of requests that have been received but not yet answered.
67    ///
68    /// Note that the first two elements of the queue have already been sent
69    /// to the network (pipelining).
70    queue: VecDeque<(Point, Point, StageRef<Blocks>)>,
71    blocks: Vec<NetworkBlock>,
72    era_history: Arc<EraHistory>,
73}
74
75impl BlockFetchInitiator {
76    /// Create a new BlockFetchInitiator instance for a given peer, using a given connection.
77    /// Returns the initial state and the initiator instance.
78    ///
79    /// The EraHistory is needed to validate the received blocks slots and era tags.
80    pub fn new(
81        muxer: StageRef<MuxMessage>,
82        peer: Peer,
83        conn_id: ConnectionId,
84        era_history: Arc<EraHistory>,
85    ) -> (State, Self) {
86        (
87            State::Idle,
88            Self { muxer, peer, conn_id, queue: VecDeque::new(), blocks: Vec::new(), era_history: era_history.clone() },
89        )
90    }
91}
92
93/// Return true if the provided blocks form a valid chain from `from` to `through`.
94/// This includes checks that:
95/// - The first block matches the `from` point.
96/// - The last block matches the `through` point.
97/// - Each block's slot is strictly greater than the previous block's slot.
98/// - Each block's parent hash matches the hash of the previous block.
99#[expect(clippy::expect_used)]
100fn is_valid_block_range(
101    era_history: &EraHistory,
102    network_blocks: &[NetworkBlock],
103    from: Point,
104    through: Point,
105) -> bool {
106    assert!(!network_blocks.is_empty(), "some blocks should have been fetched from {from} to {through}");
107
108    // Extract headers from all blocks
109    let mut headers = Vec::with_capacity(network_blocks.len());
110    for (idx, network_block) in network_blocks.iter().enumerate() {
111        match network_block.decode_header() {
112            Ok(header) => {
113                if let Ok(expected_era_tag) = era_history.slot_to_era_tag(header.slot()) {
114                    if network_block.era_tag() == expected_era_tag {
115                        headers.push(header);
116                    } else {
117                        tracing::warn!(
118                            era_tag = %network_block.era_tag(),
119                            expected_era_tag = %expected_era_tag,
120                            slot = %header.slot(),
121                            "block slot does not map to expected era tag in range validation"
122                        );
123                        return false;
124                    }
125                } else {
126                    tracing::warn!(
127                        slot = %header.slot(),
128                        "the header slot should be in the era history"
129                    );
130                    return false;
131                }
132            }
133            Err(e) => {
134                tracing::warn!(
135                    block_index = idx,
136                    error = %e,
137                    "failed to extract header from block in range validation"
138                );
139                return false;
140            }
141        }
142    }
143
144    // Validate first block matches 'from' point
145    let first_point = headers.first().expect("non-empty headers").point();
146    if first_point != from {
147        tracing::debug!(
148            ?from,
149            actual = ?first_point,
150            "first block does not match 'from' point"
151        );
152        return false;
153    }
154
155    // Validate last block matches 'through' point
156    let last_point = headers.last().expect("non-empty headers").point();
157    if last_point != through {
158        tracing::debug!(
159            ?through,
160            actual = ?last_point,
161            "last block does not match 'through' point"
162        );
163        return false;
164    }
165
166    // Validate chain continuity: slots increase and parent hashes match
167    for window in headers.windows(2) {
168        let parent = &window[0];
169        let child = &window[1];
170
171        // Check slots are strictly increasing (gaps are OK)
172        if child.slot() <= parent.slot() {
173            tracing::debug!(
174                parent_point = ?parent.point(),
175                child_point = ?child.point(),
176                "blocks are not in ascending slot order"
177            );
178            return false;
179        }
180
181        // Check parent-child hash relationship
182        let expected_parent_hash = Some(parent.hash());
183        let actual_parent_hash = child.parent_hash();
184        if actual_parent_hash != expected_parent_hash {
185            tracing::debug!(
186                parent_hash = ?parent.hash(),
187                child_parent_hash = ?actual_parent_hash,
188                child_point = ?child.point(),
189                "child block's parent hash does not match previous block's hash"
190            );
191            return false;
192        }
193    }
194
195    true
196}
197
198impl StageState<State, Initiator> for BlockFetchInitiator {
199    type LocalIn = BlockFetchMessage;
200
201    async fn local(
202        mut self,
203        proto: &State,
204        input: Self::LocalIn,
205        _eff: &Effects<Inputs<Self::LocalIn>>,
206    ) -> anyhow::Result<(Option<InitiatorAction>, Self)> {
207        match input {
208            BlockFetchMessage::RequestRange { from, through, cr } => {
209                let action = (*proto == State::Idle).then_some(InitiatorAction::RequestRange { from, through });
210                self.queue.push_back((from, through, cr));
211                Ok((action, self))
212            }
213        }
214    }
215
216    #[instrument(name = "blockfetch.initiator.protocol", skip_all, fields(message_type = input.message_type()))]
217    #[expect(clippy::expect_used)]
218    async fn network(
219        mut self,
220        _proto: &State,
221        input: InitiatorResult,
222        eff: &Effects<Inputs<Self::LocalIn>>,
223    ) -> anyhow::Result<(Option<InitiatorAction>, Self)> {
224        let queued = match input {
225            InitiatorResult::Initialize => None,
226            InitiatorResult::NoBlocks => {
227                let (_, _, cr) = self.queue.pop_front().expect("queue must not be empty");
228                eff.send(&cr, Blocks { blocks: Vec::new() }).await;
229                self.queue.get(1)
230            }
231            InitiatorResult::Block(body) => {
232                if let Ok(network_block) = NetworkBlock::try_from(RawBlock::from(body.as_slice())) {
233                    if self.blocks.len() < MAX_FETCHED_BLOCKS {
234                        self.blocks.push(network_block);
235                    } else {
236                        tracing::warn!(
237                            "the responder sent more {MAX_FETCHED_BLOCKS} blocks; terminating the connection"
238                        );
239                        return eff.terminate().await;
240                    }
241                } else {
242                    tracing::warn!("received invalid block CBOR {}; terminating the connection", hex::encode(&body));
243                    return eff.terminate().await;
244                }
245                None
246            }
247            InitiatorResult::Done => {
248                let (from, through, cr) = self.queue.pop_front().expect("queue must not be empty");
249                let blocks = mem::take(&mut self.blocks);
250                if is_valid_block_range(self.era_history.as_ref(), &blocks, from, through) {
251                    eff.send(&cr, Blocks { blocks }).await;
252                } else {
253                    tracing::warn!(
254                        ?from,
255                        ?through,
256                        "received blocks do not form a valid range; terminating the connection"
257                    );
258                    return eff.terminate().await;
259                }
260                self.queue.get(1)
261            }
262        };
263        let action = queued.map(|(from, through, _)| InitiatorAction::RequestRange { from: *from, through: *through });
264        Ok((action, self))
265    }
266
267    fn muxer(&self) -> &StageRef<MuxMessage> {
268        &self.muxer
269    }
270}
271
272impl ProtocolState<Initiator> for State {
273    type WireMsg = Message;
274    type Action = InitiatorAction;
275    type Out = InitiatorResult;
276    type Error = Void;
277
278    fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
279        Ok((outcome().result(InitiatorResult::Initialize), *self))
280    }
281
282    #[instrument(name = "blockfetch.initiator.stage", skip_all, fields(message_type = input.message_type()))]
283    fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
284        use Message::*;
285        match (self, input) {
286            (Self::Busy, StartBatch) => Ok((outcome().want_next(), Self::Streaming)),
287            (Self::Busy, NoBlocks) => Ok((outcome().result(InitiatorResult::NoBlocks), Self::Idle)),
288            (Self::Streaming, Block { body }) => {
289                Ok((outcome().want_next().result(InitiatorResult::Block(body)), Self::Streaming))
290            }
291            (Self::Streaming, BatchDone) => Ok((outcome().result(InitiatorResult::Done), Self::Idle)),
292            (state, msg) => anyhow::bail!("unexpected message in state {:?}: {:?}", state, msg),
293        }
294    }
295
296    fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
297        use InitiatorAction::*;
298        match (self, input) {
299            (Self::Idle, RequestRange { from, through }) => {
300                Ok((outcome().send(Message::RequestRange { from, through }).want_next(), Self::Busy))
301            }
302            (Self::Idle, ClientDone) => Ok((outcome().send(Message::ClientDone), Self::Done)),
303            (state, action) => {
304                anyhow::bail!("unexpected action in state {:?}: {:?}", state, action)
305            }
306        }
307    }
308}
309
310/// Result of the initiator protocol step, to be used by the local stage.
311#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
312pub enum InitiatorResult {
313    Initialize,
314    NoBlocks,
315    Block(Vec<u8>),
316    Done,
317}
318
319impl InitiatorResult {
320    fn message_type(&self) -> &'static str {
321        match self {
322            Self::Initialize => "Initialize",
323            Self::NoBlocks => "NoBlocks",
324            Self::Block(_) => "Block",
325            Self::Done => "Done",
326        }
327    }
328}
329
330/// Outcome action of the local stage, to be used by the initiator protocol stage.
331#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
332pub enum InitiatorAction {
333    RequestRange { from: Point, through: Point },
334    ClientDone,
335}
336
337#[cfg(test)]
338pub mod tests {
339    use std::time::Duration;
340
341    use amaru_kernel::{
342        BlockHeader, Epoch, EraBound, EraName, EraParams, EraSummary, HeaderHash, IsHeader, Slot, any_headers_chain,
343        cbor, make_header, utils::tests::run_strategy,
344    };
345
346    use super::*;
347    use crate::protocol::Initiator;
348
349    #[test]
350    #[expect(clippy::wildcard_enum_match_arm)]
351    fn test_initiator_protocol() {
352        crate::blockfetch::spec::<Initiator>().check(State::Idle, |msg| match msg {
353            Message::RequestRange { from, through } => {
354                Some(InitiatorAction::RequestRange { from: *from, through: *through })
355            }
356            Message::ClientDone => Some(InitiatorAction::ClientDone),
357            _ => None,
358        });
359    }
360
361    #[test]
362    fn test_valid_block_range_single_block() {
363        let headers = run_strategy(any_headers_chain(1));
364        let blocks = vec![make_network_block(&headers[0])];
365
366        assert!(is_valid_block_range(&test_era_history(), &blocks, headers[0].point(), headers[0].point()));
367    }
368
369    #[test]
370    fn test_valid_block_range_consecutive_blocks() {
371        let headers = run_strategy(any_headers_chain(3));
372        let blocks =
373            vec![make_network_block(&headers[0]), make_network_block(&headers[1]), make_network_block(&headers[2])];
374
375        assert!(is_valid_block_range(&test_era_history(), &blocks, headers[0].point(), headers[2].point()));
376    }
377
378    #[test]
379    #[should_panic(expected = "some blocks should have been fetched")]
380    fn test_empty_blocks_with_equal_range() {
381        let headers = run_strategy(any_headers_chain(1));
382        is_valid_block_range(&test_era_history(), &[], headers[0].point(), headers[0].point());
383    }
384
385    #[test]
386    fn test_first_block_point_mismatch() {
387        // Create blocks where the first block doesn't match 'from'
388        let header1 = make_header(1, 100, None);
389        let block_header1 = BlockHeader::from(header1.clone());
390
391        let header2 = make_header(2, 101, Some(block_header1.hash()));
392        let block_header2 = BlockHeader::from(header2.clone());
393        let point2 = block_header2.point();
394
395        let blocks = vec![make_network_block(&block_header1), make_network_block(&block_header2)];
396
397        // Use a different 'from' point that doesn't match the first block
398        let wrong_from = Point::Specific(99u64.into(), HeaderHash::from([99u8; 32]));
399        assert!(!is_valid_block_range(&test_era_history(), &blocks, wrong_from, point2));
400    }
401
402    #[test]
403    fn test_last_block_point_mismatch() {
404        // Create blocks where the last block doesn't match 'through'
405        let header1 = make_header(1, 100, None);
406        let block_header1 = BlockHeader::from(header1.clone());
407        let point1 = block_header1.point();
408
409        let header2 = make_header(2, 101, Some(block_header1.hash()));
410        let block_header2 = BlockHeader::from(header2.clone());
411
412        let blocks = vec![make_network_block(&block_header1), make_network_block(&block_header2)];
413
414        // Use a different 'through' point that doesn't match the last block
415        let wrong_through = Point::Specific(102u64.into(), HeaderHash::from([102u8; 32]));
416        assert!(!is_valid_block_range(&test_era_history(), &blocks, point1, wrong_through));
417    }
418
419    #[test]
420    fn test_blocks_with_non_increasing_slots() {
421        // Create blocks where slots are not strictly increasing
422        let header1 = make_header(1, 100, None);
423        let block_header1 = BlockHeader::from(header1.clone());
424        let point1 = block_header1.point();
425
426        let header2 = make_header(2, 99, Some(block_header1.hash())); // Slot goes backward!
427        let block_header2 = BlockHeader::from(header2.clone());
428        let point2 = block_header2.point();
429
430        let blocks = vec![make_network_block(&block_header1), make_network_block(&block_header2)];
431
432        assert!(!is_valid_block_range(&test_era_history(), &blocks, point1, point2));
433    }
434
435    #[test]
436    fn test_blocks_with_equal_slots() {
437        // Create blocks where slots are equal (should fail)
438        let header1 = make_header(1, 100, None);
439        let block_header1 = BlockHeader::from(header1.clone());
440        let point1 = block_header1.point();
441
442        let header2 = make_header(2, 100, Some(block_header1.hash())); // Same slot!
443        let block_header2 = BlockHeader::from(header2.clone());
444        let point2 = block_header2.point();
445
446        let blocks = vec![make_network_block(&block_header1), make_network_block(&block_header2)];
447
448        assert!(!is_valid_block_range(&test_era_history(), &blocks, point1, point2));
449    }
450
451    #[test]
452    fn test_broken_parent_child_hash_chain() {
453        // Create blocks where the parent hash doesn't match
454        let header1 = make_header(1, 100, None);
455        let block_header1 = BlockHeader::from(header1.clone());
456        let point1 = block_header1.point();
457
458        // Create header2 with wrong parent hash (not matching block_header1's hash)
459        let wrong_parent_hash = HeaderHash::from([99u8; 32]);
460        let header2 = make_header(2, 101, Some(wrong_parent_hash));
461        let block_header2 = BlockHeader::from(header2.clone());
462        // Use the actual point from block_header2 so we test parent-child hash validation
463        let point2 = block_header2.point();
464
465        let blocks = vec![make_network_block(&block_header1), make_network_block(&block_header2)];
466
467        assert!(!is_valid_block_range(&test_era_history(), &blocks, point1, point2));
468    }
469
470    #[test]
471    fn test_invalid_cbor_in_block() {
472        // Create a valid first block and an invalid second block
473        let header1 = make_header(1, 100, None);
474        let block_header1 = BlockHeader::from(header1.clone());
475        let point1 = block_header1.point();
476
477        let blocks = vec![make_network_block(&block_header1), make_invalid_network_block()];
478
479        let point2 = Point::Specific(101u64.into(), HeaderHash::from([2u8; 32]));
480        assert!(!is_valid_block_range(&test_era_history(), &blocks, point1, point2));
481    }
482
483    // HELPERS
484
485    /// Create a simple era history for testing where all slots map to era index 0 (tag 1).
486    pub fn test_era_history() -> Arc<EraHistory> {
487        Arc::new(EraHistory::new(
488            &[EraSummary {
489                start: EraBound { time: Duration::from_secs(0), slot: Slot::from(0), epoch: Epoch::from(0) },
490                end: None,
491                params: EraParams::new(86400, Duration::from_secs(1), EraName::Conway).expect("valid era params"),
492            }],
493            Slot::from(2160 * 3),
494        ))
495    }
496
497    pub fn make_network_block(header: &BlockHeader) -> NetworkBlock {
498        NetworkBlock::try_from(make_raw_block(header)).expect("valid network block")
499    }
500
501    pub fn make_invalid_network_block() -> NetworkBlock {
502        let mut incomplete_bytes = Vec::new();
503        let mut encoder = cbor::Encoder::new(&mut incomplete_bytes);
504        encoder.array(2).expect("failed to encode array");
505        encoder.u16(1).expect("failed to encode tag");
506        encoder.array(1).expect("failed to encode inner array");
507        encoder.null().expect("failed to encode placeholder");
508        let raw_block = RawBlock::from(incomplete_bytes.as_slice());
509
510        // from raw block should work since the era tag is present
511        NetworkBlock::try_from(raw_block).unwrap()
512    }
513
514    pub fn make_raw_block(header: &BlockHeader) -> RawBlock {
515        let mut block_bytes = Vec::new();
516        let mut encoder = cbor::Encoder::new(&mut block_bytes);
517
518        // block format: [era_tag, [header, tx_bodies, witnesses, auxiliary_data?, invalid_transactions?]]
519        encoder.array(2).expect("failed to encode array");
520        let era_history = test_era_history();
521        let era_tag = era_history.slot_to_era_tag(header.slot()).unwrap();
522        encoder.encode(era_tag).expect("failed to encode tag");
523        encoder.array(5).expect("failed to encode inner array");
524        encoder.encode(header.header()).expect("failed to encode header");
525        encoder.array(0).expect("failed to encode tx bodies");
526        encoder.array(0).expect("failed to encode witnesses");
527        encoder.null().expect("failed to encode auxiliary data");
528        encoder.null().expect("failed to encode invalid txs");
529
530        RawBlock::from(block_bytes.as_slice())
531    }
532}