Skip to main content

amaru_protocols/handshake/
initiator.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
16use tracing::instrument;
17
18use crate::{
19    handshake::{State, messages::Message},
20    mux::MuxMessage,
21    protocol::{
22        Initiator, Inputs, Miniprotocol, Outcome, PROTO_HANDSHAKE, ProtocolState, StageState, miniprotocol, outcome,
23    },
24    protocol_messages::{handshake::HandshakeResult, version_data::VersionData, version_table::VersionTable},
25};
26
27pub fn register_deserializers() -> DeserializerGuards {
28    vec![
29        pure_stage::register_data_deserializer::<HandshakeInitiator>().boxed(),
30        pure_stage::register_data_deserializer::<(State, HandshakeInitiator)>().boxed(),
31    ]
32}
33
34pub fn initiator() -> Miniprotocol<State, HandshakeInitiator, Initiator> {
35    miniprotocol(PROTO_HANDSHAKE)
36}
37
38#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
39pub struct HandshakeInitiator {
40    muxer: StageRef<MuxMessage>,
41    connection: StageRef<HandshakeResult>,
42    our_versions: VersionTable<VersionData>,
43}
44
45impl HandshakeInitiator {
46    pub fn new(
47        muxer: StageRef<MuxMessage>,
48        connection: StageRef<HandshakeResult>,
49        version_table: VersionTable<VersionData>,
50    ) -> (State, Self) {
51        (State::Propose, Self { muxer, connection, our_versions: version_table })
52    }
53}
54
55impl StageState<State, Initiator> for HandshakeInitiator {
56    type LocalIn = Void;
57
58    async fn local(
59        self,
60        _proto: &State,
61        input: Self::LocalIn,
62        _eff: &Effects<Inputs<Self::LocalIn>>,
63    ) -> anyhow::Result<(Option<InitiatorAction>, Self)> {
64        match input {}
65    }
66
67    #[instrument(name = "handshake.initiator.stage", skip_all, fields(message_type = input.message_type()))]
68    async fn network(
69        self,
70        _proto: &State,
71        input: InitiatorResult,
72        eff: &Effects<Inputs<Self::LocalIn>>,
73    ) -> anyhow::Result<(Option<InitiatorAction>, Self)> {
74        Ok(match input {
75            InitiatorResult::Propose => {
76                tracing::debug!(?self.our_versions, "proposing versions");
77                (Some(InitiatorAction::Propose(self.our_versions.clone())), self)
78            }
79            InitiatorResult::Conclusion(handshake_result) => {
80                tracing::debug!(?handshake_result, "conclusion");
81                eff.send(&self.connection, handshake_result).await;
82                (None, self)
83            }
84            InitiatorResult::SimOpen(version_table) => {
85                tracing::debug!(?version_table, "simultaneous open");
86                let result = crate::handshake::compute_negotiation_result(
87                    crate::protocol::Role::Initiator,
88                    self.our_versions.clone(),
89                    version_table,
90                );
91                eff.send(&self.connection, result).await;
92                (None, self)
93            }
94        })
95    }
96
97    fn muxer(&self) -> &StageRef<MuxMessage> {
98        &self.muxer
99    }
100}
101
102impl ProtocolState<Initiator> for State {
103    type WireMsg = Message<VersionData>;
104    type Action = InitiatorAction;
105    type Out = InitiatorResult;
106    type Error = Void;
107
108    fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
109        Ok((outcome().result(InitiatorResult::Propose), Self::Propose))
110    }
111
112    #[instrument(name = "handshake.initiator.protocol", skip_all, fields(message_type = input.message_type()))]
113    fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
114        anyhow::ensure!(self == &Self::Confirm, "handshake initiator cannot receive in initial state");
115        Ok(match input {
116            Message::Propose(version_table) => {
117                // TCP simultaneous open
118                (outcome().result(InitiatorResult::SimOpen(version_table)), Self::Done)
119            }
120            Message::Accept(version_number, version_data) => (
121                outcome().result(InitiatorResult::Conclusion(HandshakeResult::Accepted(version_number, version_data))),
122                Self::Done,
123            ),
124            Message::Refuse(refuse_reason) => {
125                (outcome().result(InitiatorResult::Conclusion(HandshakeResult::Refused(refuse_reason))), Self::Done)
126            }
127            Message::QueryReply(version_table) => {
128                (outcome().result(InitiatorResult::Conclusion(HandshakeResult::Query(version_table))), Self::Done)
129            }
130        })
131    }
132
133    fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
134        anyhow::ensure!(self == &Self::Propose, "handshake initiator cannot send in confirmation state");
135        Ok(match input {
136            InitiatorAction::Propose(version_table) => {
137                (outcome().send(Message::Propose(version_table)).want_next(), Self::Confirm)
138            }
139        })
140    }
141}
142
143#[derive(Debug, PartialEq)]
144pub enum InitiatorResult {
145    Propose,
146    Conclusion(HandshakeResult),
147    SimOpen(VersionTable<VersionData>),
148}
149
150impl InitiatorResult {
151    pub fn message_type(&self) -> &str {
152        match self {
153            Self::Propose => "Propose",
154            Self::Conclusion(_) => "Conclusion",
155            Self::SimOpen(_) => "SimOpen",
156        }
157    }
158}
159
160#[derive(Debug)]
161pub enum InitiatorAction {
162    Propose(VersionTable<VersionData>),
163}
164
165#[cfg(test)]
166#[expect(clippy::wildcard_enum_match_arm)]
167pub mod tests {
168    use super::*;
169    use crate::protocol::Initiator;
170
171    #[test]
172    fn test_initiator_protocol() {
173        crate::handshake::spec::<Initiator>().check(State::Propose, |msg| match msg {
174            Message::Propose(version_table) => Some(InitiatorAction::Propose(version_table.clone())),
175            _ => None,
176        });
177    }
178}