1use amaru_kernel::{BlockHeader, ORIGIN_HASH, Peer, Point, Tip};
16use amaru_ouroboros::{ConnectionId, ReadOnlyChainStore};
17use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
18use tracing::instrument;
19
20use crate::{
21 chainsync::messages::{HeaderContent, Message},
22 mux::MuxMessage,
23 protocol::{
24 Initiator, Inputs, Miniprotocol, Outcome, PROTO_N2N_CHAIN_SYNC, ProtocolState, StageState, miniprotocol,
25 outcome,
26 },
27 store_effects::Store,
28};
29
30pub fn register_deserializers() -> DeserializerGuards {
31 vec![
32 pure_stage::register_data_deserializer::<InitiatorMessage>().boxed(),
33 pure_stage::register_data_deserializer::<(InitiatorState, ChainSyncInitiator)>().boxed(),
34 pure_stage::register_data_deserializer::<ChainSyncInitiatorMsg>().boxed(),
35 pure_stage::register_data_deserializer::<ChainSyncInitiator>().boxed(),
36 ]
37}
38
39pub fn initiator() -> Miniprotocol<InitiatorState, ChainSyncInitiator, Initiator> {
40 miniprotocol(PROTO_N2N_CHAIN_SYNC)
41}
42
43#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
45pub enum InitiatorMessage {
46 RequestNext,
47 Done,
48}
49
50impl InitiatorMessage {
51 pub fn message_type(&self) -> &str {
52 match self {
53 InitiatorMessage::RequestNext => "RequestNext",
54 InitiatorMessage::Done => "Done",
55 }
56 }
57}
58
59#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
61pub struct ChainSyncInitiatorMsg {
62 pub peer: Peer,
63 pub conn_id: ConnectionId,
64 pub handler: StageRef<InitiatorMessage>,
65 pub msg: InitiatorResult,
66}
67
68impl ChainSyncInitiatorMsg {
69 pub fn message_type(&self) -> &str {
70 match self.msg {
71 InitiatorResult::Initialize => "Initialize",
72 InitiatorResult::IntersectFound(_, _) => "IntersectFound",
73 InitiatorResult::IntersectNotFound(_) => "IntersectNotFound",
74 InitiatorResult::RollForward(_, _) => "RollForward",
75 InitiatorResult::RollBackward(_, _) => "RollBackward",
76 }
77 }
78}
79
80#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
81pub struct ChainSyncInitiator {
82 upstream: Option<Tip>,
83 peer: Peer,
84 conn_id: ConnectionId,
85 muxer: StageRef<MuxMessage>,
86 pipeline: StageRef<ChainSyncInitiatorMsg>,
87 me: StageRef<InitiatorMessage>,
88}
89
90impl ChainSyncInitiator {
91 pub fn new(
92 peer: Peer,
93 conn_id: ConnectionId,
94 muxer: StageRef<MuxMessage>,
95 pipeline: StageRef<ChainSyncInitiatorMsg>,
96 ) -> (InitiatorState, Self) {
97 (InitiatorState::Idle, Self { upstream: None, peer, conn_id, muxer, pipeline, me: StageRef::blackhole() })
98 }
99}
100
101impl StageState<InitiatorState, Initiator> for ChainSyncInitiator {
102 type LocalIn = InitiatorMessage;
103
104 async fn local(
105 self,
106 proto: &InitiatorState,
107 input: Self::LocalIn,
108 _eff: &Effects<Inputs<Self::LocalIn>>,
109 ) -> anyhow::Result<(Option<<InitiatorState as ProtocolState<Initiator>>::Action>, Self)> {
110 use InitiatorState::*;
111
112 Ok(match (proto, input) {
113 (Idle, InitiatorMessage::RequestNext) => (Some(InitiatorAction::RequestNext), self),
114 (CanAwait(_) | MustReply(_), InitiatorMessage::RequestNext) => (Some(InitiatorAction::RequestNext), self),
115 (Idle, InitiatorMessage::Done) => (Some(InitiatorAction::Done), self),
116 (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
117 })
118 }
119
120 #[instrument(name = "chainsync.initiator.stage", skip_all, fields(message_type = input.message_type()))]
121 async fn network(
122 mut self,
123 _proto: &InitiatorState,
124 input: <InitiatorState as ProtocolState<Initiator>>::Out,
125 eff: &Effects<Inputs<Self::LocalIn>>,
126 ) -> anyhow::Result<(Option<<InitiatorState as ProtocolState<Initiator>>::Action>, Self)> {
127 use InitiatorAction::*;
128 let action = match &input {
129 InitiatorResult::Initialize => {
130 self.me = eff.contramap(eff.me(), format!("{}-handler", eff.me().name()), Inputs::Local).await;
131 Some(Intersect(intersect_points(&Store::new(eff.clone()))))
132 }
133 InitiatorResult::IntersectFound(_, tip)
134 | InitiatorResult::IntersectNotFound(tip)
135 | InitiatorResult::RollForward(_, tip)
136 | InitiatorResult::RollBackward(_, tip) => {
137 self.upstream = Some(*tip);
138 None
139 }
140 };
141 eff.send(
142 &self.pipeline,
143 ChainSyncInitiatorMsg {
144 peer: self.peer.clone(),
145 conn_id: self.conn_id,
146 handler: self.me.clone(),
147 msg: input,
148 },
149 )
150 .await;
151 Ok((action, self))
152 }
153
154 fn muxer(&self) -> &StageRef<MuxMessage> {
155 &self.muxer
156 }
157}
158
159fn intersect_points(store: &dyn ReadOnlyChainStore<BlockHeader>) -> Vec<Point> {
160 let mut spacing = 1;
161 let mut points = Vec::new();
162 let best = store.get_best_chain_hash();
163 if best == ORIGIN_HASH {
164 return vec![Point::Origin];
165 }
166 #[expect(clippy::expect_used)]
167 let best = store.load_header(&best).expect("best chain hash is valid");
168 let best_point = best.tip().point();
169 points.push(best_point);
170
171 let mut last = best_point;
172 for (index, header) in store.ancestors(best).enumerate() {
173 last = header.tip().point();
174 if index + 1 == spacing {
175 points.push(last);
176 spacing *= 2;
177 }
178 }
179 if points.last() != Some(&last) {
180 points.push(last);
181 }
182 points
183}
184
185#[derive(Debug)]
186pub enum InitiatorAction {
187 Intersect(Vec<Point>),
188 RequestNext,
189 Done,
190}
191
192#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
193pub enum InitiatorResult {
194 Initialize,
195 IntersectFound(Point, Tip),
196 IntersectNotFound(Tip),
197 RollForward(HeaderContent, Tip),
198 RollBackward(Point, Tip),
199}
200
201impl InitiatorResult {
202 pub fn message_type(&self) -> &str {
203 match self {
204 InitiatorResult::Initialize => "Initialize",
205 InitiatorResult::IntersectFound(_, _) => "IntersectFound",
206 InitiatorResult::IntersectNotFound(_) => "IntersectNotFound",
207 InitiatorResult::RollForward(_, _) => "RollForward",
208 InitiatorResult::RollBackward(_, _) => "RollBackward",
209 }
210 }
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
214pub enum InitiatorState {
215 Idle,
216 CanAwait(u8),
217 MustReply(u8),
218 Intersect,
219 Done,
220}
221
222impl ProtocolState<Initiator> for InitiatorState {
223 type WireMsg = Message;
224 type Action = InitiatorAction;
225 type Out = InitiatorResult;
226 type Error = Void;
227
228 fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
229 Ok((outcome().result(InitiatorResult::Initialize), *self))
230 }
231
232 #[instrument(name = "chainsync.initiator.protocol", skip_all, fields(message_type = input.message_type()))]
233 fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
234 use InitiatorState::*;
235
236 Ok(match (self, input) {
237 (Intersect, Message::IntersectFound(point, tip)) => (
238 outcome().send(Message::RequestNext(2)).want_next().result(InitiatorResult::IntersectFound(point, tip)),
241 CanAwait(1),
242 ),
243 (Intersect, Message::IntersectNotFound(tip)) => {
244 (outcome().result(InitiatorResult::IntersectNotFound(tip)), Idle)
245 }
246 (CanAwait(n), Message::AwaitReply) => (outcome().want_next(), MustReply(*n)),
247 (CanAwait(n) | MustReply(n), Message::RollForward(content, tip)) => (
248 outcome().result(InitiatorResult::RollForward(content, tip)),
249 if *n == 0 { Idle } else { CanAwait(*n - 1) },
250 ),
251 (CanAwait(n) | MustReply(n), Message::RollBackward(point, tip)) => (
252 outcome().result(InitiatorResult::RollBackward(point, tip)),
253 if *n == 0 { Idle } else { CanAwait(*n - 1) },
254 ),
255 (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
256 })
257 }
258
259 fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
260 use InitiatorState::*;
261
262 Ok(match (self, input) {
263 (Idle, InitiatorAction::Intersect(points)) => {
264 (outcome().send(Message::FindIntersect(points)).want_next(), Intersect)
265 }
266 (Idle, InitiatorAction::RequestNext) => (outcome().send(Message::RequestNext(1)).want_next(), CanAwait(0)),
267 (CanAwait(n), InitiatorAction::RequestNext) => {
268 (outcome().send(Message::RequestNext(1)).want_next(), CanAwait(*n + 1))
269 }
270 (MustReply(n), InitiatorAction::RequestNext) => {
271 (outcome().send(Message::RequestNext(1)).want_next(), MustReply(*n + 1))
272 }
273 (Idle, InitiatorAction::Done) => (outcome().send(Message::Done), Done),
274 (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
275 })
276 }
277}
278
279#[cfg(test)]
280#[expect(clippy::wildcard_enum_match_arm)]
281pub mod tests {
282 use InitiatorState::*;
283 use Message::*;
284 use amaru_kernel::{EraName, Hash, HeaderHash, RawBlock, Slot, make_header, size::HEADER};
285 use amaru_ouroboros_traits::{Nonces, StoreError};
286
287 use super::*;
288 use crate::protocol::ProtoSpec;
289
290 pub fn spec() -> ProtoSpec<InitiatorState, Message, Initiator> {
291 let find_intersect = || FindIntersect(vec![Point::Origin]);
293 let intersect_found = || IntersectFound(Point::Origin, Tip::origin());
294 let intersect_not_found = || IntersectNotFound(Tip::origin());
295 let roll_forward = || RollForward(HeaderContent::with_bytes(vec![], EraName::Conway), Tip::origin());
296 let roll_backward = || RollBackward(Point::Origin, Tip::origin());
297
298 let mut spec = ProtoSpec::default();
299 spec.init(Idle, find_intersect(), Intersect);
300 spec.init(Idle, Message::Done, InitiatorState::Done);
301 spec.init(Idle, Message::RequestNext(1), CanAwait(0));
302 spec.resp(Intersect, intersect_found(), Idle);
303 spec.resp(Intersect, intersect_not_found(), Idle);
304 spec.resp(CanAwait(0), AwaitReply, MustReply(0));
305 spec.resp(CanAwait(0), roll_forward(), Idle);
306 spec.resp(CanAwait(0), roll_backward(), Idle);
307 spec.resp(MustReply(0), roll_forward(), Idle);
308 spec.resp(MustReply(0), roll_backward(), Idle);
309 spec
310 }
311
312 #[test]
313 #[ignore = "pipelining cannot be tested yet"]
314 fn test_initiator_protocol() {
315 spec().check(Idle, |msg| match msg {
316 FindIntersect(points) => Some(InitiatorAction::Intersect(points.clone())),
317 RequestNext(1) => Some(InitiatorAction::RequestNext),
318 Message::Done => Some(InitiatorAction::Done),
319 _ => None,
320 });
321 }
322
323 #[test]
324 fn test_intersect_points_includes_best_point_and_are_spaced_with_a_factor_2() {
325 let store = MockChainStoreForIntersectPoints::default();
326 let points = intersect_points(&store);
327 let slots = points.iter().map(|p| p.slot_or_default().into()).collect::<Vec<u64>>();
328 assert_eq!(slots, vec![100, 99, 98, 96, 92, 84, 68, 36, 0]);
330 }
331
332 #[derive(Debug)]
334 struct MockChainStoreForIntersectPoints {
335 best_point: Point,
336 }
337
338 impl Default for MockChainStoreForIntersectPoints {
339 fn default() -> Self {
340 Self { best_point: Point::Specific(Slot::from(100), Hash::new([100u8; HEADER])) }
341 }
342 }
343
344 #[expect(clippy::todo)]
345 impl ReadOnlyChainStore<BlockHeader> for MockChainStoreForIntersectPoints {
346 fn get_best_chain_hash(&self) -> HeaderHash {
347 self.best_point.hash()
348 }
349
350 fn load_header(&self, _hash: &HeaderHash) -> Option<BlockHeader> {
351 Some(BlockHeader::new(
352 make_header(1, self.best_point.slot_or_default().into(), None),
353 self.best_point.hash(),
354 ))
355 }
356
357 fn ancestors<'a>(&'a self, _from: BlockHeader) -> Box<dyn Iterator<Item = BlockHeader> + 'a>
358 where
359 BlockHeader: 'a,
360 {
361 let mut ancestor_block_headers = vec![];
362 for slot in 0..100 {
363 let header_hash = Hash::new([slot as u8; HEADER]);
364 let block_header = BlockHeader::new(make_header(1, slot, None), header_hash);
365 ancestor_block_headers.push(block_header);
366 }
367 ancestor_block_headers.reverse();
368 Box::new(ancestor_block_headers.into_iter())
369 }
370
371 fn get_children(&self, _hash: &HeaderHash) -> Vec<HeaderHash> {
372 todo!()
373 }
374
375 fn get_anchor_hash(&self) -> HeaderHash {
376 todo!()
377 }
378
379 fn load_from_best_chain(&self, _point: &Point) -> Option<HeaderHash> {
380 todo!()
381 }
382
383 fn next_best_chain(&self, _point: &Point) -> Option<Point> {
384 todo!()
385 }
386
387 fn load_block(&self, _hash: &HeaderHash) -> Result<Option<RawBlock>, StoreError> {
388 todo!()
389 }
390
391 fn get_nonces(&self, _header: &HeaderHash) -> Option<Nonces> {
392 todo!()
393 }
394
395 fn has_header(&self, _hash: &HeaderHash) -> bool {
396 todo!()
397 }
398 }
399}