Skip to main content

amaru_protocols/chainsync/
initiator.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use amaru_kernel::{BlockHeader, ORIGIN_HASH, Peer, Point, Tip};
16use amaru_ouroboros::{ConnectionId, ReadOnlyChainStore};
17use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
18use tracing::instrument;
19
20use crate::{
21    chainsync::messages::{HeaderContent, Message},
22    mux::MuxMessage,
23    protocol::{
24        Initiator, Inputs, Miniprotocol, Outcome, PROTO_N2N_CHAIN_SYNC, ProtocolState, StageState, miniprotocol,
25        outcome,
26    },
27    store_effects::Store,
28};
29
30pub fn register_deserializers() -> DeserializerGuards {
31    vec![
32        pure_stage::register_data_deserializer::<InitiatorMessage>().boxed(),
33        pure_stage::register_data_deserializer::<(InitiatorState, ChainSyncInitiator)>().boxed(),
34        pure_stage::register_data_deserializer::<ChainSyncInitiatorMsg>().boxed(),
35        pure_stage::register_data_deserializer::<ChainSyncInitiator>().boxed(),
36    ]
37}
38
39pub fn initiator() -> Miniprotocol<InitiatorState, ChainSyncInitiator, Initiator> {
40    miniprotocol(PROTO_N2N_CHAIN_SYNC)
41}
42
43/// Message sent to the handler from the consensus pipeline
44#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
45pub enum InitiatorMessage {
46    RequestNext,
47    Done,
48}
49
50impl InitiatorMessage {
51    pub fn message_type(&self) -> &str {
52        match self {
53            InitiatorMessage::RequestNext => "RequestNext",
54            InitiatorMessage::Done => "Done",
55        }
56    }
57}
58
59/// Message sent from the handler to the consensus pipeline
60#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
61pub struct ChainSyncInitiatorMsg {
62    pub peer: Peer,
63    pub conn_id: ConnectionId,
64    pub handler: StageRef<InitiatorMessage>,
65    pub msg: InitiatorResult,
66}
67
68impl ChainSyncInitiatorMsg {
69    pub fn message_type(&self) -> &str {
70        match self.msg {
71            InitiatorResult::Initialize => "Initialize",
72            InitiatorResult::IntersectFound(_, _) => "IntersectFound",
73            InitiatorResult::IntersectNotFound(_) => "IntersectNotFound",
74            InitiatorResult::RollForward(_, _) => "RollForward",
75            InitiatorResult::RollBackward(_, _) => "RollBackward",
76        }
77    }
78}
79
80#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
81pub struct ChainSyncInitiator {
82    upstream: Option<Tip>,
83    peer: Peer,
84    conn_id: ConnectionId,
85    muxer: StageRef<MuxMessage>,
86    pipeline: StageRef<ChainSyncInitiatorMsg>,
87    me: StageRef<InitiatorMessage>,
88}
89
90impl ChainSyncInitiator {
91    pub fn new(
92        peer: Peer,
93        conn_id: ConnectionId,
94        muxer: StageRef<MuxMessage>,
95        pipeline: StageRef<ChainSyncInitiatorMsg>,
96    ) -> (InitiatorState, Self) {
97        (InitiatorState::Idle, Self { upstream: None, peer, conn_id, muxer, pipeline, me: StageRef::blackhole() })
98    }
99}
100
101impl StageState<InitiatorState, Initiator> for ChainSyncInitiator {
102    type LocalIn = InitiatorMessage;
103
104    async fn local(
105        self,
106        proto: &InitiatorState,
107        input: Self::LocalIn,
108        _eff: &Effects<Inputs<Self::LocalIn>>,
109    ) -> anyhow::Result<(Option<<InitiatorState as ProtocolState<Initiator>>::Action>, Self)> {
110        use InitiatorState::*;
111
112        Ok(match (proto, input) {
113            (Idle, InitiatorMessage::RequestNext) => (Some(InitiatorAction::RequestNext), self),
114            (CanAwait(_) | MustReply(_), InitiatorMessage::RequestNext) => (Some(InitiatorAction::RequestNext), self),
115            (Idle, InitiatorMessage::Done) => (Some(InitiatorAction::Done), self),
116            (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
117        })
118    }
119
120    #[instrument(name = "chainsync.initiator.stage", skip_all, fields(message_type = input.message_type()))]
121    async fn network(
122        mut self,
123        _proto: &InitiatorState,
124        input: <InitiatorState as ProtocolState<Initiator>>::Out,
125        eff: &Effects<Inputs<Self::LocalIn>>,
126    ) -> anyhow::Result<(Option<<InitiatorState as ProtocolState<Initiator>>::Action>, Self)> {
127        use InitiatorAction::*;
128        let action = match &input {
129            InitiatorResult::Initialize => {
130                self.me = eff.contramap(eff.me(), format!("{}-handler", eff.me().name()), Inputs::Local).await;
131                Some(Intersect(intersect_points(&Store::new(eff.clone()))))
132            }
133            InitiatorResult::IntersectFound(_, tip)
134            | InitiatorResult::IntersectNotFound(tip)
135            | InitiatorResult::RollForward(_, tip)
136            | InitiatorResult::RollBackward(_, tip) => {
137                self.upstream = Some(*tip);
138                None
139            }
140        };
141        eff.send(
142            &self.pipeline,
143            ChainSyncInitiatorMsg {
144                peer: self.peer.clone(),
145                conn_id: self.conn_id,
146                handler: self.me.clone(),
147                msg: input,
148            },
149        )
150        .await;
151        Ok((action, self))
152    }
153
154    fn muxer(&self) -> &StageRef<MuxMessage> {
155        &self.muxer
156    }
157}
158
159fn intersect_points(store: &dyn ReadOnlyChainStore<BlockHeader>) -> Vec<Point> {
160    let mut spacing = 1;
161    let mut points = Vec::new();
162    let best = store.get_best_chain_hash();
163    if best == ORIGIN_HASH {
164        return vec![Point::Origin];
165    }
166    #[expect(clippy::expect_used)]
167    let best = store.load_header(&best).expect("best chain hash is valid");
168    let best_point = best.tip().point();
169    points.push(best_point);
170
171    let mut last = best_point;
172    for (index, header) in store.ancestors(best).enumerate() {
173        last = header.tip().point();
174        if index + 1 == spacing {
175            points.push(last);
176            spacing *= 2;
177        }
178    }
179    if points.last() != Some(&last) {
180        points.push(last);
181    }
182    points
183}
184
185#[derive(Debug)]
186pub enum InitiatorAction {
187    Intersect(Vec<Point>),
188    RequestNext,
189    Done,
190}
191
192#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
193pub enum InitiatorResult {
194    Initialize,
195    IntersectFound(Point, Tip),
196    IntersectNotFound(Tip),
197    RollForward(HeaderContent, Tip),
198    RollBackward(Point, Tip),
199}
200
201impl InitiatorResult {
202    pub fn message_type(&self) -> &str {
203        match self {
204            InitiatorResult::Initialize => "Initialize",
205            InitiatorResult::IntersectFound(_, _) => "IntersectFound",
206            InitiatorResult::IntersectNotFound(_) => "IntersectNotFound",
207            InitiatorResult::RollForward(_, _) => "RollForward",
208            InitiatorResult::RollBackward(_, _) => "RollBackward",
209        }
210    }
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
214pub enum InitiatorState {
215    Idle,
216    CanAwait(u8),
217    MustReply(u8),
218    Intersect,
219    Done,
220}
221
222impl ProtocolState<Initiator> for InitiatorState {
223    type WireMsg = Message;
224    type Action = InitiatorAction;
225    type Out = InitiatorResult;
226    type Error = Void;
227
228    fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
229        Ok((outcome().result(InitiatorResult::Initialize), *self))
230    }
231
232    #[instrument(name = "chainsync.initiator.protocol", skip_all, fields(message_type = input.message_type()))]
233    fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
234        use InitiatorState::*;
235
236        Ok(match (self, input) {
237            (Intersect, Message::IntersectFound(point, tip)) => (
238                // only for this first time do we sent two requests
239                // this initiates the desired pipelining behaviour
240                outcome().send(Message::RequestNext(2)).want_next().result(InitiatorResult::IntersectFound(point, tip)),
241                CanAwait(1),
242            ),
243            (Intersect, Message::IntersectNotFound(tip)) => {
244                (outcome().result(InitiatorResult::IntersectNotFound(tip)), Idle)
245            }
246            (CanAwait(n), Message::AwaitReply) => (outcome().want_next(), MustReply(*n)),
247            (CanAwait(n) | MustReply(n), Message::RollForward(content, tip)) => (
248                outcome().result(InitiatorResult::RollForward(content, tip)),
249                if *n == 0 { Idle } else { CanAwait(*n - 1) },
250            ),
251            (CanAwait(n) | MustReply(n), Message::RollBackward(point, tip)) => (
252                outcome().result(InitiatorResult::RollBackward(point, tip)),
253                if *n == 0 { Idle } else { CanAwait(*n - 1) },
254            ),
255            (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
256        })
257    }
258
259    fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
260        use InitiatorState::*;
261
262        Ok(match (self, input) {
263            (Idle, InitiatorAction::Intersect(points)) => {
264                (outcome().send(Message::FindIntersect(points)).want_next(), Intersect)
265            }
266            (Idle, InitiatorAction::RequestNext) => (outcome().send(Message::RequestNext(1)).want_next(), CanAwait(0)),
267            (CanAwait(n), InitiatorAction::RequestNext) => {
268                (outcome().send(Message::RequestNext(1)).want_next(), CanAwait(*n + 1))
269            }
270            (MustReply(n), InitiatorAction::RequestNext) => {
271                (outcome().send(Message::RequestNext(1)).want_next(), MustReply(*n + 1))
272            }
273            (Idle, InitiatorAction::Done) => (outcome().send(Message::Done), Done),
274            (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
275        })
276    }
277}
278
279#[cfg(test)]
280#[expect(clippy::wildcard_enum_match_arm)]
281pub mod tests {
282    use InitiatorState::*;
283    use Message::*;
284    use amaru_kernel::{EraName, Hash, HeaderHash, RawBlock, Slot, make_header, size::HEADER};
285    use amaru_ouroboros_traits::{Nonces, StoreError};
286
287    use super::*;
288    use crate::protocol::ProtoSpec;
289
290    pub fn spec() -> ProtoSpec<InitiatorState, Message, Initiator> {
291        // canonical states and messages
292        let find_intersect = || FindIntersect(vec![Point::Origin]);
293        let intersect_found = || IntersectFound(Point::Origin, Tip::origin());
294        let intersect_not_found = || IntersectNotFound(Tip::origin());
295        let roll_forward = || RollForward(HeaderContent::with_bytes(vec![], EraName::Conway), Tip::origin());
296        let roll_backward = || RollBackward(Point::Origin, Tip::origin());
297
298        let mut spec = ProtoSpec::default();
299        spec.init(Idle, find_intersect(), Intersect);
300        spec.init(Idle, Message::Done, InitiatorState::Done);
301        spec.init(Idle, Message::RequestNext(1), CanAwait(0));
302        spec.resp(Intersect, intersect_found(), Idle);
303        spec.resp(Intersect, intersect_not_found(), Idle);
304        spec.resp(CanAwait(0), AwaitReply, MustReply(0));
305        spec.resp(CanAwait(0), roll_forward(), Idle);
306        spec.resp(CanAwait(0), roll_backward(), Idle);
307        spec.resp(MustReply(0), roll_forward(), Idle);
308        spec.resp(MustReply(0), roll_backward(), Idle);
309        spec
310    }
311
312    #[test]
313    #[ignore = "pipelining cannot be tested yet"]
314    fn test_initiator_protocol() {
315        spec().check(Idle, |msg| match msg {
316            FindIntersect(points) => Some(InitiatorAction::Intersect(points.clone())),
317            RequestNext(1) => Some(InitiatorAction::RequestNext),
318            Message::Done => Some(InitiatorAction::Done),
319            _ => None,
320        });
321    }
322
323    #[test]
324    fn test_intersect_points_includes_best_point_and_are_spaced_with_a_factor_2() {
325        let store = MockChainStoreForIntersectPoints::default();
326        let points = intersect_points(&store);
327        let slots = points.iter().map(|p| p.slot_or_default().into()).collect::<Vec<u64>>();
328        // The expected slots contain the best point (100) and the other points are spaced with a factor of 2.
329        assert_eq!(slots, vec![100, 99, 98, 96, 92, 84, 68, 36, 0]);
330    }
331
332    /// This chain store contains a chain of 100 blocks with slots from 0 to 100 where 100 is the best point.
333    #[derive(Debug)]
334    struct MockChainStoreForIntersectPoints {
335        best_point: Point,
336    }
337
338    impl Default for MockChainStoreForIntersectPoints {
339        fn default() -> Self {
340            Self { best_point: Point::Specific(Slot::from(100), Hash::new([100u8; HEADER])) }
341        }
342    }
343
344    #[expect(clippy::todo)]
345    impl ReadOnlyChainStore<BlockHeader> for MockChainStoreForIntersectPoints {
346        fn get_best_chain_hash(&self) -> HeaderHash {
347            self.best_point.hash()
348        }
349
350        fn load_header(&self, _hash: &HeaderHash) -> Option<BlockHeader> {
351            Some(BlockHeader::new(
352                make_header(1, self.best_point.slot_or_default().into(), None),
353                self.best_point.hash(),
354            ))
355        }
356
357        fn ancestors<'a>(&'a self, _from: BlockHeader) -> Box<dyn Iterator<Item = BlockHeader> + 'a>
358        where
359            BlockHeader: 'a,
360        {
361            let mut ancestor_block_headers = vec![];
362            for slot in 0..100 {
363                let header_hash = Hash::new([slot as u8; HEADER]);
364                let block_header = BlockHeader::new(make_header(1, slot, None), header_hash);
365                ancestor_block_headers.push(block_header);
366            }
367            ancestor_block_headers.reverse();
368            Box::new(ancestor_block_headers.into_iter())
369        }
370
371        fn get_children(&self, _hash: &HeaderHash) -> Vec<HeaderHash> {
372            todo!()
373        }
374
375        fn get_anchor_hash(&self) -> HeaderHash {
376            todo!()
377        }
378
379        fn load_from_best_chain(&self, _point: &Point) -> Option<HeaderHash> {
380            todo!()
381        }
382
383        fn next_best_chain(&self, _point: &Point) -> Option<Point> {
384            todo!()
385        }
386
387        fn load_block(&self, _hash: &HeaderHash) -> Result<Option<RawBlock>, StoreError> {
388            todo!()
389        }
390
391        fn get_nonces(&self, _header: &HeaderHash) -> Option<Nonces> {
392            todo!()
393        }
394
395        fn has_header(&self, _hash: &HeaderHash) -> bool {
396            todo!()
397        }
398    }
399}