amaru_protocols/handshake/
responder.rs1use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
16use tracing::instrument;
17
18use crate::{
19 handshake::{State, messages::Message},
20 mux::MuxMessage,
21 protocol::{
22 Inputs, Miniprotocol, Outcome, PROTO_HANDSHAKE, ProtocolState, Responder, StageState, miniprotocol, outcome,
23 },
24 protocol_messages::{
25 handshake::{HandshakeResult, RefuseReason},
26 version_data::VersionData,
27 version_number::VersionNumber,
28 version_table::VersionTable,
29 },
30};
31
32pub fn register_deserializers() -> DeserializerGuards {
33 vec![
34 pure_stage::register_data_deserializer::<HandshakeResponder>().boxed(),
35 pure_stage::register_data_deserializer::<(State, HandshakeResponder)>().boxed(),
36 ]
37}
38
39pub fn responder() -> Miniprotocol<State, HandshakeResponder, Responder> {
40 miniprotocol(PROTO_HANDSHAKE.responder())
41}
42
43#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
44pub struct HandshakeResponder {
45 muxer: StageRef<MuxMessage>,
46 connection: StageRef<HandshakeResult>,
47 our_versions: VersionTable<VersionData>,
48}
49
50impl HandshakeResponder {
51 pub fn new(
52 muxer: StageRef<MuxMessage>,
53 connection: StageRef<HandshakeResult>,
54 version_table: VersionTable<VersionData>,
55 ) -> (State, Self) {
56 (State::Propose, Self { muxer, connection, our_versions: version_table })
57 }
58}
59
60impl StageState<State, Responder> for HandshakeResponder {
61 type LocalIn = Void;
62
63 async fn local(
64 self,
65 _proto: &State,
66 input: Self::LocalIn,
67 _eff: &Effects<Inputs<Self::LocalIn>>,
68 ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
69 match input {}
70 }
71
72 #[instrument(name = "handshake.responder.stage", skip_all, fields(version_table = %input.0))]
73 async fn network(
74 self,
75 _proto: &State,
76 input: Proposal,
77 eff: &Effects<Inputs<Self::LocalIn>>,
78 ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
79 let result = crate::handshake::compute_negotiation_result(
80 crate::protocol::Role::Responder,
81 self.our_versions.clone(),
82 input.0,
83 );
84 eff.send(&self.connection, result.clone()).await;
85 Ok((Some(result.into()), self))
86 }
87
88 fn muxer(&self) -> &StageRef<MuxMessage> {
89 &self.muxer
90 }
91}
92
93impl ProtocolState<Responder> for State {
94 type WireMsg = Message<VersionData>;
95 type Action = ResponderAction;
96 type Out = Proposal;
97 type Error = Void;
98
99 fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
100 Ok((outcome().want_next(), Self::Propose))
101 }
102
103 #[instrument(name = "handshake.responder.protocol", skip_all, fields(message_type = input.message_type()))]
104 fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
105 anyhow::ensure!(self == &Self::Propose, "handshake responder cannot receive in confirm state");
106 match (self, input) {
107 (Self::Propose, Message::Propose(version_table)) => {
108 Ok((outcome().result(Proposal(version_table)), Self::Confirm))
109 }
110 input => anyhow::bail!("invalid message from initiator: {:?}", input),
111 }
112 }
113
114 fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
115 anyhow::ensure!(self == &Self::Confirm, "handshake responder cannot send in propose state");
116 Ok(match input {
117 ResponderAction::Accept(version_number, version_data) => {
118 (outcome().send(Message::Accept(version_number, version_data)), Self::Done)
119 }
120 ResponderAction::Refuse(refuse_reason) => (outcome().send(Message::Refuse(refuse_reason)), Self::Done),
121 ResponderAction::Query(version_table) => (outcome().send(Message::QueryReply(version_table)), Self::Done),
122 })
123 }
124}
125
126#[derive(Debug)]
127pub enum ResponderAction {
128 Accept(VersionNumber, VersionData),
129 Refuse(RefuseReason),
130 Query(VersionTable<VersionData>),
131}
132
133#[derive(Debug, PartialEq)]
134pub struct Proposal(VersionTable<VersionData>);
135
136impl From<HandshakeResult> for ResponderAction {
137 fn from(result: HandshakeResult) -> Self {
138 match result {
139 HandshakeResult::Accepted(version_number, version_data) => {
140 ResponderAction::Accept(version_number, version_data)
141 }
142 HandshakeResult::Refused(reason) => ResponderAction::Refuse(reason),
143 HandshakeResult::Query(version_table) => ResponderAction::Query(version_table),
144 }
145 }
146}
147
148#[cfg(test)]
149pub mod tests {
150 use crate::{
151 handshake::{Message, State, responder::ResponderAction},
152 protocol::Responder,
153 };
154
155 #[test]
156 fn test_responder_protocol() {
157 crate::handshake::spec::<Responder>().check(State::Propose, |msg| match msg {
158 Message::Accept(vn, vd) => Some(ResponderAction::Accept(*vn, vd.clone())),
159 Message::Refuse(reason) => Some(ResponderAction::Refuse(reason.clone())),
160 Message::QueryReply(vt) => Some(ResponderAction::Query(vt.clone())),
161 _ => None,
162 });
163 }
164}