1use std::cmp::Reverse;
16
17use amaru_kernel::{BlockHeader, EraName, IsHeader, Peer, Point, Tip};
18use amaru_ouroboros::{ConnectionId, ReadOnlyChainStore};
19use anyhow::{Context, ensure};
20use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
21use tracing::instrument;
22
23use crate::{
24 chainsync::messages::{HeaderContent, Message},
25 mux::MuxMessage,
26 protocol::{
27 Inputs, Miniprotocol, Outcome, PROTO_N2N_CHAIN_SYNC, ProtocolState, Responder, StageState, miniprotocol,
28 outcome,
29 },
30 store_effects::Store,
31};
32
33pub fn register_deserializers() -> DeserializerGuards {
34 vec![
35 pure_stage::register_data_deserializer::<ResponderMessage>().boxed(),
36 pure_stage::register_data_deserializer::<(ResponderState, ChainSyncResponder)>().boxed(),
37 pure_stage::register_data_deserializer::<ChainSyncResponder>().boxed(),
38 ]
39}
40
41pub fn responder() -> Miniprotocol<ResponderState, ChainSyncResponder, Responder> {
42 miniprotocol(PROTO_N2N_CHAIN_SYNC.responder())
43}
44
45#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
46pub enum ResponderMessage {
47 NewTip(Tip),
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
51pub struct ChainSyncResponder {
52 upstream: Tip,
53 peer: Peer,
54 pointer: Point,
55 conn_id: ConnectionId,
56 muxer: StageRef<MuxMessage>,
57}
58
59impl ChainSyncResponder {
60 pub fn new(
61 upstream: Tip,
62 peer: Peer,
63 conn_id: ConnectionId,
64 muxer: StageRef<MuxMessage>,
65 ) -> (ResponderState, Self) {
66 (ResponderState::Idle { send_rollback: false }, Self { upstream, peer, pointer: Point::Origin, conn_id, muxer })
67 }
68}
69
70impl StageState<ResponderState, Responder> for ChainSyncResponder {
71 type LocalIn = ResponderMessage;
72
73 async fn local(
74 mut self,
75 proto: &ResponderState,
76 input: Self::LocalIn,
77 eff: &Effects<Inputs<Self::LocalIn>>,
78 ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
79 match input {
80 ResponderMessage::NewTip(tip) => {
81 tracing::trace!(%tip, "New tip");
82 self.upstream = tip;
83 let action = next_header(*proto, &mut self.pointer, &Store::new(eff.clone()), self.upstream)
84 .context("failed to get next header")?;
85 Ok((action, self))
86 }
87 }
88 }
89
90 #[instrument(name = "chainsync.responder.stage", skip_all, fields(message_type = input.message_type()))]
91 async fn network(
92 mut self,
93 proto: &ResponderState,
94 input: ResponderResult,
95 eff: &Effects<Inputs<Self::LocalIn>>,
96 ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
97 match input {
98 ResponderResult::FindIntersect(points) => {
99 let action = intersect(points, &Store::new(eff.clone()), self.upstream)
100 .context("failed to find intersection")?;
101 if let ResponderAction::IntersectFound(point, _tip) = &action {
102 self.pointer = *point;
103 }
104 Ok((Some(action), self))
105 }
106 ResponderResult::RequestNext => {
107 let action = next_header(*proto, &mut self.pointer, &Store::new(eff.clone()), self.upstream)
108 .context("failed to get next header")?;
109 Ok((action, self))
110 }
111 ResponderResult::Done => {
112 tracing::info!("peer stopped chainsync");
113 Ok((None, self))
114 }
115 }
116 }
117
118 fn muxer(&self) -> &StageRef<MuxMessage> {
119 &self.muxer
120 }
121}
122
123fn next_header(
124 state: ResponderState,
125 pointer: &mut Point,
126 store: &dyn ReadOnlyChainStore<BlockHeader>,
127 tip: Tip,
128) -> anyhow::Result<Option<ResponderAction>> {
129 match state {
130 ResponderState::CanAwait { send_rollback: true } => {
131 return Ok(Some(ResponderAction::RollBackward(*pointer, tip)));
132 }
133 ResponderState::MustReply | ResponderState::CanAwait { .. } => {}
134 ResponderState::Idle { .. } | ResponderState::Intersect | ResponderState::Done => {
135 return Ok(None);
136 }
137 };
138 if *pointer == tip.point() {
139 return Ok((matches!(state, ResponderState::CanAwait { .. })).then_some(ResponderAction::AwaitReply));
140 }
141
142 if store.load_from_best_chain(pointer).is_none() {
143 let header = store.load_header(&pointer.hash()).ok_or_else(|| anyhow::anyhow!("remote pointer not found"))?;
145 for header in store.ancestors(header) {
146 if store.load_from_best_chain(&header.point()).is_some() {
147 *pointer = header.point();
148 return Ok(Some(ResponderAction::RollBackward(header.point(), tip)));
149 }
150 }
151 anyhow::bail!("no overlap found between client pointer chain and stored best chain");
152 }
153 let Some(point) = store.next_best_chain(pointer) else {
155 return Ok(None);
156 };
157 let header =
158 store.load_header(&point.hash()).ok_or_else(|| anyhow::anyhow!("best-chain header not found: {}", point))?;
159 *pointer = point;
160 Ok(Some(ResponderAction::RollForward(HeaderContent::new(&header, EraName::Conway), tip)))
161}
162
163fn intersect(
164 mut points: Vec<Point>,
165 store: &dyn ReadOnlyChainStore<BlockHeader>,
166 tip: Tip,
167) -> anyhow::Result<ResponderAction> {
168 if points.is_empty() {
169 return Ok(ResponderAction::IntersectNotFound(tip));
170 }
171
172 points.sort_by_key(|p| Reverse(*p));
173
174 for point in &points {
175 if store.load_from_best_chain(point).is_some() {
176 return Ok(ResponderAction::IntersectFound(*point, tip));
177 }
178 }
179 Ok(ResponderAction::IntersectNotFound(tip))
180}
181
182#[derive(Debug, PartialEq, Eq)]
183pub enum ResponderAction {
184 IntersectFound(Point, Tip),
185 IntersectNotFound(Tip),
186 AwaitReply,
187 RollForward(HeaderContent, Tip),
188 RollBackward(Point, Tip),
189}
190
191#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
192pub enum ResponderResult {
193 FindIntersect(Vec<Point>),
194 RequestNext,
195 Done,
196}
197
198impl ResponderResult {
199 fn message_type(&self) -> &'static str {
200 match self {
201 ResponderResult::FindIntersect(_) => "FindIntersect",
202 ResponderResult::RequestNext => "RequestNext",
203 ResponderResult::Done => "Done",
204 }
205 }
206}
207
208#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, Ord, PartialOrd)]
209pub enum ResponderState {
210 Idle { send_rollback: bool },
211 CanAwait { send_rollback: bool },
212 MustReply,
213 Intersect,
214 Done,
215}
216
217impl ProtocolState<Responder> for ResponderState {
218 type WireMsg = Message;
219 type Action = ResponderAction;
220 type Out = ResponderResult;
221 type Error = Void;
222
223 fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
224 Ok((outcome().want_next(), *self))
225 }
226
227 #[instrument(name = "chainsync.responder.protocol", skip_all, fields(message_type = input.message_type()))]
228 fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
229 use ResponderState::*;
230
231 Ok(match (self, input) {
232 (Idle { .. }, Message::FindIntersect(points)) => {
233 (outcome().result(ResponderResult::FindIntersect(points)), Intersect)
234 }
235 (Idle { send_rollback }, Message::RequestNext(1)) => {
236 (outcome().result(ResponderResult::RequestNext), CanAwait { send_rollback: *send_rollback })
237 }
238 (Idle { .. }, Message::Done) => (outcome().result(ResponderResult::Done), Done),
239 (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
240 })
241 }
242
243 fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
244 use ResponderState::*;
245
246 Ok(match (self, input) {
247 (Intersect, ResponderAction::IntersectFound(point, tip)) => {
248 (outcome().send(Message::IntersectFound(point, tip)).want_next(), Idle { send_rollback: true })
249 }
250 (Intersect, ResponderAction::IntersectNotFound(tip)) => {
251 (outcome().send(Message::IntersectNotFound(tip)).want_next(), Idle { send_rollback: false })
252 }
253 (CanAwait { send_rollback }, ResponderAction::AwaitReply) => {
254 ensure!(!*send_rollback, "cannot AwaitReply after intersect");
255 (outcome().send(Message::AwaitReply), MustReply)
256 }
257 (CanAwait { send_rollback }, ResponderAction::RollForward(content, tip)) => {
258 ensure!(!*send_rollback, "cannot RollForward after intersect");
259 (outcome().send(Message::RollForward(content, tip)).want_next(), Idle { send_rollback: false })
260 }
261 (MustReply, ResponderAction::RollForward(content, tip)) => {
262 (outcome().send(Message::RollForward(content, tip)).want_next(), Idle { send_rollback: false })
263 }
264 (CanAwait { .. } | MustReply, ResponderAction::RollBackward(point, tip)) => {
265 (outcome().send(Message::RollBackward(point, tip)).want_next(), Idle { send_rollback: false })
266 }
267 (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
268 })
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use std::sync::Arc;
275
276 use amaru_kernel::{BlockHeader, Hash, Slot, make_header, size::HEADER};
277 use amaru_ouroboros_traits::{ChainStore, in_memory_consensus_store::InMemConsensusStore};
278
279 use super::*;
280 use crate::{chainsync::initiator::InitiatorState, protocol::ProtoSpec};
281
282 #[test]
283 fn intersect_finds_point_on_best_chain() {
284 let (store, points) = build_chain_store(10, 0);
285 let tip = make_tip(&points);
286
287 let result = intersect(vec![points[5]], store.as_ref(), tip).unwrap();
288 assert_eq!(result, ResponderAction::IntersectFound(points[5], tip));
289 }
290
291 #[test]
292 fn intersect_returns_most_recent_matching_point() {
293 let (store, points) = build_chain_store(10, 0);
294 let tip = make_tip(&points);
295
296 let result = intersect(vec![points[3], points[7]], store.as_ref(), tip).unwrap();
298 assert_eq!(result, ResponderAction::IntersectFound(points[7], tip));
299 }
300
301 #[test]
302 fn intersect_finds_point_before_anchor() {
303 let (store, points) = build_chain_store(10, 5);
305 let tip = make_tip(&points);
306
307 let result = intersect(vec![points[2]], store.as_ref(), tip).unwrap();
308 assert_eq!(result, ResponderAction::IntersectFound(points[2], tip));
309 }
310
311 #[test]
312 fn intersect_not_found_with_empty_points() {
313 let (store, points) = build_chain_store(10, 0);
314 let tip = make_tip(&points);
315
316 let result = intersect(vec![], store.as_ref(), tip).unwrap();
317 assert_eq!(result, ResponderAction::IntersectNotFound(tip));
318 }
319
320 #[test]
321 fn intersect_not_found_with_unknown_points() {
322 let (store, points) = build_chain_store(10, 0);
323 let tip = make_tip(&points);
324
325 let unknown = Point::Specific(Slot::from(999), Hash::new([0xff; HEADER]));
326 let result = intersect(vec![unknown], store.as_ref(), tip).unwrap();
327 assert_eq!(result, ResponderAction::IntersectNotFound(tip));
328 }
329
330 #[expect(clippy::wildcard_enum_match_arm)]
331 #[test]
332 fn test_responder_protocol() {
333 use Message::{
334 AwaitReply, FindIntersect, IntersectFound, IntersectNotFound, RequestNext, RollBackward, RollForward,
335 };
336 use ResponderState::{CanAwait, Done, Idle, Intersect, MustReply};
337
338 let idle = |send_rollback: bool| Idle { send_rollback };
340 let can_await = |send_rollback: bool| CanAwait { send_rollback };
341 let find_intersect = || FindIntersect(vec![Point::Origin]);
342 let intersect_found = || IntersectFound(Point::Origin, Tip::origin());
343 let intersect_not_found = || IntersectNotFound(Tip::origin());
344 let roll_forward = || RollForward(HeaderContent::with_bytes(vec![], EraName::Conway), Tip::origin());
345 let roll_backward = || RollBackward(Point::Origin, Tip::origin());
346
347 let mut spec = ProtoSpec::default();
348 spec.init(idle(false), find_intersect(), Intersect);
349 spec.init(idle(true), find_intersect(), Intersect);
350 spec.init(idle(false), RequestNext(1), can_await(false));
351 spec.init(idle(true), RequestNext(1), can_await(true));
352 spec.init(idle(false), Message::Done, Done);
353 spec.init(idle(true), Message::Done, Done);
354 spec.resp(Intersect, intersect_found(), idle(true));
355 spec.resp(Intersect, intersect_not_found(), idle(false));
356 spec.resp(can_await(false), AwaitReply, MustReply);
357 spec.resp(can_await(false), roll_forward(), idle(false));
358 spec.resp(can_await(false), roll_backward(), idle(false));
359 spec.resp(can_await(true), roll_backward(), idle(false));
360 spec.resp(MustReply, roll_forward(), idle(false));
361 spec.resp(MustReply, roll_backward(), idle(false));
362
363 spec.check(idle(false), |msg| match msg {
364 AwaitReply => Some(ResponderAction::AwaitReply),
365 RollForward(header_content, tip) => Some(ResponderAction::RollForward(header_content.clone(), *tip)),
366 RollBackward(point, tip) => Some(ResponderAction::RollBackward(*point, *tip)),
367 IntersectFound(point, tip) => Some(ResponderAction::IntersectFound(*point, *tip)),
368 IntersectNotFound(tip) => Some(ResponderAction::IntersectNotFound(*tip)),
369 _ => None,
370 });
371
372 spec.assert_refines(&super::super::initiator::tests::spec(), |state| match state {
373 Idle { .. } => InitiatorState::Idle,
374 CanAwait { .. } => InitiatorState::CanAwait(0),
375 MustReply => InitiatorState::MustReply(0),
376 Intersect => InitiatorState::Intersect,
377 Done => InitiatorState::Done,
378 });
379 }
380
381 fn build_chain_store(n: u64, anchor_index: u64) -> (Arc<InMemConsensusStore<BlockHeader>>, Vec<Point>) {
386 let store = Arc::new(InMemConsensusStore::new());
387 let mut points = Vec::new();
388 let mut prev_hash = None;
389
390 for slot in 0..n {
391 let header_raw = make_header(slot, slot, prev_hash);
392 let hash = Hash::new([slot as u8; HEADER]);
393 let header = BlockHeader::new(header_raw, hash);
394 store.store_header(&header).unwrap();
395 let point = Point::Specific(Slot::from(slot), hash);
396 store.roll_forward_chain(&point).unwrap();
397 points.push(point);
398 prev_hash = Some(hash);
399 }
400
401 store.set_anchor_hash(&points[anchor_index as usize].hash()).unwrap();
402 store.set_best_chain_hash(&points.last().unwrap().hash()).unwrap();
403 (store, points)
404 }
405
406 fn make_tip(points: &[Point]) -> Tip {
407 let last = points.last().unwrap();
408 Tip::new(*last, 0.into())
409 }
410}