amaru_protocols/handshake/
initiator.rs1use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
16use tracing::instrument;
17
18use crate::{
19 handshake::{State, messages::Message},
20 mux::MuxMessage,
21 protocol::{
22 Initiator, Inputs, Miniprotocol, Outcome, PROTO_HANDSHAKE, ProtocolState, StageState, miniprotocol, outcome,
23 },
24 protocol_messages::{handshake::HandshakeResult, version_data::VersionData, version_table::VersionTable},
25};
26
27pub fn register_deserializers() -> DeserializerGuards {
28 vec![
29 pure_stage::register_data_deserializer::<HandshakeInitiator>().boxed(),
30 pure_stage::register_data_deserializer::<(State, HandshakeInitiator)>().boxed(),
31 ]
32}
33
34pub fn initiator() -> Miniprotocol<State, HandshakeInitiator, Initiator> {
35 miniprotocol(PROTO_HANDSHAKE)
36}
37
38#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
39pub struct HandshakeInitiator {
40 muxer: StageRef<MuxMessage>,
41 connection: StageRef<HandshakeResult>,
42 our_versions: VersionTable<VersionData>,
43}
44
45impl HandshakeInitiator {
46 pub fn new(
47 muxer: StageRef<MuxMessage>,
48 connection: StageRef<HandshakeResult>,
49 version_table: VersionTable<VersionData>,
50 ) -> (State, Self) {
51 (State::Propose, Self { muxer, connection, our_versions: version_table })
52 }
53}
54
55impl StageState<State, Initiator> for HandshakeInitiator {
56 type LocalIn = Void;
57
58 async fn local(
59 self,
60 _proto: &State,
61 input: Self::LocalIn,
62 _eff: &Effects<Inputs<Self::LocalIn>>,
63 ) -> anyhow::Result<(Option<InitiatorAction>, Self)> {
64 match input {}
65 }
66
67 #[instrument(name = "handshake.initiator.stage", skip_all, fields(message_type = input.message_type()))]
68 async fn network(
69 self,
70 _proto: &State,
71 input: InitiatorResult,
72 eff: &Effects<Inputs<Self::LocalIn>>,
73 ) -> anyhow::Result<(Option<InitiatorAction>, Self)> {
74 Ok(match input {
75 InitiatorResult::Propose => {
76 tracing::debug!(?self.our_versions, "proposing versions");
77 (Some(InitiatorAction::Propose(self.our_versions.clone())), self)
78 }
79 InitiatorResult::Conclusion(handshake_result) => {
80 tracing::debug!(?handshake_result, "conclusion");
81 eff.send(&self.connection, handshake_result).await;
82 (None, self)
83 }
84 InitiatorResult::SimOpen(version_table) => {
85 tracing::debug!(?version_table, "simultaneous open");
86 let result = crate::handshake::compute_negotiation_result(
87 crate::protocol::Role::Initiator,
88 self.our_versions.clone(),
89 version_table,
90 );
91 eff.send(&self.connection, result).await;
92 (None, self)
93 }
94 })
95 }
96
97 fn muxer(&self) -> &StageRef<MuxMessage> {
98 &self.muxer
99 }
100}
101
102impl ProtocolState<Initiator> for State {
103 type WireMsg = Message<VersionData>;
104 type Action = InitiatorAction;
105 type Out = InitiatorResult;
106 type Error = Void;
107
108 fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
109 Ok((outcome().result(InitiatorResult::Propose), Self::Propose))
110 }
111
112 #[instrument(name = "handshake.initiator.protocol", skip_all, fields(message_type = input.message_type()))]
113 fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
114 anyhow::ensure!(self == &Self::Confirm, "handshake initiator cannot receive in initial state");
115 Ok(match input {
116 Message::Propose(version_table) => {
117 (outcome().result(InitiatorResult::SimOpen(version_table)), Self::Done)
119 }
120 Message::Accept(version_number, version_data) => (
121 outcome().result(InitiatorResult::Conclusion(HandshakeResult::Accepted(version_number, version_data))),
122 Self::Done,
123 ),
124 Message::Refuse(refuse_reason) => {
125 (outcome().result(InitiatorResult::Conclusion(HandshakeResult::Refused(refuse_reason))), Self::Done)
126 }
127 Message::QueryReply(version_table) => {
128 (outcome().result(InitiatorResult::Conclusion(HandshakeResult::Query(version_table))), Self::Done)
129 }
130 })
131 }
132
133 fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
134 anyhow::ensure!(self == &Self::Propose, "handshake initiator cannot send in confirmation state");
135 Ok(match input {
136 InitiatorAction::Propose(version_table) => {
137 (outcome().send(Message::Propose(version_table)).want_next(), Self::Confirm)
138 }
139 })
140 }
141}
142
143#[derive(Debug, PartialEq)]
144pub enum InitiatorResult {
145 Propose,
146 Conclusion(HandshakeResult),
147 SimOpen(VersionTable<VersionData>),
148}
149
150impl InitiatorResult {
151 pub fn message_type(&self) -> &str {
152 match self {
153 Self::Propose => "Propose",
154 Self::Conclusion(_) => "Conclusion",
155 Self::SimOpen(_) => "SimOpen",
156 }
157 }
158}
159
160#[derive(Debug)]
161pub enum InitiatorAction {
162 Propose(VersionTable<VersionData>),
163}
164
165#[cfg(test)]
166#[expect(clippy::wildcard_enum_match_arm)]
167pub mod tests {
168 use super::*;
169 use crate::protocol::Initiator;
170
171 #[test]
172 fn test_initiator_protocol() {
173 crate::handshake::spec::<Initiator>().check(State::Propose, |msg| match msg {
174 Message::Propose(version_table) => Some(InitiatorAction::Propose(version_table.clone())),
175 _ => None,
176 });
177 }
178}