1use std::sync::Arc;
16
17use amaru_kernel::{EraHistory, NetworkMagic, ORIGIN_HASH, Peer, Point, Tip};
18use amaru_ouroboros::{ConnectionId, ReadOnlyChainStore, TxOrigin};
19use pure_stage::{DeserializerGuards, Effects, StageRef, Void, register_data_deserializer};
20use tracing::instrument;
21
22use crate::{
23 blockfetch::{
24 self, BlockFetchMessage, Blocks, StreamBlocks, register_blockfetch_initiator, register_blockfetch_responder,
25 },
26 chainsync::{self, ChainSyncInitiatorMsg, register_chainsync_initiator, register_chainsync_responder},
27 handshake,
28 keepalive::register_keepalive,
29 manager::ManagerConfig,
30 mux::{self, HandlerMessage, MuxMessage},
31 protocol::{Inputs, PROTO_HANDSHAKE, Role},
32 protocol_messages::{
33 handshake::HandshakeResult, version_data::VersionData, version_number::VersionNumber,
34 version_table::VersionTable,
35 },
36 store_effects::Store,
37 tx_submission::register_tx_submission,
38};
39
40#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
41pub struct Connection {
42 params: Params,
43 state: State,
44}
45
46impl Connection {
47 pub fn new(
48 peer: Peer,
49 conn_id: ConnectionId,
50 role: Role,
51 config: ManagerConfig,
52 magic: NetworkMagic,
53 pipeline: StageRef<ChainSyncInitiatorMsg>,
54 era_history: Arc<EraHistory>,
55 ) -> Self {
56 Self { params: Params { peer, conn_id, role, config, magic, pipeline, era_history }, state: State::Initial }
57 }
58}
59
60#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
61struct Params {
62 peer: Peer,
63 conn_id: ConnectionId,
64 role: Role,
65 magic: NetworkMagic,
66 config: ManagerConfig,
67 pipeline: StageRef<ChainSyncInitiatorMsg>,
68 era_history: Arc<EraHistory>,
69}
70
71#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
72enum State {
73 Initial,
74 Handshake { muxer: StageRef<MuxMessage>, handshake: StageRef<Inputs<Void>> },
75 Initiator(StateInitiator),
76 Responder(StateResponder),
77}
78
79#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
80struct StateInitiator {
81 chainsync_initiator: StageRef<chainsync::InitiatorMessage>,
82 blockfetch_initiator: StageRef<blockfetch::BlockFetchMessage>,
83 version_number: VersionNumber,
84 version_data: VersionData,
85 muxer: StageRef<MuxMessage>,
86 handshake: StageRef<Inputs<Void>>,
87 keepalive: StageRef<HandlerMessage>,
88 tx_submission: StageRef<HandlerMessage>,
89}
90
91#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
92struct StateResponder {
93 chainsync_responder: StageRef<chainsync::ResponderMessage>,
94 muxer: StageRef<MuxMessage>,
95 handshake: StageRef<Inputs<Void>>,
96 keepalive: StageRef<HandlerMessage>,
97 tx_submission: StageRef<HandlerMessage>,
98 blockfetch_responder: StageRef<StreamBlocks>,
99}
100
101#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
102pub enum ConnectionMessage {
103 Initialize,
104 Disconnect,
105 Handshake(HandshakeResult),
106 FetchBlocks { from: Point, through: Point, cr: StageRef<Blocks> },
107 NewTip(Tip),
108 }
110
111impl ConnectionMessage {
112 fn message_type(&self) -> &'static str {
113 match self {
114 ConnectionMessage::Initialize => "Initialize",
115 ConnectionMessage::Disconnect => "Disconnect",
116 ConnectionMessage::Handshake(_) => "Handshake",
117 ConnectionMessage::FetchBlocks { .. } => "FetchBlocks",
118 ConnectionMessage::NewTip(_) => "NewTip",
119 }
120 }
121}
122
123#[instrument(name = "connection", skip_all, fields(message_type = msg.message_type(), conn_id = %params.conn_id, peer = %params.peer, role = ?params.role))]
124pub async fn stage(
125 Connection { params, state }: Connection,
126 msg: ConnectionMessage,
127 eff: Effects<ConnectionMessage>,
128) -> Connection {
129 let state = match (state, msg) {
130 (_, ConnectionMessage::Disconnect) => return eff.terminate().await,
131 (State::Initial, ConnectionMessage::Initialize) => do_initialize(¶ms, eff).await,
132 (State::Handshake { muxer, handshake }, ConnectionMessage::Handshake(handshake_result)) => {
133 do_handshake(¶ms, muxer, params.pipeline.clone(), handshake, handshake_result, eff).await
134 }
135 (State::Initiator(s), ConnectionMessage::FetchBlocks { from, through, cr }) => {
136 eff.send(&s.blockfetch_initiator, BlockFetchMessage::RequestRange { from, through, cr }).await;
137 State::Initiator(s)
138 }
139 (State::Responder(s), ConnectionMessage::NewTip(tip)) => {
140 eff.send(&s.chainsync_responder, chainsync::ResponderMessage::NewTip(tip)).await;
141 State::Responder(s)
142 }
143 (State::Initiator(s), ConnectionMessage::NewTip(_)) => {
144 State::Initiator(s)
146 }
147 (state @ (State::Initial | State::Handshake { .. }), msg @ ConnectionMessage::FetchBlocks { .. }) => {
148 eff.schedule_after(msg, params.config.reconnect_delay).await;
153 state
154 }
155 (state @ (State::Initial | State::Handshake { .. }), msg @ ConnectionMessage::NewTip(_)) => {
156 eff.schedule_after(msg, params.config.reconnect_delay).await;
158 state
159 }
160 x => unimplemented!("{x:?}"),
161 };
162 Connection { params, state }
163}
164
165async fn do_initialize(Params { conn_id, role, magic, .. }: &Params, eff: Effects<ConnectionMessage>) -> State {
166 let muxer = eff.stage("mux", mux::stage).await;
167 let muxer = eff.wire_up(muxer, mux::State::new(*conn_id, &[(PROTO_HANDSHAKE.erase(), 5760)], *role)).await;
168
169 let handshake_result = eff.contramap(eff.me(), "handshake_result", ConnectionMessage::Handshake).await;
170
171 let handshake = match role {
172 Role::Initiator => {
173 eff.wire_up(
174 eff.stage("handshake", handshake::initiator()).await,
175 handshake::HandshakeInitiator::new(
176 muxer.clone(),
177 handshake_result,
178 VersionTable::v11_and_above(*magic, true),
179 ),
180 )
181 .await
182 }
183 Role::Responder => {
184 eff.wire_up(
185 eff.stage("handshake", handshake::responder()).await,
186 handshake::HandshakeResponder::new(
187 muxer.clone(),
188 handshake_result,
189 VersionTable::v11_and_above(*magic, false),
192 ),
193 )
194 .await
195 }
196 };
197
198 let handler = eff.contramap(&handshake, "handshake_bytes", Inputs::Network).await;
199
200 let protocol = match role {
201 Role::Initiator => PROTO_HANDSHAKE.erase(),
202 Role::Responder => PROTO_HANDSHAKE.responder().erase(),
203 };
204 eff.send(&muxer, MuxMessage::Register { protocol, frame: mux::Frame::OneCborItem, handler, max_buffer: 5760 })
205 .await;
206
207 State::Handshake { muxer, handshake }
208}
209
210#[expect(clippy::expect_used)]
211async fn do_handshake(
212 Params { role, peer, conn_id, era_history, .. }: &Params,
213 muxer: StageRef<MuxMessage>,
214 pipeline: StageRef<ChainSyncInitiatorMsg>,
215 handshake: StageRef<Inputs<Void>>,
216 handshake_result: HandshakeResult,
217 eff: Effects<ConnectionMessage>,
218) -> State {
219 let (version_number, version_data) = match handshake_result {
220 HandshakeResult::Accepted(version_number, version_data) => (version_number, version_data),
221 HandshakeResult::Refused(refuse_reason) => {
222 tracing::error!(?refuse_reason, "handshake refused");
223 return eff.terminate().await;
224 }
225 HandshakeResult::Query(version_table) => {
226 tracing::info!(?version_table, "handshake query reply");
227 return eff.terminate().await;
228 }
229 };
230
231 let keepalive = register_keepalive(*role, muxer.clone(), &eff).await;
232 let tx_submission = register_tx_submission(*role, muxer.clone(), &eff, TxOrigin::Remote(peer.clone())).await;
233
234 if *role == Role::Initiator {
235 let chainsync_initiator = register_chainsync_initiator(&muxer, peer.clone(), *conn_id, pipeline, &eff).await;
236 let blockfetch_initiator =
237 register_blockfetch_initiator(&muxer, peer.clone(), *conn_id, era_history.clone(), &eff).await;
238 State::Initiator(StateInitiator {
239 chainsync_initiator,
240 blockfetch_initiator,
241 version_number,
242 version_data,
243 muxer,
244 handshake,
245 keepalive,
246 tx_submission,
247 })
248 } else {
249 let store = Store::new(eff.clone());
250 let upstream = store.get_best_chain_hash();
251 let upstream = if upstream == ORIGIN_HASH {
252 Tip::new(Point::Origin, 0.into())
253 } else {
254 let header = store.load_header(&upstream).expect("best chain hash not found");
255 header.tip()
256 };
257 let chainsync_responder = register_chainsync_responder(&muxer, upstream, peer.clone(), *conn_id, &eff).await;
258 let blockfetch_responder = register_blockfetch_responder(&muxer, &eff).await;
259
260 State::Responder(StateResponder {
261 chainsync_responder,
262 blockfetch_responder,
263 muxer,
264 handshake,
265 keepalive,
266 tx_submission,
267 })
268 }
269}
270
271pub fn register_deserializers() -> DeserializerGuards {
272 vec![
273 register_data_deserializer::<(ConnectionId, StageRef<mux::MuxMessage>, Role)>().boxed(),
274 register_data_deserializer::<Connection>().boxed(),
275 register_data_deserializer::<ConnectionMessage>().boxed(),
276 ]
277}
278
279#[cfg(test)]
280mod tests {
281 use amaru_kernel::NetworkName;
282 use pure_stage::{Effect, StageGraph, simulation::SimulationBuilder};
283
284 use super::*;
285
286 #[test]
287 fn test_fetch_blocks_in_initial_state_reschedules() {
288 fetch_blocks_in_disconnected_state_reschedules(State::Initial);
289 }
290
291 #[test]
292 fn test_fetch_blocks_in_handshake_state_reschedules() {
293 let handshake_state = State::Handshake { muxer: StageRef::blackhole(), handshake: StageRef::blackhole() };
294 fetch_blocks_in_disconnected_state_reschedules(handshake_state);
295 }
296
297 #[test]
298 fn test_new_tip_in_initial_state_reschedules() {
299 new_tip_in_disconnected_state_reschedules(State::Initial);
300 }
301
302 #[test]
303 fn test_new_tip_in_handshake_state_reschedules() {
304 let handshake_state = State::Handshake { muxer: StageRef::blackhole(), handshake: StageRef::blackhole() };
305 new_tip_in_disconnected_state_reschedules(handshake_state);
306 }
307
308 fn fetch_blocks_in_disconnected_state_reschedules(connection_state: State) {
309 assert_message_reschedules_in_disconnected_state(connection_state, |network| {
310 let (blocks_output, _rx) = network.output::<Blocks>("blocks_output", 10);
311 ConnectionMessage::FetchBlocks { from: Point::Origin, through: Point::Origin, cr: blocks_output }
312 });
313 }
314
315 fn new_tip_in_disconnected_state_reschedules(connection_state: State) {
316 assert_message_reschedules_in_disconnected_state(
317 connection_state,
318 |_| ConnectionMessage::NewTip(Tip::origin()),
319 );
320 }
321
322 fn assert_message_reschedules_in_disconnected_state(
323 connection_state: State,
324 make_msg: impl FnOnce(&mut SimulationBuilder) -> ConnectionMessage,
325 ) {
326 let mut network = SimulationBuilder::default();
327
328 let connection_stage = network.stage("connection", stage);
329 let connection_stage = network.wire_up(connection_stage, test_connection(connection_state.clone()));
330
331 let msg = make_msg(&mut network);
332 network.preload(&connection_stage, [msg]).unwrap();
333
334 let mut running = network.run();
335 let start_time = running.now();
336
337 let stage_name = connection_stage.name().clone();
338 running.breakpoint(
339 "schedule",
340 move |eff| matches!(eff, Effect::Schedule { at_stage, .. } if *at_stage == stage_name),
341 );
342
343 let effect = running.run_until_blocked().assert_breakpoint("schedule");
344
345 let reconnect_delay = ManagerConfig::default().reconnect_delay;
346 if let Effect::Schedule { id, .. } = &effect {
347 let delay = id.time().checked_since(start_time).unwrap();
348 assert!(delay >= reconnect_delay);
349 } else {
350 panic!("Expected Schedule effect");
351 }
352
353 running.clear_breakpoint("schedule");
355 running.handle_effect(effect);
356
357 running.run_until_sleeping_or_blocked().assert_sleeping();
359
360 let state = running.get_state(&connection_stage).unwrap();
362 assert_eq!(state.state, connection_state);
363 }
364
365 fn test_connection(state: State) -> Connection {
368 let era_history: &EraHistory = NetworkName::Preprod.into();
369 Connection {
370 params: Params {
371 peer: Peer::new("test-peer"),
372 conn_id: ConnectionId::initial(),
373 role: Role::Initiator,
374 config: ManagerConfig::default(),
375 magic: NetworkMagic::PREPROD,
376 pipeline: StageRef::blackhole(),
377 era_history: Arc::new(era_history.clone()),
378 },
379 state,
380 }
381 }
382}