Skip to main content

amaru_protocols/
connection.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use amaru_kernel::{EraHistory, NetworkMagic, ORIGIN_HASH, Peer, Point, Tip};
18use amaru_ouroboros::{ConnectionId, ReadOnlyChainStore, TxOrigin};
19use pure_stage::{DeserializerGuards, Effects, StageRef, Void, register_data_deserializer};
20use tracing::instrument;
21
22use crate::{
23    blockfetch::{
24        self, BlockFetchMessage, Blocks, StreamBlocks, register_blockfetch_initiator, register_blockfetch_responder,
25    },
26    chainsync::{self, ChainSyncInitiatorMsg, register_chainsync_initiator, register_chainsync_responder},
27    handshake,
28    keepalive::register_keepalive,
29    manager::ManagerConfig,
30    mux::{self, HandlerMessage, MuxMessage},
31    protocol::{Inputs, PROTO_HANDSHAKE, Role},
32    protocol_messages::{
33        handshake::HandshakeResult, version_data::VersionData, version_number::VersionNumber,
34        version_table::VersionTable,
35    },
36    store_effects::Store,
37    tx_submission::register_tx_submission,
38};
39
40#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
41pub struct Connection {
42    params: Params,
43    state: State,
44}
45
46impl Connection {
47    pub fn new(
48        peer: Peer,
49        conn_id: ConnectionId,
50        role: Role,
51        config: ManagerConfig,
52        magic: NetworkMagic,
53        pipeline: StageRef<ChainSyncInitiatorMsg>,
54        era_history: Arc<EraHistory>,
55    ) -> Self {
56        Self { params: Params { peer, conn_id, role, config, magic, pipeline, era_history }, state: State::Initial }
57    }
58}
59
60#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
61struct Params {
62    peer: Peer,
63    conn_id: ConnectionId,
64    role: Role,
65    magic: NetworkMagic,
66    config: ManagerConfig,
67    pipeline: StageRef<ChainSyncInitiatorMsg>,
68    era_history: Arc<EraHistory>,
69}
70
71#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
72enum State {
73    Initial,
74    Handshake { muxer: StageRef<MuxMessage>, handshake: StageRef<Inputs<Void>> },
75    Initiator(StateInitiator),
76    Responder(StateResponder),
77}
78
79#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
80struct StateInitiator {
81    chainsync_initiator: StageRef<chainsync::InitiatorMessage>,
82    blockfetch_initiator: StageRef<blockfetch::BlockFetchMessage>,
83    version_number: VersionNumber,
84    version_data: VersionData,
85    muxer: StageRef<MuxMessage>,
86    handshake: StageRef<Inputs<Void>>,
87    keepalive: StageRef<HandlerMessage>,
88    tx_submission: StageRef<HandlerMessage>,
89}
90
91#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
92struct StateResponder {
93    chainsync_responder: StageRef<chainsync::ResponderMessage>,
94    muxer: StageRef<MuxMessage>,
95    handshake: StageRef<Inputs<Void>>,
96    keepalive: StageRef<HandlerMessage>,
97    tx_submission: StageRef<HandlerMessage>,
98    blockfetch_responder: StageRef<StreamBlocks>,
99}
100
101#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
102pub enum ConnectionMessage {
103    Initialize,
104    Disconnect,
105    Handshake(HandshakeResult),
106    FetchBlocks { from: Point, through: Point, cr: StageRef<Blocks> },
107    NewTip(Tip),
108    // LATER: make full duplex, etc.
109}
110
111impl ConnectionMessage {
112    fn message_type(&self) -> &'static str {
113        match self {
114            ConnectionMessage::Initialize => "Initialize",
115            ConnectionMessage::Disconnect => "Disconnect",
116            ConnectionMessage::Handshake(_) => "Handshake",
117            ConnectionMessage::FetchBlocks { .. } => "FetchBlocks",
118            ConnectionMessage::NewTip(_) => "NewTip",
119        }
120    }
121}
122
123#[instrument(name = "connection", skip_all, fields(message_type = msg.message_type(), conn_id = %params.conn_id, peer = %params.peer, role = ?params.role))]
124pub async fn stage(
125    Connection { params, state }: Connection,
126    msg: ConnectionMessage,
127    eff: Effects<ConnectionMessage>,
128) -> Connection {
129    let state = match (state, msg) {
130        (_, ConnectionMessage::Disconnect) => return eff.terminate().await,
131        (State::Initial, ConnectionMessage::Initialize) => do_initialize(&params, eff).await,
132        (State::Handshake { muxer, handshake }, ConnectionMessage::Handshake(handshake_result)) => {
133            do_handshake(&params, muxer, params.pipeline.clone(), handshake, handshake_result, eff).await
134        }
135        (State::Initiator(s), ConnectionMessage::FetchBlocks { from, through, cr }) => {
136            eff.send(&s.blockfetch_initiator, BlockFetchMessage::RequestRange { from, through, cr }).await;
137            State::Initiator(s)
138        }
139        (State::Responder(s), ConnectionMessage::NewTip(tip)) => {
140            eff.send(&s.chainsync_responder, chainsync::ResponderMessage::NewTip(tip)).await;
141            State::Responder(s)
142        }
143        (State::Initiator(s), ConnectionMessage::NewTip(_)) => {
144            // don't propagate new tip messages when using the initiator side of a connection.
145            State::Initiator(s)
146        }
147        (state @ (State::Initial | State::Handshake { .. }), msg @ ConnectionMessage::FetchBlocks { .. }) => {
148            // The peer might be still connecting. In that case we reschedule the message
149            // If the peer eventually can't be fully initialized, the caller timeout will trigger.
150            // We schedule after the reconnect delay (2s by default) which is shorter than the call
151            // timeout (5s) (whereas a full connection timeout is 10s).
152            eff.schedule_after(msg, params.config.reconnect_delay).await;
153            state
154        }
155        (state @ (State::Initial | State::Handshake { .. }), msg @ ConnectionMessage::NewTip(_)) => {
156            // The peer might be still connecting. Reschedule the NewTip message.
157            eff.schedule_after(msg, params.config.reconnect_delay).await;
158            state
159        }
160        x => unimplemented!("{x:?}"),
161    };
162    Connection { params, state }
163}
164
165async fn do_initialize(Params { conn_id, role, magic, .. }: &Params, eff: Effects<ConnectionMessage>) -> State {
166    let muxer = eff.stage("mux", mux::stage).await;
167    let muxer = eff.wire_up(muxer, mux::State::new(*conn_id, &[(PROTO_HANDSHAKE.erase(), 5760)], *role)).await;
168
169    let handshake_result = eff.contramap(eff.me(), "handshake_result", ConnectionMessage::Handshake).await;
170
171    let handshake = match role {
172        Role::Initiator => {
173            eff.wire_up(
174                eff.stage("handshake", handshake::initiator()).await,
175                handshake::HandshakeInitiator::new(
176                    muxer.clone(),
177                    handshake_result,
178                    VersionTable::v11_and_above(*magic, true),
179                ),
180            )
181            .await
182        }
183        Role::Responder => {
184            eff.wire_up(
185                eff.stage("handshake", handshake::responder()).await,
186                handshake::HandshakeResponder::new(
187                    muxer.clone(),
188                    handshake_result,
189                    // Use initiator_only_diffusion_mode = false so downstream peers
190                    // know we can serve as chainsync/blockfetch server
191                    VersionTable::v11_and_above(*magic, false),
192                ),
193            )
194            .await
195        }
196    };
197
198    let handler = eff.contramap(&handshake, "handshake_bytes", Inputs::Network).await;
199
200    let protocol = match role {
201        Role::Initiator => PROTO_HANDSHAKE.erase(),
202        Role::Responder => PROTO_HANDSHAKE.responder().erase(),
203    };
204    eff.send(&muxer, MuxMessage::Register { protocol, frame: mux::Frame::OneCborItem, handler, max_buffer: 5760 })
205        .await;
206
207    State::Handshake { muxer, handshake }
208}
209
210#[expect(clippy::expect_used)]
211async fn do_handshake(
212    Params { role, peer, conn_id, era_history, .. }: &Params,
213    muxer: StageRef<MuxMessage>,
214    pipeline: StageRef<ChainSyncInitiatorMsg>,
215    handshake: StageRef<Inputs<Void>>,
216    handshake_result: HandshakeResult,
217    eff: Effects<ConnectionMessage>,
218) -> State {
219    let (version_number, version_data) = match handshake_result {
220        HandshakeResult::Accepted(version_number, version_data) => (version_number, version_data),
221        HandshakeResult::Refused(refuse_reason) => {
222            tracing::error!(?refuse_reason, "handshake refused");
223            return eff.terminate().await;
224        }
225        HandshakeResult::Query(version_table) => {
226            tracing::info!(?version_table, "handshake query reply");
227            return eff.terminate().await;
228        }
229    };
230
231    let keepalive = register_keepalive(*role, muxer.clone(), &eff).await;
232    let tx_submission = register_tx_submission(*role, muxer.clone(), &eff, TxOrigin::Remote(peer.clone())).await;
233
234    if *role == Role::Initiator {
235        let chainsync_initiator = register_chainsync_initiator(&muxer, peer.clone(), *conn_id, pipeline, &eff).await;
236        let blockfetch_initiator =
237            register_blockfetch_initiator(&muxer, peer.clone(), *conn_id, era_history.clone(), &eff).await;
238        State::Initiator(StateInitiator {
239            chainsync_initiator,
240            blockfetch_initiator,
241            version_number,
242            version_data,
243            muxer,
244            handshake,
245            keepalive,
246            tx_submission,
247        })
248    } else {
249        let store = Store::new(eff.clone());
250        let upstream = store.get_best_chain_hash();
251        let upstream = if upstream == ORIGIN_HASH {
252            Tip::new(Point::Origin, 0.into())
253        } else {
254            let header = store.load_header(&upstream).expect("best chain hash not found");
255            header.tip()
256        };
257        let chainsync_responder = register_chainsync_responder(&muxer, upstream, peer.clone(), *conn_id, &eff).await;
258        let blockfetch_responder = register_blockfetch_responder(&muxer, &eff).await;
259
260        State::Responder(StateResponder {
261            chainsync_responder,
262            blockfetch_responder,
263            muxer,
264            handshake,
265            keepalive,
266            tx_submission,
267        })
268    }
269}
270
271pub fn register_deserializers() -> DeserializerGuards {
272    vec![
273        register_data_deserializer::<(ConnectionId, StageRef<mux::MuxMessage>, Role)>().boxed(),
274        register_data_deserializer::<Connection>().boxed(),
275        register_data_deserializer::<ConnectionMessage>().boxed(),
276    ]
277}
278
279#[cfg(test)]
280mod tests {
281    use amaru_kernel::NetworkName;
282    use pure_stage::{Effect, StageGraph, simulation::SimulationBuilder};
283
284    use super::*;
285
286    #[test]
287    fn test_fetch_blocks_in_initial_state_reschedules() {
288        fetch_blocks_in_disconnected_state_reschedules(State::Initial);
289    }
290
291    #[test]
292    fn test_fetch_blocks_in_handshake_state_reschedules() {
293        let handshake_state = State::Handshake { muxer: StageRef::blackhole(), handshake: StageRef::blackhole() };
294        fetch_blocks_in_disconnected_state_reschedules(handshake_state);
295    }
296
297    #[test]
298    fn test_new_tip_in_initial_state_reschedules() {
299        new_tip_in_disconnected_state_reschedules(State::Initial);
300    }
301
302    #[test]
303    fn test_new_tip_in_handshake_state_reschedules() {
304        let handshake_state = State::Handshake { muxer: StageRef::blackhole(), handshake: StageRef::blackhole() };
305        new_tip_in_disconnected_state_reschedules(handshake_state);
306    }
307
308    fn fetch_blocks_in_disconnected_state_reschedules(connection_state: State) {
309        assert_message_reschedules_in_disconnected_state(connection_state, |network| {
310            let (blocks_output, _rx) = network.output::<Blocks>("blocks_output", 10);
311            ConnectionMessage::FetchBlocks { from: Point::Origin, through: Point::Origin, cr: blocks_output }
312        });
313    }
314
315    fn new_tip_in_disconnected_state_reschedules(connection_state: State) {
316        assert_message_reschedules_in_disconnected_state(
317            connection_state,
318            |_| ConnectionMessage::NewTip(Tip::origin()),
319        );
320    }
321
322    fn assert_message_reschedules_in_disconnected_state(
323        connection_state: State,
324        make_msg: impl FnOnce(&mut SimulationBuilder) -> ConnectionMessage,
325    ) {
326        let mut network = SimulationBuilder::default();
327
328        let connection_stage = network.stage("connection", stage);
329        let connection_stage = network.wire_up(connection_stage, test_connection(connection_state.clone()));
330
331        let msg = make_msg(&mut network);
332        network.preload(&connection_stage, [msg]).unwrap();
333
334        let mut running = network.run();
335        let start_time = running.now();
336
337        let stage_name = connection_stage.name().clone();
338        running.breakpoint(
339            "schedule",
340            move |eff| matches!(eff, Effect::Schedule { at_stage, .. } if *at_stage == stage_name),
341        );
342
343        let effect = running.run_until_blocked().assert_breakpoint("schedule");
344
345        let reconnect_delay = ManagerConfig::default().reconnect_delay;
346        if let Effect::Schedule { id, .. } = &effect {
347            let delay = id.time().checked_since(start_time).unwrap();
348            assert!(delay >= reconnect_delay);
349        } else {
350            panic!("Expected Schedule effect");
351        }
352
353        // Clear the breakpoint before continuing
354        running.clear_breakpoint("schedule");
355        running.handle_effect(effect);
356
357        // Let the simulation continue until blocked (will hit the scheduled wake up)
358        running.run_until_sleeping_or_blocked().assert_sleeping();
359
360        // Verify state remains the same
361        let state = running.get_state(&connection_stage).unwrap();
362        assert_eq!(state.state, connection_state);
363    }
364
365    // HELPERS
366
367    fn test_connection(state: State) -> Connection {
368        let era_history: &EraHistory = NetworkName::Preprod.into();
369        Connection {
370            params: Params {
371                peer: Peer::new("test-peer"),
372                conn_id: ConnectionId::initial(),
373                role: Role::Initiator,
374                config: ManagerConfig::default(),
375                magic: NetworkMagic::PREPROD,
376                pipeline: StageRef::blackhole(),
377                era_history: Arc::new(era_history.clone()),
378            },
379            state,
380        }
381    }
382}