1use std::sync::Arc;
16
17use amaru_kernel::{EraHistory, NetworkMagic, ORIGIN_HASH, Peer, Point, Tip};
18use amaru_ouroboros::{ConnectionId, ReadOnlyChainStore, TxOrigin};
19use pure_stage::{DeserializerGuards, Effects, StageRef, Void, register_data_deserializer};
20use tracing::instrument;
21
22use crate::{
23 blockfetch::{
24 self, BlockFetchMessage, Blocks, StreamBlocks, register_blockfetch_initiator, register_blockfetch_responder,
25 },
26 chainsync::{self, ChainSyncInitiatorMsg, register_chainsync_initiator, register_chainsync_responder},
27 handshake,
28 keepalive::register_keepalive,
29 manager::ManagerConfig,
30 mux::{self, HandlerMessage, MuxMessage},
31 protocol::{Inputs, PROTO_HANDSHAKE, Role},
32 protocol_messages::{
33 handshake::HandshakeResult, version_data::VersionData, version_number::VersionNumber,
34 version_table::VersionTable,
35 },
36 store_effects::Store,
37 tx_submission::register_tx_submission,
38};
39
40#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
41pub struct Connection {
42 params: Params,
43 state: State,
44}
45
46impl Connection {
47 pub fn new(
48 peer: Peer,
49 conn_id: ConnectionId,
50 role: Role,
51 config: ManagerConfig,
52 magic: NetworkMagic,
53 pipeline: StageRef<ChainSyncInitiatorMsg>,
54 era_history: Arc<EraHistory>,
55 ) -> Self {
56 Self { params: Params { peer, conn_id, role, config, magic, pipeline, era_history }, state: State::Initial }
57 }
58}
59
60#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
61struct Params {
62 peer: Peer,
63 conn_id: ConnectionId,
64 role: Role,
65 magic: NetworkMagic,
66 config: ManagerConfig,
67 pipeline: StageRef<ChainSyncInitiatorMsg>,
68 era_history: Arc<EraHistory>,
69}
70
71#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
72enum State {
73 Initial,
74 Handshake { muxer: StageRef<MuxMessage>, handshake: StageRef<Inputs<Void>> },
75 Initiator(StateInitiator),
76 Responder(StateResponder),
77}
78
79#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
80struct StateInitiator {
81 chainsync_initiator: StageRef<chainsync::InitiatorMessage>,
82 blockfetch_initiator: StageRef<blockfetch::BlockFetchMessage>,
83 version_number: VersionNumber,
84 version_data: VersionData,
85 muxer: StageRef<MuxMessage>,
86 handshake: StageRef<Inputs<Void>>,
87 keepalive: StageRef<HandlerMessage>,
88 tx_submission: StageRef<HandlerMessage>,
89}
90
91#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
92struct StateResponder {
93 chainsync_responder: StageRef<chainsync::ResponderMessage>,
94 muxer: StageRef<MuxMessage>,
95 handshake: StageRef<Inputs<Void>>,
96 keepalive: StageRef<HandlerMessage>,
97 tx_submission: StageRef<HandlerMessage>,
98 blockfetch_responder: StageRef<StreamBlocks>,
99}
100
101#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
102pub enum ConnectionMessage {
103 Initialize,
104 Disconnect,
105 Handshake(HandshakeResult),
106 FetchBlocks { from: Point, through: Point, cr: StageRef<Blocks> },
107 NewTip(Tip),
108 }
110
111impl ConnectionMessage {
112 fn message_type(&self) -> &'static str {
113 match self {
114 ConnectionMessage::Initialize => "Initialize",
115 ConnectionMessage::Disconnect => "Disconnect",
116 ConnectionMessage::Handshake(_) => "Handshake",
117 ConnectionMessage::FetchBlocks { .. } => "FetchBlocks",
118 ConnectionMessage::NewTip(_) => "NewTip",
119 }
120 }
121}
122
123#[instrument(name = "connection", skip_all, fields(message_type = msg.message_type(), conn_id = %params.conn_id, peer = %params.peer, role = ?params.role))]
124pub async fn stage(
125 Connection { params, state }: Connection,
126 msg: ConnectionMessage,
127 eff: Effects<ConnectionMessage>,
128) -> Connection {
129 let state = match (state, msg) {
130 (_, ConnectionMessage::Disconnect) => return eff.terminate().await,
131 (State::Initial, ConnectionMessage::Initialize) => do_initialize(¶ms, eff).await,
132 (State::Handshake { muxer, handshake }, ConnectionMessage::Handshake(handshake_result)) => {
133 do_handshake(¶ms, muxer, params.pipeline.clone(), handshake, handshake_result, eff).await
134 }
135 (State::Initiator(s), ConnectionMessage::FetchBlocks { from, through, cr }) => {
136 eff.send(&s.blockfetch_initiator, BlockFetchMessage::RequestRange { from, through, cr }).await;
137 State::Initiator(s)
138 }
139 (State::Responder(s), ConnectionMessage::NewTip(tip)) => {
140 eff.send(&s.chainsync_responder, chainsync::ResponderMessage::NewTip(tip)).await;
141 State::Responder(s)
142 }
143 (State::Initiator(s), ConnectionMessage::NewTip(_)) => {
144 State::Initiator(s)
146 }
147 (state @ (State::Initial | State::Handshake { .. }), msg @ ConnectionMessage::FetchBlocks { .. }) => {
148 eff.schedule_after(msg, params.config.reconnect_delay).await;
153 state
154 }
155 x => unimplemented!("{x:?}"),
156 };
157 Connection { params, state }
158}
159
160async fn do_initialize(Params { conn_id, role, magic, .. }: &Params, eff: Effects<ConnectionMessage>) -> State {
161 let muxer = eff.stage("mux", mux::stage).await;
162 let muxer = eff.wire_up(muxer, mux::State::new(*conn_id, &[(PROTO_HANDSHAKE.erase(), 5760)], *role)).await;
163
164 let handshake_result = eff.contramap(eff.me(), "handshake_result", ConnectionMessage::Handshake).await;
165
166 let handshake = match role {
167 Role::Initiator => {
168 eff.wire_up(
169 eff.stage("handshake", handshake::initiator()).await,
170 handshake::HandshakeInitiator::new(
171 muxer.clone(),
172 handshake_result,
173 VersionTable::v11_and_above(*magic, true),
174 ),
175 )
176 .await
177 }
178 Role::Responder => {
179 eff.wire_up(
180 eff.stage("handshake", handshake::responder()).await,
181 handshake::HandshakeResponder::new(
182 muxer.clone(),
183 handshake_result,
184 VersionTable::v11_and_above(*magic, false),
187 ),
188 )
189 .await
190 }
191 };
192
193 let handler = eff.contramap(&handshake, "handshake_bytes", Inputs::Network).await;
194
195 let protocol = match role {
196 Role::Initiator => PROTO_HANDSHAKE.erase(),
197 Role::Responder => PROTO_HANDSHAKE.responder().erase(),
198 };
199 eff.send(&muxer, MuxMessage::Register { protocol, frame: mux::Frame::OneCborItem, handler, max_buffer: 5760 })
200 .await;
201
202 State::Handshake { muxer, handshake }
203}
204
205#[expect(clippy::expect_used)]
206async fn do_handshake(
207 Params { role, peer, conn_id, era_history, .. }: &Params,
208 muxer: StageRef<MuxMessage>,
209 pipeline: StageRef<ChainSyncInitiatorMsg>,
210 handshake: StageRef<Inputs<Void>>,
211 handshake_result: HandshakeResult,
212 eff: Effects<ConnectionMessage>,
213) -> State {
214 let (version_number, version_data) = match handshake_result {
215 HandshakeResult::Accepted(version_number, version_data) => (version_number, version_data),
216 HandshakeResult::Refused(refuse_reason) => {
217 tracing::error!(?refuse_reason, "handshake refused");
218 return eff.terminate().await;
219 }
220 HandshakeResult::Query(version_table) => {
221 tracing::info!(?version_table, "handshake query reply");
222 return eff.terminate().await;
223 }
224 };
225
226 let keepalive = register_keepalive(*role, muxer.clone(), &eff).await;
227 let tx_submission = register_tx_submission(*role, muxer.clone(), &eff, TxOrigin::Remote(peer.clone())).await;
228
229 if *role == Role::Initiator {
230 let chainsync_initiator = register_chainsync_initiator(&muxer, peer.clone(), *conn_id, pipeline, &eff).await;
231 let blockfetch_initiator =
232 register_blockfetch_initiator(&muxer, peer.clone(), *conn_id, era_history.clone(), &eff).await;
233 State::Initiator(StateInitiator {
234 chainsync_initiator,
235 blockfetch_initiator,
236 version_number,
237 version_data,
238 muxer,
239 handshake,
240 keepalive,
241 tx_submission,
242 })
243 } else {
244 let store = Store::new(eff.clone());
245 let upstream = store.get_best_chain_hash();
246 let upstream = if upstream == ORIGIN_HASH {
247 Tip::new(Point::Origin, 0.into())
248 } else {
249 let header = store.load_header(&upstream).expect("best chain hash not found");
250 header.tip()
251 };
252 let chainsync_responder = register_chainsync_responder(&muxer, upstream, peer.clone(), *conn_id, &eff).await;
253 let blockfetch_responder = register_blockfetch_responder(&muxer, &eff).await;
254
255 State::Responder(StateResponder {
256 chainsync_responder,
257 blockfetch_responder,
258 muxer,
259 handshake,
260 keepalive,
261 tx_submission,
262 })
263 }
264}
265
266pub fn register_deserializers() -> DeserializerGuards {
267 vec![
268 register_data_deserializer::<(ConnectionId, StageRef<mux::MuxMessage>, Role)>().boxed(),
269 register_data_deserializer::<Connection>().boxed(),
270 register_data_deserializer::<ConnectionMessage>().boxed(),
271 ]
272}
273
274#[cfg(test)]
275mod tests {
276 use amaru_kernel::NetworkName;
277 use pure_stage::{Effect, StageGraph, simulation::SimulationBuilder};
278
279 use super::*;
280
281 #[test]
282 fn test_fetch_blocks_in_initial_state_reschedules() {
283 fetch_blocks_in_disconnected_state_reschedules(State::Initial);
284 }
285
286 #[test]
287 fn test_fetch_blocks_in_handshake_state_reschedules() {
288 let handshake_state = State::Handshake { muxer: StageRef::blackhole(), handshake: StageRef::blackhole() };
289 fetch_blocks_in_disconnected_state_reschedules(handshake_state);
290 }
291
292 fn test_connection(state: State) -> Connection {
295 let era_history: &EraHistory = NetworkName::Preprod.into();
296 Connection {
297 params: Params {
298 peer: Peer::new("test-peer"),
299 conn_id: ConnectionId::initial(),
300 role: Role::Initiator,
301 config: ManagerConfig::default(),
302 magic: NetworkMagic::PREPROD,
303 pipeline: StageRef::blackhole(),
304 era_history: Arc::new(era_history.clone()),
305 },
306 state,
307 }
308 }
309
310 fn fetch_blocks_in_disconnected_state_reschedules(connection_state: State) {
311 let mut network = SimulationBuilder::default();
312
313 let connection_stage = network.stage("connection", stage);
314 let connection_stage = network.wire_up(connection_stage, test_connection(connection_state.clone()));
315
316 let (blocks_output, _rx) = network.output::<Blocks>("blocks_output", 10);
317
318 let fetch_msg =
319 ConnectionMessage::FetchBlocks { from: Point::Origin, through: Point::Origin, cr: blocks_output };
320
321 network.preload(&connection_stage, [fetch_msg]).unwrap();
322
323 let mut running = network.run();
324 let start_time = running.now();
325
326 let stage_name = connection_stage.name().clone();
327 running.breakpoint(
328 "schedule",
329 move |eff| matches!(eff, Effect::Schedule { at_stage, .. } if *at_stage == stage_name),
330 );
331
332 let effect = running.run_until_blocked().assert_breakpoint("schedule");
333
334 let reconnect_delay = ManagerConfig::default().reconnect_delay;
335 if let Effect::Schedule { id, .. } = &effect {
336 let delay = id.time().checked_since(start_time).unwrap();
337 assert!(delay >= reconnect_delay);
338 } else {
339 panic!("Expected Schedule effect");
340 }
341
342 running.clear_breakpoint("schedule");
344 running.handle_effect(effect);
345
346 running.run_until_sleeping_or_blocked().assert_sleeping();
348
349 let state = running.get_state(&connection_stage).unwrap();
351 assert_eq!(state.state, connection_state);
352 }
353}