Skip to main content

amaru_network/
chain_sync_client.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use amaru_kernel::{Peer, Point};
16use amaru_observability::trace;
17use pallas_network::miniprotocols::chainsync::{Client, ClientError, HeaderContent, NextResponse};
18use pallas_traverse::MultiEraHeader;
19
20use crate::point::{from_network_point, to_network_point};
21
22pub type RawHeader = Vec<u8>;
23
24#[derive(Debug, thiserror::Error)]
25pub enum ChainSyncClientError {
26    #[error("Failed to decode header: {0}")]
27    HeaderDecodeError(String),
28    #[error("Network error: {0}")]
29    NetworkError(ClientError),
30    #[error("No intersection found for points: {points:?}")]
31    NoIntersectionFound { points: Vec<Point> },
32}
33
34pub fn to_traverse(header: &HeaderContent) -> Result<MultiEraHeader<'_>, ChainSyncClientError> {
35    let out = match header.byron_prefix {
36        Some((subtag, _)) => MultiEraHeader::decode(header.variant, Some(subtag), &header.cbor),
37        None => MultiEraHeader::decode(header.variant, None, &header.cbor),
38    };
39
40    out.map_err(|e| ChainSyncClientError::HeaderDecodeError(e.to_string()))
41}
42
43/// Handles chain synchronization network operations
44pub struct ChainSyncClient {
45    pub peer: Peer,
46    chain_sync: Client<HeaderContent>,
47    intersection: Vec<Point>,
48}
49
50impl ChainSyncClient {
51    pub fn new(peer: Peer, chain_sync: Client<HeaderContent>, intersection: Vec<Point>) -> Self {
52        Self { peer, chain_sync, intersection }
53    }
54
55    #[trace(amaru::network::chainsync_client::FIND_INTERSECTION,
56        peer = self.peer.name.clone(),
57        intersection_slot = u64::from(self.intersection.last().map(|p| p.slot_or_default()).unwrap_or_default())
58    )]
59    pub async fn find_intersection(&mut self) -> Result<Point, ChainSyncClientError> {
60        let client = &mut self.chain_sync;
61        let (point, _) = client
62            .find_intersect(self.intersection.iter().cloned().map(to_network_point).collect())
63            .await
64            .map_err(ChainSyncClientError::NetworkError)?;
65
66        let intersection =
67            point.ok_or(ChainSyncClientError::NoIntersectionFound { points: self.intersection.clone() })?;
68        Ok(from_network_point(&intersection))
69    }
70
71    pub fn intersection(&self) -> &[Point] {
72        &self.intersection
73    }
74
75    pub async fn request_next(&mut self) -> Result<NextResponse<HeaderContent>, ChainSyncClientError> {
76        let client = &mut self.chain_sync;
77
78        client
79            .request_next()
80            .await
81            .inspect_err(|err| tracing::error!(reason = %err, "request next failed"))
82            .map_err(ChainSyncClientError::NetworkError)
83    }
84
85    pub async fn await_next(&mut self) -> Result<NextResponse<HeaderContent>, ChainSyncClientError> {
86        let client = &mut self.chain_sync;
87
88        match client.recv_while_must_reply().await {
89            Ok(result) => Ok(result),
90            Err(err) => {
91                tracing::error!(reason = %err, "failed while awaiting for next block");
92                Err(ChainSyncClientError::NetworkError(err))
93            }
94        }
95    }
96
97    pub fn has_agency(&self) -> bool {
98        self.chain_sync.has_agency()
99    }
100}