Skip to main content

amaru_protocols/chainsync/
responder.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cmp::Reverse;
16
17use amaru_kernel::{BlockHeader, EraName, IsHeader, Peer, Point, Tip};
18use amaru_ouroboros::{ConnectionId, ReadOnlyChainStore};
19use anyhow::{Context, ensure};
20use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
21use tracing::instrument;
22
23use crate::{
24    chainsync::messages::{HeaderContent, Message},
25    mux::MuxMessage,
26    protocol::{
27        Inputs, Miniprotocol, Outcome, PROTO_N2N_CHAIN_SYNC, ProtocolState, Responder, StageState, miniprotocol,
28        outcome,
29    },
30    store_effects::Store,
31};
32
33pub fn register_deserializers() -> DeserializerGuards {
34    vec![
35        pure_stage::register_data_deserializer::<ResponderMessage>().boxed(),
36        pure_stage::register_data_deserializer::<(ResponderState, ChainSyncResponder)>().boxed(),
37        pure_stage::register_data_deserializer::<ChainSyncResponder>().boxed(),
38    ]
39}
40
41pub fn responder() -> Miniprotocol<ResponderState, ChainSyncResponder, Responder> {
42    miniprotocol(PROTO_N2N_CHAIN_SYNC.responder())
43}
44
45#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
46pub enum ResponderMessage {
47    NewTip(Tip),
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
51pub struct ChainSyncResponder {
52    upstream: Tip,
53    peer: Peer,
54    pointer: Point,
55    conn_id: ConnectionId,
56    muxer: StageRef<MuxMessage>,
57}
58
59impl ChainSyncResponder {
60    pub fn new(
61        upstream: Tip,
62        peer: Peer,
63        conn_id: ConnectionId,
64        muxer: StageRef<MuxMessage>,
65    ) -> (ResponderState, Self) {
66        (ResponderState::Idle { send_rollback: false }, Self { upstream, peer, pointer: Point::Origin, conn_id, muxer })
67    }
68}
69
70impl StageState<ResponderState, Responder> for ChainSyncResponder {
71    type LocalIn = ResponderMessage;
72
73    async fn local(
74        mut self,
75        proto: &ResponderState,
76        input: Self::LocalIn,
77        eff: &Effects<Inputs<Self::LocalIn>>,
78    ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
79        match input {
80            ResponderMessage::NewTip(tip) => {
81                tracing::trace!(%tip, "New tip");
82                self.upstream = tip;
83                let action = next_header(*proto, &mut self.pointer, &Store::new(eff.clone()), self.upstream)
84                    .context("failed to get next header")?;
85                Ok((action, self))
86            }
87        }
88    }
89
90    #[instrument(name = "chainsync.responder.stage", skip_all, fields(message_type = input.message_type()))]
91    async fn network(
92        mut self,
93        proto: &ResponderState,
94        input: ResponderResult,
95        eff: &Effects<Inputs<Self::LocalIn>>,
96    ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
97        match input {
98            ResponderResult::FindIntersect(points) => {
99                let action = intersect(points, &Store::new(eff.clone()), self.upstream)
100                    .context("failed to find intersection")?;
101                if let ResponderAction::IntersectFound(point, _tip) = &action {
102                    self.pointer = *point;
103                }
104                Ok((Some(action), self))
105            }
106            ResponderResult::RequestNext => {
107                let action = next_header(*proto, &mut self.pointer, &Store::new(eff.clone()), self.upstream)
108                    .context("failed to get next header")?;
109                Ok((action, self))
110            }
111            ResponderResult::Done => {
112                tracing::info!("peer stopped chainsync");
113                Ok((None, self))
114            }
115        }
116    }
117
118    fn muxer(&self) -> &StageRef<MuxMessage> {
119        &self.muxer
120    }
121}
122
123fn next_header(
124    state: ResponderState,
125    pointer: &mut Point,
126    store: &dyn ReadOnlyChainStore<BlockHeader>,
127    tip: Tip,
128) -> anyhow::Result<Option<ResponderAction>> {
129    match state {
130        ResponderState::CanAwait { send_rollback: true } => {
131            return Ok(Some(ResponderAction::RollBackward(*pointer, tip)));
132        }
133        ResponderState::MustReply | ResponderState::CanAwait { .. } => {}
134        ResponderState::Idle { .. } | ResponderState::Intersect | ResponderState::Done => {
135            return Ok(None);
136        }
137    };
138    if *pointer == tip.point() {
139        return Ok((matches!(state, ResponderState::CanAwait { .. })).then_some(ResponderAction::AwaitReply));
140    }
141
142    if store.load_from_best_chain(pointer).is_none() {
143        // client is on a different fork, we need to roll backward
144        let header = store.load_header(&pointer.hash()).ok_or_else(|| anyhow::anyhow!("remote pointer not found"))?;
145        for header in store.ancestors(header) {
146            if store.load_from_best_chain(&header.point()).is_some() {
147                *pointer = header.point();
148                return Ok(Some(ResponderAction::RollBackward(header.point(), tip)));
149            }
150        }
151        anyhow::bail!("no overlap found between client pointer chain and stored best chain");
152    }
153    // pointer is on the best chain, we need to roll forward
154    let Some(point) = store.next_best_chain(pointer) else {
155        return Ok(None);
156    };
157    let header =
158        store.load_header(&point.hash()).ok_or_else(|| anyhow::anyhow!("best-chain header not found: {}", point))?;
159    *pointer = point;
160    Ok(Some(ResponderAction::RollForward(HeaderContent::new(&header, EraName::Conway), tip)))
161}
162
163fn intersect(
164    mut points: Vec<Point>,
165    store: &dyn ReadOnlyChainStore<BlockHeader>,
166    tip: Tip,
167) -> anyhow::Result<ResponderAction> {
168    if points.is_empty() {
169        return Ok(ResponderAction::IntersectNotFound(tip));
170    }
171
172    points.sort_by_key(|p| Reverse(*p));
173
174    for point in &points {
175        if store.load_from_best_chain(point).is_some() {
176            return Ok(ResponderAction::IntersectFound(*point, tip));
177        }
178    }
179    Ok(ResponderAction::IntersectNotFound(tip))
180}
181
182#[derive(Debug, PartialEq, Eq)]
183pub enum ResponderAction {
184    IntersectFound(Point, Tip),
185    IntersectNotFound(Tip),
186    AwaitReply,
187    RollForward(HeaderContent, Tip),
188    RollBackward(Point, Tip),
189}
190
191#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
192pub enum ResponderResult {
193    FindIntersect(Vec<Point>),
194    RequestNext,
195    Done,
196}
197
198impl ResponderResult {
199    fn message_type(&self) -> &'static str {
200        match self {
201            ResponderResult::FindIntersect(_) => "FindIntersect",
202            ResponderResult::RequestNext => "RequestNext",
203            ResponderResult::Done => "Done",
204        }
205    }
206}
207
208#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, Ord, PartialOrd)]
209pub enum ResponderState {
210    Idle { send_rollback: bool },
211    CanAwait { send_rollback: bool },
212    MustReply,
213    Intersect,
214    Done,
215}
216
217impl ProtocolState<Responder> for ResponderState {
218    type WireMsg = Message;
219    type Action = ResponderAction;
220    type Out = ResponderResult;
221    type Error = Void;
222
223    fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
224        Ok((outcome().want_next(), *self))
225    }
226
227    #[instrument(name = "chainsync.responder.protocol", skip_all, fields(message_type = input.message_type()))]
228    fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
229        use ResponderState::*;
230
231        Ok(match (self, input) {
232            (Idle { .. }, Message::FindIntersect(points)) => {
233                (outcome().result(ResponderResult::FindIntersect(points)), Intersect)
234            }
235            (Idle { send_rollback }, Message::RequestNext(1)) => {
236                (outcome().result(ResponderResult::RequestNext), CanAwait { send_rollback: *send_rollback })
237            }
238            (Idle { .. }, Message::Done) => (outcome().result(ResponderResult::Done), Done),
239            (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
240        })
241    }
242
243    fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
244        use ResponderState::*;
245
246        Ok(match (self, input) {
247            (Intersect, ResponderAction::IntersectFound(point, tip)) => {
248                (outcome().send(Message::IntersectFound(point, tip)).want_next(), Idle { send_rollback: true })
249            }
250            (Intersect, ResponderAction::IntersectNotFound(tip)) => {
251                (outcome().send(Message::IntersectNotFound(tip)).want_next(), Idle { send_rollback: false })
252            }
253            (CanAwait { send_rollback }, ResponderAction::AwaitReply) => {
254                ensure!(!*send_rollback, "cannot AwaitReply after intersect");
255                (outcome().send(Message::AwaitReply), MustReply)
256            }
257            (CanAwait { send_rollback }, ResponderAction::RollForward(content, tip)) => {
258                ensure!(!*send_rollback, "cannot RollForward after intersect");
259                (outcome().send(Message::RollForward(content, tip)).want_next(), Idle { send_rollback: false })
260            }
261            (MustReply, ResponderAction::RollForward(content, tip)) => {
262                (outcome().send(Message::RollForward(content, tip)).want_next(), Idle { send_rollback: false })
263            }
264            (CanAwait { .. } | MustReply, ResponderAction::RollBackward(point, tip)) => {
265                (outcome().send(Message::RollBackward(point, tip)).want_next(), Idle { send_rollback: false })
266            }
267            (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
268        })
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use std::sync::Arc;
275
276    use amaru_kernel::{BlockHeader, Hash, Slot, make_header, size::HEADER};
277    use amaru_ouroboros_traits::{ChainStore, in_memory_consensus_store::InMemConsensusStore};
278
279    use super::*;
280    use crate::{chainsync::initiator::InitiatorState, protocol::ProtoSpec};
281
282    #[test]
283    fn intersect_finds_point_on_best_chain() {
284        let (store, points) = build_chain_store(10, 0);
285        let tip = make_tip(&points);
286
287        let result = intersect(vec![points[5]], store.as_ref(), tip).unwrap();
288        assert_eq!(result, ResponderAction::IntersectFound(points[5], tip));
289    }
290
291    #[test]
292    fn intersect_returns_most_recent_matching_point() {
293        let (store, points) = build_chain_store(10, 0);
294        let tip = make_tip(&points);
295
296        // points are sorted highest-first, so point[7] should be found first
297        let result = intersect(vec![points[3], points[7]], store.as_ref(), tip).unwrap();
298        assert_eq!(result, ResponderAction::IntersectFound(points[7], tip));
299    }
300
301    #[test]
302    fn intersect_finds_point_before_anchor() {
303        // Anchor at index 5, but point[2] is still on the best chain index
304        let (store, points) = build_chain_store(10, 5);
305        let tip = make_tip(&points);
306
307        let result = intersect(vec![points[2]], store.as_ref(), tip).unwrap();
308        assert_eq!(result, ResponderAction::IntersectFound(points[2], tip));
309    }
310
311    #[test]
312    fn intersect_not_found_with_empty_points() {
313        let (store, points) = build_chain_store(10, 0);
314        let tip = make_tip(&points);
315
316        let result = intersect(vec![], store.as_ref(), tip).unwrap();
317        assert_eq!(result, ResponderAction::IntersectNotFound(tip));
318    }
319
320    #[test]
321    fn intersect_not_found_with_unknown_points() {
322        let (store, points) = build_chain_store(10, 0);
323        let tip = make_tip(&points);
324
325        let unknown = Point::Specific(Slot::from(999), Hash::new([0xff; HEADER]));
326        let result = intersect(vec![unknown], store.as_ref(), tip).unwrap();
327        assert_eq!(result, ResponderAction::IntersectNotFound(tip));
328    }
329
330    #[expect(clippy::wildcard_enum_match_arm)]
331    #[test]
332    fn test_responder_protocol() {
333        use Message::{
334            AwaitReply, FindIntersect, IntersectFound, IntersectNotFound, RequestNext, RollBackward, RollForward,
335        };
336        use ResponderState::{CanAwait, Done, Idle, Intersect, MustReply};
337
338        // canonical states and messages
339        let idle = |send_rollback: bool| Idle { send_rollback };
340        let can_await = |send_rollback: bool| CanAwait { send_rollback };
341        let find_intersect = || FindIntersect(vec![Point::Origin]);
342        let intersect_found = || IntersectFound(Point::Origin, Tip::origin());
343        let intersect_not_found = || IntersectNotFound(Tip::origin());
344        let roll_forward = || RollForward(HeaderContent::with_bytes(vec![], EraName::Conway), Tip::origin());
345        let roll_backward = || RollBackward(Point::Origin, Tip::origin());
346
347        let mut spec = ProtoSpec::default();
348        spec.init(idle(false), find_intersect(), Intersect);
349        spec.init(idle(true), find_intersect(), Intersect);
350        spec.init(idle(false), RequestNext(1), can_await(false));
351        spec.init(idle(true), RequestNext(1), can_await(true));
352        spec.init(idle(false), Message::Done, Done);
353        spec.init(idle(true), Message::Done, Done);
354        spec.resp(Intersect, intersect_found(), idle(true));
355        spec.resp(Intersect, intersect_not_found(), idle(false));
356        spec.resp(can_await(false), AwaitReply, MustReply);
357        spec.resp(can_await(false), roll_forward(), idle(false));
358        spec.resp(can_await(false), roll_backward(), idle(false));
359        spec.resp(can_await(true), roll_backward(), idle(false));
360        spec.resp(MustReply, roll_forward(), idle(false));
361        spec.resp(MustReply, roll_backward(), idle(false));
362
363        spec.check(idle(false), |msg| match msg {
364            AwaitReply => Some(ResponderAction::AwaitReply),
365            RollForward(header_content, tip) => Some(ResponderAction::RollForward(header_content.clone(), *tip)),
366            RollBackward(point, tip) => Some(ResponderAction::RollBackward(*point, *tip)),
367            IntersectFound(point, tip) => Some(ResponderAction::IntersectFound(*point, *tip)),
368            IntersectNotFound(tip) => Some(ResponderAction::IntersectNotFound(*tip)),
369            _ => None,
370        });
371
372        spec.assert_refines(&super::super::initiator::tests::spec(), |state| match state {
373            Idle { .. } => InitiatorState::Idle,
374            CanAwait { .. } => InitiatorState::CanAwait(0),
375            MustReply => InitiatorState::MustReply(0),
376            Intersect => InitiatorState::Intersect,
377            Done => InitiatorState::Done,
378        });
379    }
380
381    // HELPERS
382
383    /// Build an in-memory chain store with `n` headers on the best chain,
384    /// and set the anchor at `anchor_index`.
385    fn build_chain_store(n: u64, anchor_index: u64) -> (Arc<InMemConsensusStore<BlockHeader>>, Vec<Point>) {
386        let store = Arc::new(InMemConsensusStore::new());
387        let mut points = Vec::new();
388        let mut prev_hash = None;
389
390        for slot in 0..n {
391            let header_raw = make_header(slot, slot, prev_hash);
392            let hash = Hash::new([slot as u8; HEADER]);
393            let header = BlockHeader::new(header_raw, hash);
394            store.store_header(&header).unwrap();
395            let point = Point::Specific(Slot::from(slot), hash);
396            store.roll_forward_chain(&point).unwrap();
397            points.push(point);
398            prev_hash = Some(hash);
399        }
400
401        store.set_anchor_hash(&points[anchor_index as usize].hash()).unwrap();
402        store.set_best_chain_hash(&points.last().unwrap().hash()).unwrap();
403        (store, points)
404    }
405
406    fn make_tip(points: &[Point]) -> Tip {
407        let last = points.last().unwrap();
408        Tip::new(*last, 0.into())
409    }
410}