Skip to main content

amaru_protocols/blockfetch/
responder.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Debug;
16
17use amaru_kernel::{BlockHeader, IsHeader, NonEmptyVec, Point, RawBlock};
18use amaru_ouroboros_traits::ChainStore;
19use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
20use tracing::instrument;
21
22use crate::{
23    blockfetch::{State, messages::Message},
24    mux::MuxMessage,
25    protocol::{
26        Inputs, Miniprotocol, Outcome, PROTO_N2N_BLOCK_FETCH, ProtocolState, Responder, StageState, miniprotocol,
27        outcome,
28    },
29    store_effects::Store,
30};
31
32pub fn register_deserializers() -> DeserializerGuards {
33    vec![
34        pure_stage::register_data_deserializer::<BlockFetchResponder>().boxed(),
35        pure_stage::register_data_deserializer::<(State, BlockFetchResponder)>().boxed(),
36    ]
37}
38
39pub fn responder() -> Miniprotocol<State, BlockFetchResponder, Responder> {
40    miniprotocol(PROTO_N2N_BLOCK_FETCH.responder())
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
44pub struct BlockFetchResponder {
45    muxer: StageRef<MuxMessage>,
46}
47
48/// This data type represents a range of points to fetch blocks for.
49/// The points are ordered from the most recent to oldest and at least one point is present
50#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
51pub struct PointsRange(NonEmptyVec<Point>);
52
53/// Maximum number of blocks that can be streamed for a single request
54pub const MAX_FETCHED_BLOCKS: usize = 1000;
55
56impl PointsRange {
57    /// Create a points range with a single point
58    pub fn singleton(first: Point) -> PointsRange {
59        PointsRange(NonEmptyVec::singleton(first))
60    }
61
62    /// Create a points range from a vector of points.
63    pub fn from_vec(vec: Vec<Point>) -> Option<PointsRange> {
64        NonEmptyVec::try_from(vec).ok().map(PointsRange)
65    }
66
67    #[cfg(test)]
68    pub fn points(&self) -> Vec<Point> {
69        self.0.to_vec()
70    }
71
72    /// Load the first available block in the current range (the block is expected to be found).
73    /// Each time we attempt to fetch a block we pop its point from the current_range.
74    fn next_block(self, store: &dyn ChainStore<BlockHeader>) -> anyhow::Result<(RawBlock, Option<PointsRange>)> {
75        // points are stored from most recent to oldest, so we pop from the end
76        let (last, rest) = self.0.pop();
77        let last_hash = last.hash();
78        let stored_block =
79            store.load_block(&last_hash)?.ok_or_else(|| anyhow::anyhow!("block {} was pruned", last_hash))?;
80        Ok((stored_block, rest.map(PointsRange)))
81    }
82
83    /// Return a points range:
84    ///  - Check that `from` <= `through`
85    ///  - Check that there is a valid path of block from `from` to `through` in the chain store.
86    ///  - Check that we don't return too many headers to avoid getting over the protocol limits.
87    ///  - Return None if any of the above checks fail and return the points range otherwise.
88    pub fn request_range(
89        store: &dyn ChainStore<BlockHeader>,
90        from: Point,
91        through: Point,
92    ) -> anyhow::Result<Option<PointsRange>> {
93        // make sure that from <= through
94        if from > through {
95            tracing::debug!(%from, %through, "requested range is invalid: from > through");
96            return Ok(None);
97        };
98
99        if from == through {
100            return if store.load_block(&from.hash())?.is_some() {
101                Ok(Some(PointsRange::singleton(from)))
102            } else {
103                Ok(None)
104            };
105        }
106
107        let mut current_hash = through.hash();
108        let mut result = vec![];
109        loop {
110            if result.len() >= MAX_FETCHED_BLOCKS {
111                tracing::debug!(
112                    %from,
113                    %through,
114                    max_blocks = MAX_FETCHED_BLOCKS,
115                    "requested range exceeds maximum allowed blocks"
116                );
117                return Ok(None);
118            }
119            // check that the block exists
120            if store.load_block(&current_hash)?.is_none() {
121                return Ok(None);
122            }
123
124            // load the header for the current hash
125            if let Some(header) = store.load_header(&current_hash) {
126                result.push(header.point());
127                // if we found the from point, we're done
128                if current_hash == from.hash() {
129                    break;
130                }
131                // if we reached a slot before 'from', abort
132                if header.slot() < from.slot_or_default() {
133                    return Ok(None);
134                }
135                if let Some(parent_hash) = header.parent_hash() {
136                    current_hash = parent_hash
137                } else {
138                    return Ok(None);
139                }
140            } else {
141                return Ok(None);
142            }
143        }
144        Ok(PointsRange::from_vec(result))
145    }
146}
147
148impl BlockFetchResponder {
149    pub fn new(muxer: StageRef<MuxMessage>) -> (State, Self) {
150        (State::Idle, Self { muxer })
151    }
152}
153
154/// Local message for streaming blocks.
155#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
156pub enum StreamBlocks {
157    More(PointsRange),
158    Done,
159}
160
161impl StageState<State, Responder> for BlockFetchResponder {
162    type LocalIn = StreamBlocks;
163
164    async fn local(
165        self,
166        _proto: &State,
167        input: Self::LocalIn,
168        eff: &Effects<Inputs<Self::LocalIn>>,
169    ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
170        let store = Store::new(eff.clone());
171        match input {
172            StreamBlocks::Done => Ok((Some(ResponderAction::BatchDone), self)),
173            StreamBlocks::More(points_range) => {
174                let (block, points_range) = points_range.next_block(&store)?;
175                // recurse if there are more blocks to fetch or signal that streaming is done
176                if let Some(points_range) = points_range {
177                    eff.send(eff.me_ref(), Inputs::Local(StreamBlocks::More(points_range))).await;
178                } else {
179                    eff.send(eff.me_ref(), Inputs::Local(StreamBlocks::Done)).await;
180                }
181                Ok((Some(ResponderAction::Block(block)), self))
182            }
183        }
184    }
185
186    #[instrument(name = "blockfetch.responder.stage", skip_all, fields(message_type = input.message_type()))]
187    async fn network(
188        self,
189        _proto: &State,
190        input: ResponderResult,
191        eff: &Effects<Inputs<Self::LocalIn>>,
192    ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
193        match input {
194            ResponderResult::RequestRange { from, through } => {
195                let store = Store::new(eff.clone());
196                if let Some(points_range) = PointsRange::request_range(&store, from, through)? {
197                    eff.send(eff.me_ref(), Inputs::Local(StreamBlocks::More(points_range))).await;
198                    Ok((Some(ResponderAction::StartBatch), self))
199                } else {
200                    Ok((Some(ResponderAction::NoBlocks), self))
201                }
202            }
203            ResponderResult::Done => Ok((None, self)),
204        }
205    }
206
207    fn muxer(&self) -> &StageRef<MuxMessage> {
208        &self.muxer
209    }
210}
211
212impl ProtocolState<Responder> for State {
213    type WireMsg = Message;
214    type Action = ResponderAction;
215    type Out = ResponderResult;
216    type Error = Void;
217
218    fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
219        Ok((outcome().want_next(), *self))
220    }
221
222    #[instrument(name = "blockfetch.responder.protocol", skip_all, fields(message_type = input.message_type()))]
223    fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
224        use Message::*;
225        match (self, input) {
226            (Self::Idle, RequestRange { from, through }) => {
227                Ok((outcome().result(ResponderResult::RequestRange { from, through }), Self::Busy))
228            }
229            (Self::Idle, ClientDone) => Ok((outcome().want_next().result(ResponderResult::Done), Self::Done)),
230            (state, msg) => anyhow::bail!("unexpected message in state {:?}: {:?}", state, msg),
231        }
232    }
233
234    fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
235        use ResponderAction::*;
236        match (self, input) {
237            (Self::Busy, StartBatch) => Ok((outcome().send(Message::StartBatch), Self::Streaming)),
238            (Self::Busy, NoBlocks) => Ok((outcome().send(Message::NoBlocks).want_next(), Self::Idle)),
239            (Self::Streaming, Block(body)) => {
240                Ok((outcome().send(Message::Block { body: body.to_vec() }), Self::Streaming))
241            }
242            (Self::Streaming, BatchDone) => Ok((outcome().send(Message::BatchDone).want_next(), Self::Idle)),
243            (state, action) => {
244                anyhow::bail!("unexpected action in state {:?}: {:?}", state, action)
245            }
246        }
247    }
248}
249
250#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
251pub enum ResponderAction {
252    StartBatch,
253    NoBlocks,
254    Block(RawBlock),
255    BatchDone,
256}
257
258#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
259pub enum ResponderResult {
260    RequestRange { from: Point, through: Point },
261    Done,
262}
263
264impl ResponderResult {
265    pub fn message_type(&self) -> &'static str {
266        match self {
267            ResponderResult::RequestRange { .. } => "RequestRange",
268            ResponderResult::Done => "Done",
269        }
270    }
271}
272
273#[cfg(test)]
274pub mod tests {
275    use std::sync::Arc;
276
277    use amaru_kernel::{
278        BlockHeader, EraName, IsHeader, Slot, TESTNET_ERA_HISTORY, any_fake_header, any_headers_chain,
279        any_headers_chain_with_root,
280        cardano::network_block::{NetworkBlock, make_encoded_block},
281        utils::tests::run_strategy,
282    };
283    use amaru_ouroboros_traits::{ChainStore, in_memory_consensus_store::InMemConsensusStore};
284
285    use super::*;
286    use crate::protocol::Responder;
287
288    #[test]
289    #[expect(clippy::wildcard_enum_match_arm)]
290    fn test_responder_protocol() {
291        crate::blockfetch::spec::<Responder>().check(State::Idle, |msg| match msg {
292            Message::NoBlocks => Some(ResponderAction::NoBlocks),
293            Message::StartBatch => Some(ResponderAction::StartBatch),
294            Message::Block { body } => Some(ResponderAction::Block(RawBlock::from(body.as_slice()))),
295            Message::BatchDone => Some(ResponderAction::BatchDone),
296            _ => None,
297        });
298    }
299
300    #[test]
301    fn decode_network_block() {
302        let as_hex = "820785828a1a002cc8f51a04994d195820f27eddec5e782552e6ef408cff7c4a27e505fe54c20a717027d97e1c91da9d7c5820064effe4fa426184a911159fa803a9c1092459cd0b8f3e584ef9513955be0f5558201e5d0dcf77643d89a94353493859a21b47672015fb652b51f922617e4b27da8982584042d0edd71e6cac29e45f61eabbcce4f803f2ff78bce9fa295d11cb7c3cddb60f7694faaea787183fd604267d8114b57453493c963c7485405838cd79a261013a5850bc8672b4ff2db478e5b21364bfa9f0a2f5265e5ac56b261ce3dcb7ac57301a8362573eef2ae23eb2540915704534d1c0af8eace59a25c130629af7600b175b5e234b376961e2fd12b37de5213e8eff0304582029571d16f081709b3c48651860077bebf9340abb3fc7133443c54f1f5a5edcf1845820ee1d7c2bd6978e3bc8a47fc478424a9efd797f16813164db292320e3728f6de5091902465840f69f8974108be5df23dd0dad2f0e888e5c1702c35c678f3b7a2802f272666ea8a7c9b9f6e786e761d4cb747159d68b7d8f43bceae6ab4e543795d8aded59c302820a005901c06063a37f6f01765b34bceb2651e40a69e3bc31b35fd6c952415175844132250cdcbafd19c39952f471f7318a5cc3e45f54dadc9067bb6d25dac8b76f0bea5106c2f45235fac710d3e78d259af37fd617ed9e372626c5b080359ba1bf5150df764365e0faedfe66ab7e338f7aec558e0a192f4f744b473fbe669013ade2cd144c7742c3ff1d78002af59b0f1b45807bce21f592d23596c54d37095b52a8f942c763f5f014aa161fc18123054a618e8ecb9256c392c3bebcb30e10b2c4bef64f4c3b0aea29a4378a53b6d061c9000b510c0bf76d87171fb357faeb54087718fea0ee33e048d4a1aa8a831f7f9148ebbbb2d79f58c61268e1e1369ae88e2369e65e57169cc477726944790423f9dee584fb9eceeee79a447c075ada7bceb6a28699f0721415d3d0ab8f20b77410bc5faf296ce126cb73b9aaab208b9844d95d127ccaefac37c323cc1957aad3350c2d176916593aa854be50e7c36857adcf51800d490ce082908c5a1aceb8fd51fffc67abaf2c09c1f957bc2e009b8a76394402211eac5ff26c2e5d69aa2c6f4a0e4f2ac28c1482b4706916a0c876d56952b1db18af64658f6249db7fe7e7e366fd2a0f869472d38edb6145404f556025ea0066228080a080";
303        let bytes = hex::decode(as_hex).expect("valid hex");
304        let network_block: NetworkBlock = minicbor::decode(&bytes).expect("a valid network block");
305        assert_eq!(network_block.era_tag(), EraName::Conway);
306    }
307
308    #[test]
309    fn test_request_range_invalid_from_greater_than_through() {
310        let (store, headers) = make_store_with_chain(5);
311        let result = PointsRange::request_range(&*store, headers[3].point(), headers[1].point()).unwrap();
312        assert_eq!(result, None, "should return None when from > through");
313    }
314
315    #[test]
316    fn test_request_range_single_point_block_exists() {
317        let (store, headers) = make_store_with_chain(3);
318        store_blocks(store.clone(), &headers[1..2]);
319
320        let result = PointsRange::request_range(&*store, headers[1].point(), headers[1].point()).unwrap();
321        assert_eq!(result, Some(PointsRange::singleton(headers[1].point())));
322    }
323
324    #[test]
325    fn test_request_range_single_point_block_missing() {
326        let (store, headers) = make_store_with_chain(3);
327        let result = PointsRange::request_range(&*store, headers[1].point(), headers[1].point()).unwrap();
328        assert_eq!(result, None, "should return None when from == through but block doesn't exist");
329    }
330
331    #[test]
332    fn test_request_range_valid_chain() {
333        let (store, headers) = make_store_with_chain(5);
334        store_blocks(store.clone(), &headers);
335        let result = PointsRange::request_range(&*store, headers[0].point(), headers[4].point()).unwrap();
336        assert_eq!(
337            result,
338            PointsRange::from_vec(vec![
339                headers[4].point(),
340                headers[3].point(),
341                headers[2].point(),
342                headers[1].point(),
343                headers[0].point(),
344            ])
345        );
346    }
347
348    #[test]
349    fn test_request_range_missing_block_in_chain() {
350        let (store, headers) = make_store_with_chain(5);
351
352        // Store blocks for all headers except one in the middle
353        for (i, h) in headers.iter().enumerate() {
354            if i != 2 {
355                // Skip storing block for index 2
356                let raw_block = RawBlock::from(&[1u8, 2, 3][..]);
357                store.store_block(&h.hash(), &raw_block).unwrap();
358            }
359        }
360
361        let result = PointsRange::request_range(&*store, headers[0].point(), headers[4].point()).unwrap();
362        assert_eq!(result, None, "should return None when a block is missing in the chain");
363    }
364
365    #[test]
366    fn test_request_range_missing_header_in_chain() {
367        let headers: Vec<BlockHeader> = run_strategy(any_headers_chain(5));
368        let store = Arc::new(InMemConsensusStore::new());
369
370        // Set anchor to the first header
371        store.set_anchor_hash(&headers[0].hash()).unwrap();
372
373        // Store only some headers (skip one in the middle)
374        for (i, h) in headers.iter().enumerate() {
375            if i != 2 {
376                // Skip storing header for index 2
377                store.store_header(h).unwrap();
378                store.roll_forward_chain(&h.point()).unwrap();
379                store.set_best_chain_hash(&h.hash()).unwrap();
380                let raw_block = RawBlock::from(&[1u8, 2, 3][..]);
381                store.store_block(&h.hash(), &raw_block).unwrap();
382            }
383        }
384
385        let result = PointsRange::request_range(&*store, headers[0].point(), headers[4].point()).unwrap();
386        assert_eq!(result, None, "should return None when a header is missing in the chain");
387    }
388
389    #[test]
390    fn test_request_range_no_parent_hash_before_from() {
391        let genesis = Point::Specific(Slot::from(10), run_strategy(any_fake_header()).hash());
392        let (store, headers) = make_store_with_chain_starting_from(5, genesis);
393
394        let result = PointsRange::request_range(
395            &*store,
396            Point::Specific(Slot::from(2), run_strategy(any_fake_header()).hash()),
397            headers[3].point(),
398        )
399        .unwrap();
400        assert_eq!(result, None, "should return None when we hit genesis before finding from");
401    }
402
403    #[test]
404    fn test_request_range_slot_before_from_abort() {
405        // Create a chain with 5 headers
406        let (store, headers) = make_store_with_chain(5);
407        store_blocks(store.clone(), &headers);
408
409        // Create a 'from' point that has a slot within the chain range but with a non-existent hash.
410        // When traversing backwards from 'through', we'll pass the slot of 'from' without finding it,
411        // and then hit a block with a slot before 'from', triggering the abort condition.
412        let from_slot = headers[2].slot();
413        let non_existent_hash = run_strategy(any_fake_header()).hash();
414        let from = Point::Specific(from_slot, non_existent_hash);
415
416        let result = PointsRange::request_range(&*store, from, headers[4].point()).unwrap();
417        assert_eq!(result, None, "should return None when we reach a slot before 'from' without finding 'from'");
418    }
419
420    #[test]
421    fn test_request_range_exactly_max_blocks() {
422        // Create a chain longer than MAX_BLOCKS
423        let (store, headers) = make_store_with_chain(MAX_FETCHED_BLOCKS);
424        store_blocks(store.clone(), &headers);
425
426        let result =
427            PointsRange::request_range(&*store, headers[0].point(), headers[MAX_FETCHED_BLOCKS - 1].point()).unwrap();
428
429        assert_eq!(result.unwrap().points().len(), MAX_FETCHED_BLOCKS);
430    }
431
432    #[test]
433    fn test_request_range_max_blocks_limit() {
434        // Create a chain longer than MAX_BLOCKS
435        let chain_length = MAX_FETCHED_BLOCKS + 1;
436        let (store, headers) = make_store_with_chain(chain_length);
437        store_blocks(store.clone(), &headers);
438
439        let result =
440            PointsRange::request_range(&*store, headers[0].point(), headers[chain_length - 1].point()).unwrap();
441        assert_eq!(result, None, "should return None when the requested range exceeds MAX_BLOCKS limit");
442    }
443
444    #[test]
445    fn test_next_block_single_point() {
446        let (store, headers) = make_store_with_chain(3);
447        store_blocks(store.clone(), &headers);
448
449        let (block, remaining_range) = PointsRange::singleton(headers[1].point()).next_block(&*store).unwrap();
450
451        // Should return the block for the single point
452        let network_block: NetworkBlock = block.try_into().unwrap();
453        assert_eq!(network_block.decode_header().unwrap().point(), headers[1].point());
454
455        // Should have no remaining range
456        assert_eq!(remaining_range, None);
457    }
458
459    #[test]
460    fn test_next_block_multiple_points() {
461        let (store, headers) = make_store_with_chain(5);
462        store_blocks(store.clone(), &headers);
463
464        let (block, remaining_range) =
465            PointsRange::from_vec(vec![headers[2].point(), headers[1].point(), headers[0].point()])
466                .unwrap()
467                .next_block(&*store)
468                .unwrap();
469
470        // Should return the first block
471        let network_block: NetworkBlock = block.try_into().unwrap();
472        assert_eq!(network_block.decode_header().unwrap().point(), headers[0].point());
473
474        // Should have remaining points
475        assert_eq!(remaining_range, PointsRange::from_vec(vec![headers[2].point(), headers[1].point()]));
476    }
477
478    // HELPERS
479
480    fn make_store_with_chain(n: usize) -> (Arc<InMemConsensusStore<BlockHeader>>, Vec<BlockHeader>) {
481        make_store_with_chain_starting_from(n, Point::Origin)
482    }
483
484    fn make_store_with_chain_starting_from(
485        n: usize,
486        point: Point,
487    ) -> (Arc<InMemConsensusStore<BlockHeader>>, Vec<BlockHeader>) {
488        let headers: Vec<BlockHeader> = run_strategy(any_headers_chain_with_root(n, point));
489        let store = Arc::new(InMemConsensusStore::new());
490        // Set anchor to the first header
491        store.set_anchor_hash(&headers[0].hash()).unwrap();
492        for h in &headers {
493            store.store_header(h).unwrap();
494            store.roll_forward_chain(&h.point()).unwrap();
495            store.set_best_chain_hash(&h.hash()).unwrap();
496        }
497        (store, headers)
498    }
499
500    fn store_blocks(store: Arc<InMemConsensusStore<BlockHeader>>, headers: &[BlockHeader]) {
501        for h in headers {
502            let raw_block = make_encoded_block(h, &TESTNET_ERA_HISTORY);
503            store.store_block(&h.hash(), &raw_block).unwrap();
504        }
505    }
506}