Skip to main content

amaru_protocols/
connection.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use amaru_kernel::{EraHistory, NetworkMagic, ORIGIN_HASH, Peer, Point, Tip};
18use amaru_ouroboros::{ConnectionId, ReadOnlyChainStore, TxOrigin};
19use pure_stage::{DeserializerGuards, Effects, StageRef, Void, register_data_deserializer};
20use tracing::instrument;
21
22use crate::{
23    blockfetch::{
24        self, BlockFetchMessage, Blocks, StreamBlocks, register_blockfetch_initiator, register_blockfetch_responder,
25    },
26    chainsync::{self, ChainSyncInitiatorMsg, register_chainsync_initiator, register_chainsync_responder},
27    handshake,
28    keepalive::register_keepalive,
29    manager::ManagerConfig,
30    mux::{self, HandlerMessage, MuxMessage},
31    protocol::{Inputs, PROTO_HANDSHAKE, Role},
32    protocol_messages::{
33        handshake::HandshakeResult, version_data::VersionData, version_number::VersionNumber,
34        version_table::VersionTable,
35    },
36    store_effects::Store,
37    tx_submission::register_tx_submission,
38};
39
40#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
41pub struct Connection {
42    params: Params,
43    state: State,
44}
45
46impl Connection {
47    pub fn new(
48        peer: Peer,
49        conn_id: ConnectionId,
50        role: Role,
51        config: ManagerConfig,
52        magic: NetworkMagic,
53        pipeline: StageRef<ChainSyncInitiatorMsg>,
54        era_history: Arc<EraHistory>,
55    ) -> Self {
56        Self { params: Params { peer, conn_id, role, config, magic, pipeline, era_history }, state: State::Initial }
57    }
58}
59
60#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
61struct Params {
62    peer: Peer,
63    conn_id: ConnectionId,
64    role: Role,
65    magic: NetworkMagic,
66    config: ManagerConfig,
67    pipeline: StageRef<ChainSyncInitiatorMsg>,
68    era_history: Arc<EraHistory>,
69}
70
71#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
72enum State {
73    Initial,
74    Handshake { muxer: StageRef<MuxMessage>, handshake: StageRef<Inputs<Void>> },
75    Initiator(StateInitiator),
76    Responder(StateResponder),
77}
78
79#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
80struct StateInitiator {
81    chainsync_initiator: StageRef<chainsync::InitiatorMessage>,
82    blockfetch_initiator: StageRef<blockfetch::BlockFetchMessage>,
83    version_number: VersionNumber,
84    version_data: VersionData,
85    muxer: StageRef<MuxMessage>,
86    handshake: StageRef<Inputs<Void>>,
87    keepalive: StageRef<HandlerMessage>,
88    tx_submission: StageRef<HandlerMessage>,
89}
90
91#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
92struct StateResponder {
93    chainsync_responder: StageRef<chainsync::ResponderMessage>,
94    muxer: StageRef<MuxMessage>,
95    handshake: StageRef<Inputs<Void>>,
96    keepalive: StageRef<HandlerMessage>,
97    tx_submission: StageRef<HandlerMessage>,
98    blockfetch_responder: StageRef<StreamBlocks>,
99}
100
101#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
102pub enum ConnectionMessage {
103    Initialize,
104    Disconnect,
105    Handshake(HandshakeResult),
106    FetchBlocks { from: Point, through: Point, cr: StageRef<Blocks> },
107    NewTip(Tip),
108    // LATER: make full duplex, etc.
109}
110
111impl ConnectionMessage {
112    fn message_type(&self) -> &'static str {
113        match self {
114            ConnectionMessage::Initialize => "Initialize",
115            ConnectionMessage::Disconnect => "Disconnect",
116            ConnectionMessage::Handshake(_) => "Handshake",
117            ConnectionMessage::FetchBlocks { .. } => "FetchBlocks",
118            ConnectionMessage::NewTip(_) => "NewTip",
119        }
120    }
121}
122
123#[instrument(name = "connection", skip_all, fields(message_type = msg.message_type(), conn_id = %params.conn_id, peer = %params.peer, role = ?params.role))]
124pub async fn stage(
125    Connection { params, state }: Connection,
126    msg: ConnectionMessage,
127    eff: Effects<ConnectionMessage>,
128) -> Connection {
129    let state = match (state, msg) {
130        (_, ConnectionMessage::Disconnect) => return eff.terminate().await,
131        (State::Initial, ConnectionMessage::Initialize) => do_initialize(&params, eff).await,
132        (State::Handshake { muxer, handshake }, ConnectionMessage::Handshake(handshake_result)) => {
133            do_handshake(&params, muxer, params.pipeline.clone(), handshake, handshake_result, eff).await
134        }
135        (State::Initiator(s), ConnectionMessage::FetchBlocks { from, through, cr }) => {
136            eff.send(&s.blockfetch_initiator, BlockFetchMessage::RequestRange { from, through, cr }).await;
137            State::Initiator(s)
138        }
139        (State::Responder(s), ConnectionMessage::NewTip(tip)) => {
140            eff.send(&s.chainsync_responder, chainsync::ResponderMessage::NewTip(tip)).await;
141            State::Responder(s)
142        }
143        (State::Initiator(s), ConnectionMessage::NewTip(_)) => {
144            // don't propagate new tip messages when using the initiator side of a connection.
145            State::Initiator(s)
146        }
147        (state @ (State::Initial | State::Handshake { .. }), msg @ ConnectionMessage::FetchBlocks { .. }) => {
148            // The peer might be still connecting. In that case we reschedule the message
149            // If the peer eventually can't be fully initialized, the caller timeout will trigger.
150            // We schedule after the reconnect delay (2s by default) which is shorter than the call
151            // timeout (5s) (whereas a full connection timeout is 10s).
152            eff.schedule_after(msg, params.config.reconnect_delay).await;
153            state
154        }
155        x => unimplemented!("{x:?}"),
156    };
157    Connection { params, state }
158}
159
160async fn do_initialize(Params { conn_id, role, magic, .. }: &Params, eff: Effects<ConnectionMessage>) -> State {
161    let muxer = eff.stage("mux", mux::stage).await;
162    let muxer = eff.wire_up(muxer, mux::State::new(*conn_id, &[(PROTO_HANDSHAKE.erase(), 5760)], *role)).await;
163
164    let handshake_result = eff.contramap(eff.me(), "handshake_result", ConnectionMessage::Handshake).await;
165
166    let handshake = match role {
167        Role::Initiator => {
168            eff.wire_up(
169                eff.stage("handshake", handshake::initiator()).await,
170                handshake::HandshakeInitiator::new(
171                    muxer.clone(),
172                    handshake_result,
173                    VersionTable::v11_and_above(*magic, true),
174                ),
175            )
176            .await
177        }
178        Role::Responder => {
179            eff.wire_up(
180                eff.stage("handshake", handshake::responder()).await,
181                handshake::HandshakeResponder::new(
182                    muxer.clone(),
183                    handshake_result,
184                    // Use initiator_only_diffusion_mode = false so downstream peers
185                    // know we can serve as chainsync/blockfetch server
186                    VersionTable::v11_and_above(*magic, false),
187                ),
188            )
189            .await
190        }
191    };
192
193    let handler = eff.contramap(&handshake, "handshake_bytes", Inputs::Network).await;
194
195    let protocol = match role {
196        Role::Initiator => PROTO_HANDSHAKE.erase(),
197        Role::Responder => PROTO_HANDSHAKE.responder().erase(),
198    };
199    eff.send(&muxer, MuxMessage::Register { protocol, frame: mux::Frame::OneCborItem, handler, max_buffer: 5760 })
200        .await;
201
202    State::Handshake { muxer, handshake }
203}
204
205#[expect(clippy::expect_used)]
206async fn do_handshake(
207    Params { role, peer, conn_id, era_history, .. }: &Params,
208    muxer: StageRef<MuxMessage>,
209    pipeline: StageRef<ChainSyncInitiatorMsg>,
210    handshake: StageRef<Inputs<Void>>,
211    handshake_result: HandshakeResult,
212    eff: Effects<ConnectionMessage>,
213) -> State {
214    let (version_number, version_data) = match handshake_result {
215        HandshakeResult::Accepted(version_number, version_data) => (version_number, version_data),
216        HandshakeResult::Refused(refuse_reason) => {
217            tracing::error!(?refuse_reason, "handshake refused");
218            return eff.terminate().await;
219        }
220        HandshakeResult::Query(version_table) => {
221            tracing::info!(?version_table, "handshake query reply");
222            return eff.terminate().await;
223        }
224    };
225
226    let keepalive = register_keepalive(*role, muxer.clone(), &eff).await;
227    let tx_submission = register_tx_submission(*role, muxer.clone(), &eff, TxOrigin::Remote(peer.clone())).await;
228
229    if *role == Role::Initiator {
230        let chainsync_initiator = register_chainsync_initiator(&muxer, peer.clone(), *conn_id, pipeline, &eff).await;
231        let blockfetch_initiator =
232            register_blockfetch_initiator(&muxer, peer.clone(), *conn_id, era_history.clone(), &eff).await;
233        State::Initiator(StateInitiator {
234            chainsync_initiator,
235            blockfetch_initiator,
236            version_number,
237            version_data,
238            muxer,
239            handshake,
240            keepalive,
241            tx_submission,
242        })
243    } else {
244        let store = Store::new(eff.clone());
245        let upstream = store.get_best_chain_hash();
246        let upstream = if upstream == ORIGIN_HASH {
247            Tip::new(Point::Origin, 0.into())
248        } else {
249            let header = store.load_header(&upstream).expect("best chain hash not found");
250            header.tip()
251        };
252        let chainsync_responder = register_chainsync_responder(&muxer, upstream, peer.clone(), *conn_id, &eff).await;
253        let blockfetch_responder = register_blockfetch_responder(&muxer, &eff).await;
254
255        State::Responder(StateResponder {
256            chainsync_responder,
257            blockfetch_responder,
258            muxer,
259            handshake,
260            keepalive,
261            tx_submission,
262        })
263    }
264}
265
266pub fn register_deserializers() -> DeserializerGuards {
267    vec![
268        register_data_deserializer::<(ConnectionId, StageRef<mux::MuxMessage>, Role)>().boxed(),
269        register_data_deserializer::<Connection>().boxed(),
270        register_data_deserializer::<ConnectionMessage>().boxed(),
271    ]
272}
273
274#[cfg(test)]
275mod tests {
276    use amaru_kernel::NetworkName;
277    use pure_stage::{Effect, StageGraph, simulation::SimulationBuilder};
278
279    use super::*;
280
281    #[test]
282    fn test_fetch_blocks_in_initial_state_reschedules() {
283        fetch_blocks_in_disconnected_state_reschedules(State::Initial);
284    }
285
286    #[test]
287    fn test_fetch_blocks_in_handshake_state_reschedules() {
288        let handshake_state = State::Handshake { muxer: StageRef::blackhole(), handshake: StageRef::blackhole() };
289        fetch_blocks_in_disconnected_state_reschedules(handshake_state);
290    }
291
292    // HELPERS
293
294    fn test_connection(state: State) -> Connection {
295        let era_history: &EraHistory = NetworkName::Preprod.into();
296        Connection {
297            params: Params {
298                peer: Peer::new("test-peer"),
299                conn_id: ConnectionId::initial(),
300                role: Role::Initiator,
301                config: ManagerConfig::default(),
302                magic: NetworkMagic::PREPROD,
303                pipeline: StageRef::blackhole(),
304                era_history: Arc::new(era_history.clone()),
305            },
306            state,
307        }
308    }
309
310    fn fetch_blocks_in_disconnected_state_reschedules(connection_state: State) {
311        let mut network = SimulationBuilder::default();
312
313        let connection_stage = network.stage("connection", stage);
314        let connection_stage = network.wire_up(connection_stage, test_connection(connection_state.clone()));
315
316        let (blocks_output, _rx) = network.output::<Blocks>("blocks_output", 10);
317
318        let fetch_msg =
319            ConnectionMessage::FetchBlocks { from: Point::Origin, through: Point::Origin, cr: blocks_output };
320
321        network.preload(&connection_stage, [fetch_msg]).unwrap();
322
323        let mut running = network.run();
324        let start_time = running.now();
325
326        let stage_name = connection_stage.name().clone();
327        running.breakpoint(
328            "schedule",
329            move |eff| matches!(eff, Effect::Schedule { at_stage, .. } if *at_stage == stage_name),
330        );
331
332        let effect = running.run_until_blocked().assert_breakpoint("schedule");
333
334        let reconnect_delay = ManagerConfig::default().reconnect_delay;
335        if let Effect::Schedule { id, .. } = &effect {
336            let delay = id.time().checked_since(start_time).unwrap();
337            assert!(delay >= reconnect_delay);
338        } else {
339            panic!("Expected Schedule effect");
340        }
341
342        // Clear the breakpoint before continuing
343        running.clear_breakpoint("schedule");
344        running.handle_effect(effect);
345
346        // Let the simulation continue until blocked (will hit the scheduled wake up)
347        running.run_until_sleeping_or_blocked().assert_sleeping();
348
349        // Verify state remains the same
350        let state = running.get_state(&connection_stage).unwrap();
351        assert_eq!(state.state, connection_state);
352    }
353}