1use std::fmt::Debug;
16
17use amaru_kernel::{BlockHeader, IsHeader, NonEmptyVec, Point, RawBlock};
18use amaru_ouroboros_traits::ChainStore;
19use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
20use tracing::instrument;
21
22use crate::{
23 blockfetch::{State, messages::Message},
24 mux::MuxMessage,
25 protocol::{
26 Inputs, Miniprotocol, Outcome, PROTO_N2N_BLOCK_FETCH, ProtocolState, Responder, StageState, miniprotocol,
27 outcome,
28 },
29 store_effects::Store,
30};
31
32pub fn register_deserializers() -> DeserializerGuards {
33 vec![
34 pure_stage::register_data_deserializer::<BlockFetchResponder>().boxed(),
35 pure_stage::register_data_deserializer::<(State, BlockFetchResponder)>().boxed(),
36 ]
37}
38
39pub fn responder() -> Miniprotocol<State, BlockFetchResponder, Responder> {
40 miniprotocol(PROTO_N2N_BLOCK_FETCH.responder())
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
44pub struct BlockFetchResponder {
45 muxer: StageRef<MuxMessage>,
46}
47
48#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
51pub struct PointsRange(NonEmptyVec<Point>);
52
53pub const MAX_FETCHED_BLOCKS: usize = 1000;
55
56impl PointsRange {
57 pub fn singleton(first: Point) -> PointsRange {
59 PointsRange(NonEmptyVec::singleton(first))
60 }
61
62 pub fn from_vec(vec: Vec<Point>) -> Option<PointsRange> {
64 NonEmptyVec::try_from(vec).ok().map(PointsRange)
65 }
66
67 #[cfg(test)]
68 pub fn points(&self) -> Vec<Point> {
69 self.0.to_vec()
70 }
71
72 fn next_block(self, store: &dyn ChainStore<BlockHeader>) -> anyhow::Result<(RawBlock, Option<PointsRange>)> {
75 let (last, rest) = self.0.pop();
77 let last_hash = last.hash();
78 let stored_block =
79 store.load_block(&last_hash)?.ok_or_else(|| anyhow::anyhow!("block {} was pruned", last_hash))?;
80 Ok((stored_block, rest.map(PointsRange)))
81 }
82
83 pub fn request_range(
89 store: &dyn ChainStore<BlockHeader>,
90 from: Point,
91 through: Point,
92 ) -> anyhow::Result<Option<PointsRange>> {
93 if from > through {
95 tracing::debug!(%from, %through, "requested range is invalid: from > through");
96 return Ok(None);
97 };
98
99 if from == through {
100 return if store.load_block(&from.hash())?.is_some() {
101 Ok(Some(PointsRange::singleton(from)))
102 } else {
103 Ok(None)
104 };
105 }
106
107 let mut current_hash = through.hash();
108 let mut result = vec![];
109 loop {
110 if result.len() >= MAX_FETCHED_BLOCKS {
111 tracing::debug!(
112 %from,
113 %through,
114 max_blocks = MAX_FETCHED_BLOCKS,
115 "requested range exceeds maximum allowed blocks"
116 );
117 return Ok(None);
118 }
119 if store.load_block(¤t_hash)?.is_none() {
121 return Ok(None);
122 }
123
124 if let Some(header) = store.load_header(¤t_hash) {
126 result.push(header.point());
127 if current_hash == from.hash() {
129 break;
130 }
131 if header.slot() < from.slot_or_default() {
133 return Ok(None);
134 }
135 if let Some(parent_hash) = header.parent_hash() {
136 current_hash = parent_hash
137 } else {
138 return Ok(None);
139 }
140 } else {
141 return Ok(None);
142 }
143 }
144 Ok(PointsRange::from_vec(result))
145 }
146}
147
148impl BlockFetchResponder {
149 pub fn new(muxer: StageRef<MuxMessage>) -> (State, Self) {
150 (State::Idle, Self { muxer })
151 }
152}
153
154#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
156pub enum StreamBlocks {
157 More(PointsRange),
158 Done,
159}
160
161impl StageState<State, Responder> for BlockFetchResponder {
162 type LocalIn = StreamBlocks;
163
164 async fn local(
165 self,
166 _proto: &State,
167 input: Self::LocalIn,
168 eff: &Effects<Inputs<Self::LocalIn>>,
169 ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
170 let store = Store::new(eff.clone());
171 match input {
172 StreamBlocks::Done => Ok((Some(ResponderAction::BatchDone), self)),
173 StreamBlocks::More(points_range) => {
174 let (block, points_range) = points_range.next_block(&store)?;
175 if let Some(points_range) = points_range {
177 eff.send(eff.me_ref(), Inputs::Local(StreamBlocks::More(points_range))).await;
178 } else {
179 eff.send(eff.me_ref(), Inputs::Local(StreamBlocks::Done)).await;
180 }
181 Ok((Some(ResponderAction::Block(block)), self))
182 }
183 }
184 }
185
186 #[instrument(name = "blockfetch.responder.stage", skip_all, fields(message_type = input.message_type()))]
187 async fn network(
188 self,
189 _proto: &State,
190 input: ResponderResult,
191 eff: &Effects<Inputs<Self::LocalIn>>,
192 ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
193 match input {
194 ResponderResult::RequestRange { from, through } => {
195 let store = Store::new(eff.clone());
196 if let Some(points_range) = PointsRange::request_range(&store, from, through)? {
197 eff.send(eff.me_ref(), Inputs::Local(StreamBlocks::More(points_range))).await;
198 Ok((Some(ResponderAction::StartBatch), self))
199 } else {
200 Ok((Some(ResponderAction::NoBlocks), self))
201 }
202 }
203 ResponderResult::Done => Ok((None, self)),
204 }
205 }
206
207 fn muxer(&self) -> &StageRef<MuxMessage> {
208 &self.muxer
209 }
210}
211
212impl ProtocolState<Responder> for State {
213 type WireMsg = Message;
214 type Action = ResponderAction;
215 type Out = ResponderResult;
216 type Error = Void;
217
218 fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
219 Ok((outcome().want_next(), *self))
220 }
221
222 #[instrument(name = "blockfetch.responder.protocol", skip_all, fields(message_type = input.message_type()))]
223 fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
224 use Message::*;
225 match (self, input) {
226 (Self::Idle, RequestRange { from, through }) => {
227 Ok((outcome().result(ResponderResult::RequestRange { from, through }), Self::Busy))
228 }
229 (Self::Idle, ClientDone) => Ok((outcome().want_next().result(ResponderResult::Done), Self::Done)),
230 (state, msg) => anyhow::bail!("unexpected message in state {:?}: {:?}", state, msg),
231 }
232 }
233
234 fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
235 use ResponderAction::*;
236 match (self, input) {
237 (Self::Busy, StartBatch) => Ok((outcome().send(Message::StartBatch), Self::Streaming)),
238 (Self::Busy, NoBlocks) => Ok((outcome().send(Message::NoBlocks).want_next(), Self::Idle)),
239 (Self::Streaming, Block(body)) => {
240 Ok((outcome().send(Message::Block { body: body.to_vec() }), Self::Streaming))
241 }
242 (Self::Streaming, BatchDone) => Ok((outcome().send(Message::BatchDone).want_next(), Self::Idle)),
243 (state, action) => {
244 anyhow::bail!("unexpected action in state {:?}: {:?}", state, action)
245 }
246 }
247 }
248}
249
250#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
251pub enum ResponderAction {
252 StartBatch,
253 NoBlocks,
254 Block(RawBlock),
255 BatchDone,
256}
257
258#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
259pub enum ResponderResult {
260 RequestRange { from: Point, through: Point },
261 Done,
262}
263
264impl ResponderResult {
265 pub fn message_type(&self) -> &'static str {
266 match self {
267 ResponderResult::RequestRange { .. } => "RequestRange",
268 ResponderResult::Done => "Done",
269 }
270 }
271}
272
273#[cfg(test)]
274pub mod tests {
275 use std::sync::Arc;
276
277 use amaru_kernel::{
278 BlockHeader, EraName, IsHeader, Slot, TESTNET_ERA_HISTORY, any_fake_header, any_headers_chain,
279 any_headers_chain_with_root,
280 cardano::network_block::{NetworkBlock, make_encoded_block},
281 utils::tests::run_strategy,
282 };
283 use amaru_ouroboros_traits::{ChainStore, in_memory_consensus_store::InMemConsensusStore};
284
285 use super::*;
286 use crate::protocol::Responder;
287
288 #[test]
289 #[expect(clippy::wildcard_enum_match_arm)]
290 fn test_responder_protocol() {
291 crate::blockfetch::spec::<Responder>().check(State::Idle, |msg| match msg {
292 Message::NoBlocks => Some(ResponderAction::NoBlocks),
293 Message::StartBatch => Some(ResponderAction::StartBatch),
294 Message::Block { body } => Some(ResponderAction::Block(RawBlock::from(body.as_slice()))),
295 Message::BatchDone => Some(ResponderAction::BatchDone),
296 _ => None,
297 });
298 }
299
300 #[test]
301 fn decode_network_block() {
302 let as_hex = "820785828a1a002cc8f51a04994d195820f27eddec5e782552e6ef408cff7c4a27e505fe54c20a717027d97e1c91da9d7c5820064effe4fa426184a911159fa803a9c1092459cd0b8f3e584ef9513955be0f5558201e5d0dcf77643d89a94353493859a21b47672015fb652b51f922617e4b27da8982584042d0edd71e6cac29e45f61eabbcce4f803f2ff78bce9fa295d11cb7c3cddb60f7694faaea787183fd604267d8114b57453493c963c7485405838cd79a261013a5850bc8672b4ff2db478e5b21364bfa9f0a2f5265e5ac56b261ce3dcb7ac57301a8362573eef2ae23eb2540915704534d1c0af8eace59a25c130629af7600b175b5e234b376961e2fd12b37de5213e8eff0304582029571d16f081709b3c48651860077bebf9340abb3fc7133443c54f1f5a5edcf1845820ee1d7c2bd6978e3bc8a47fc478424a9efd797f16813164db292320e3728f6de5091902465840f69f8974108be5df23dd0dad2f0e888e5c1702c35c678f3b7a2802f272666ea8a7c9b9f6e786e761d4cb747159d68b7d8f43bceae6ab4e543795d8aded59c302820a005901c06063a37f6f01765b34bceb2651e40a69e3bc31b35fd6c952415175844132250cdcbafd19c39952f471f7318a5cc3e45f54dadc9067bb6d25dac8b76f0bea5106c2f45235fac710d3e78d259af37fd617ed9e372626c5b080359ba1bf5150df764365e0faedfe66ab7e338f7aec558e0a192f4f744b473fbe669013ade2cd144c7742c3ff1d78002af59b0f1b45807bce21f592d23596c54d37095b52a8f942c763f5f014aa161fc18123054a618e8ecb9256c392c3bebcb30e10b2c4bef64f4c3b0aea29a4378a53b6d061c9000b510c0bf76d87171fb357faeb54087718fea0ee33e048d4a1aa8a831f7f9148ebbbb2d79f58c61268e1e1369ae88e2369e65e57169cc477726944790423f9dee584fb9eceeee79a447c075ada7bceb6a28699f0721415d3d0ab8f20b77410bc5faf296ce126cb73b9aaab208b9844d95d127ccaefac37c323cc1957aad3350c2d176916593aa854be50e7c36857adcf51800d490ce082908c5a1aceb8fd51fffc67abaf2c09c1f957bc2e009b8a76394402211eac5ff26c2e5d69aa2c6f4a0e4f2ac28c1482b4706916a0c876d56952b1db18af64658f6249db7fe7e7e366fd2a0f869472d38edb6145404f556025ea0066228080a080";
303 let bytes = hex::decode(as_hex).expect("valid hex");
304 let network_block: NetworkBlock = minicbor::decode(&bytes).expect("a valid network block");
305 assert_eq!(network_block.era_tag(), EraName::Conway);
306 }
307
308 #[test]
309 fn test_request_range_invalid_from_greater_than_through() {
310 let (store, headers) = make_store_with_chain(5);
311 let result = PointsRange::request_range(&*store, headers[3].point(), headers[1].point()).unwrap();
312 assert_eq!(result, None, "should return None when from > through");
313 }
314
315 #[test]
316 fn test_request_range_single_point_block_exists() {
317 let (store, headers) = make_store_with_chain(3);
318 store_blocks(store.clone(), &headers[1..2]);
319
320 let result = PointsRange::request_range(&*store, headers[1].point(), headers[1].point()).unwrap();
321 assert_eq!(result, Some(PointsRange::singleton(headers[1].point())));
322 }
323
324 #[test]
325 fn test_request_range_single_point_block_missing() {
326 let (store, headers) = make_store_with_chain(3);
327 let result = PointsRange::request_range(&*store, headers[1].point(), headers[1].point()).unwrap();
328 assert_eq!(result, None, "should return None when from == through but block doesn't exist");
329 }
330
331 #[test]
332 fn test_request_range_valid_chain() {
333 let (store, headers) = make_store_with_chain(5);
334 store_blocks(store.clone(), &headers);
335 let result = PointsRange::request_range(&*store, headers[0].point(), headers[4].point()).unwrap();
336 assert_eq!(
337 result,
338 PointsRange::from_vec(vec![
339 headers[4].point(),
340 headers[3].point(),
341 headers[2].point(),
342 headers[1].point(),
343 headers[0].point(),
344 ])
345 );
346 }
347
348 #[test]
349 fn test_request_range_missing_block_in_chain() {
350 let (store, headers) = make_store_with_chain(5);
351
352 for (i, h) in headers.iter().enumerate() {
354 if i != 2 {
355 let raw_block = RawBlock::from(&[1u8, 2, 3][..]);
357 store.store_block(&h.hash(), &raw_block).unwrap();
358 }
359 }
360
361 let result = PointsRange::request_range(&*store, headers[0].point(), headers[4].point()).unwrap();
362 assert_eq!(result, None, "should return None when a block is missing in the chain");
363 }
364
365 #[test]
366 fn test_request_range_missing_header_in_chain() {
367 let headers: Vec<BlockHeader> = run_strategy(any_headers_chain(5));
368 let store = Arc::new(InMemConsensusStore::new());
369
370 store.set_anchor_hash(&headers[0].hash()).unwrap();
372
373 for (i, h) in headers.iter().enumerate() {
375 if i != 2 {
376 store.store_header(h).unwrap();
378 store.roll_forward_chain(&h.point()).unwrap();
379 store.set_best_chain_hash(&h.hash()).unwrap();
380 let raw_block = RawBlock::from(&[1u8, 2, 3][..]);
381 store.store_block(&h.hash(), &raw_block).unwrap();
382 }
383 }
384
385 let result = PointsRange::request_range(&*store, headers[0].point(), headers[4].point()).unwrap();
386 assert_eq!(result, None, "should return None when a header is missing in the chain");
387 }
388
389 #[test]
390 fn test_request_range_no_parent_hash_before_from() {
391 let genesis = Point::Specific(Slot::from(10), run_strategy(any_fake_header()).hash());
392 let (store, headers) = make_store_with_chain_starting_from(5, genesis);
393
394 let result = PointsRange::request_range(
395 &*store,
396 Point::Specific(Slot::from(2), run_strategy(any_fake_header()).hash()),
397 headers[3].point(),
398 )
399 .unwrap();
400 assert_eq!(result, None, "should return None when we hit genesis before finding from");
401 }
402
403 #[test]
404 fn test_request_range_slot_before_from_abort() {
405 let (store, headers) = make_store_with_chain(5);
407 store_blocks(store.clone(), &headers);
408
409 let from_slot = headers[2].slot();
413 let non_existent_hash = run_strategy(any_fake_header()).hash();
414 let from = Point::Specific(from_slot, non_existent_hash);
415
416 let result = PointsRange::request_range(&*store, from, headers[4].point()).unwrap();
417 assert_eq!(result, None, "should return None when we reach a slot before 'from' without finding 'from'");
418 }
419
420 #[test]
421 fn test_request_range_exactly_max_blocks() {
422 let (store, headers) = make_store_with_chain(MAX_FETCHED_BLOCKS);
424 store_blocks(store.clone(), &headers);
425
426 let result =
427 PointsRange::request_range(&*store, headers[0].point(), headers[MAX_FETCHED_BLOCKS - 1].point()).unwrap();
428
429 assert_eq!(result.unwrap().points().len(), MAX_FETCHED_BLOCKS);
430 }
431
432 #[test]
433 fn test_request_range_max_blocks_limit() {
434 let chain_length = MAX_FETCHED_BLOCKS + 1;
436 let (store, headers) = make_store_with_chain(chain_length);
437 store_blocks(store.clone(), &headers);
438
439 let result =
440 PointsRange::request_range(&*store, headers[0].point(), headers[chain_length - 1].point()).unwrap();
441 assert_eq!(result, None, "should return None when the requested range exceeds MAX_BLOCKS limit");
442 }
443
444 #[test]
445 fn test_next_block_single_point() {
446 let (store, headers) = make_store_with_chain(3);
447 store_blocks(store.clone(), &headers);
448
449 let (block, remaining_range) = PointsRange::singleton(headers[1].point()).next_block(&*store).unwrap();
450
451 let network_block: NetworkBlock = block.try_into().unwrap();
453 assert_eq!(network_block.decode_header().unwrap().point(), headers[1].point());
454
455 assert_eq!(remaining_range, None);
457 }
458
459 #[test]
460 fn test_next_block_multiple_points() {
461 let (store, headers) = make_store_with_chain(5);
462 store_blocks(store.clone(), &headers);
463
464 let (block, remaining_range) =
465 PointsRange::from_vec(vec![headers[2].point(), headers[1].point(), headers[0].point()])
466 .unwrap()
467 .next_block(&*store)
468 .unwrap();
469
470 let network_block: NetworkBlock = block.try_into().unwrap();
472 assert_eq!(network_block.decode_header().unwrap().point(), headers[0].point());
473
474 assert_eq!(remaining_range, PointsRange::from_vec(vec![headers[2].point(), headers[1].point()]));
476 }
477
478 fn make_store_with_chain(n: usize) -> (Arc<InMemConsensusStore<BlockHeader>>, Vec<BlockHeader>) {
481 make_store_with_chain_starting_from(n, Point::Origin)
482 }
483
484 fn make_store_with_chain_starting_from(
485 n: usize,
486 point: Point,
487 ) -> (Arc<InMemConsensusStore<BlockHeader>>, Vec<BlockHeader>) {
488 let headers: Vec<BlockHeader> = run_strategy(any_headers_chain_with_root(n, point));
489 let store = Arc::new(InMemConsensusStore::new());
490 store.set_anchor_hash(&headers[0].hash()).unwrap();
492 for h in &headers {
493 store.store_header(h).unwrap();
494 store.roll_forward_chain(&h.point()).unwrap();
495 store.set_best_chain_hash(&h.hash()).unwrap();
496 }
497 (store, headers)
498 }
499
500 fn store_blocks(store: Arc<InMemConsensusStore<BlockHeader>>, headers: &[BlockHeader]) {
501 for h in headers {
502 let raw_block = make_encoded_block(h, &TESTNET_ERA_HISTORY);
503 store.store_block(&h.hash(), &raw_block).unwrap();
504 }
505 }
506}