Skip to main content

amaru_protocols/handshake/
responder.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
16use tracing::instrument;
17
18use crate::{
19    handshake::{State, messages::Message},
20    mux::MuxMessage,
21    protocol::{
22        Inputs, Miniprotocol, Outcome, PROTO_HANDSHAKE, ProtocolState, Responder, StageState, miniprotocol, outcome,
23    },
24    protocol_messages::{
25        handshake::{HandshakeResult, RefuseReason},
26        version_data::VersionData,
27        version_number::VersionNumber,
28        version_table::VersionTable,
29    },
30};
31
32pub fn register_deserializers() -> DeserializerGuards {
33    vec![
34        pure_stage::register_data_deserializer::<HandshakeResponder>().boxed(),
35        pure_stage::register_data_deserializer::<(State, HandshakeResponder)>().boxed(),
36    ]
37}
38
39pub fn responder() -> Miniprotocol<State, HandshakeResponder, Responder> {
40    miniprotocol(PROTO_HANDSHAKE.responder())
41}
42
43#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
44pub struct HandshakeResponder {
45    muxer: StageRef<MuxMessage>,
46    connection: StageRef<HandshakeResult>,
47    our_versions: VersionTable<VersionData>,
48}
49
50impl HandshakeResponder {
51    pub fn new(
52        muxer: StageRef<MuxMessage>,
53        connection: StageRef<HandshakeResult>,
54        version_table: VersionTable<VersionData>,
55    ) -> (State, Self) {
56        (State::Propose, Self { muxer, connection, our_versions: version_table })
57    }
58}
59
60impl StageState<State, Responder> for HandshakeResponder {
61    type LocalIn = Void;
62
63    async fn local(
64        self,
65        _proto: &State,
66        input: Self::LocalIn,
67        _eff: &Effects<Inputs<Self::LocalIn>>,
68    ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
69        match input {}
70    }
71
72    #[instrument(name = "handshake.responder.stage", skip_all, fields(version_table = %input.0))]
73    async fn network(
74        self,
75        _proto: &State,
76        input: Proposal,
77        eff: &Effects<Inputs<Self::LocalIn>>,
78    ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
79        let result = crate::handshake::compute_negotiation_result(
80            crate::protocol::Role::Responder,
81            self.our_versions.clone(),
82            input.0,
83        );
84        eff.send(&self.connection, result.clone()).await;
85        Ok((Some(result.into()), self))
86    }
87
88    fn muxer(&self) -> &StageRef<MuxMessage> {
89        &self.muxer
90    }
91}
92
93impl ProtocolState<Responder> for State {
94    type WireMsg = Message<VersionData>;
95    type Action = ResponderAction;
96    type Out = Proposal;
97    type Error = Void;
98
99    fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
100        Ok((outcome().want_next(), Self::Propose))
101    }
102
103    #[instrument(name = "handshake.responder.protocol", skip_all, fields(message_type = input.message_type()))]
104    fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
105        anyhow::ensure!(self == &Self::Propose, "handshake responder cannot receive in confirm state");
106        match (self, input) {
107            (Self::Propose, Message::Propose(version_table)) => {
108                Ok((outcome().result(Proposal(version_table)), Self::Confirm))
109            }
110            input => anyhow::bail!("invalid message from initiator: {:?}", input),
111        }
112    }
113
114    fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
115        anyhow::ensure!(self == &Self::Confirm, "handshake responder cannot send in propose state");
116        Ok(match input {
117            ResponderAction::Accept(version_number, version_data) => {
118                (outcome().send(Message::Accept(version_number, version_data)), Self::Done)
119            }
120            ResponderAction::Refuse(refuse_reason) => (outcome().send(Message::Refuse(refuse_reason)), Self::Done),
121            ResponderAction::Query(version_table) => (outcome().send(Message::QueryReply(version_table)), Self::Done),
122        })
123    }
124}
125
126#[derive(Debug)]
127pub enum ResponderAction {
128    Accept(VersionNumber, VersionData),
129    Refuse(RefuseReason),
130    Query(VersionTable<VersionData>),
131}
132
133#[derive(Debug, PartialEq)]
134pub struct Proposal(VersionTable<VersionData>);
135
136impl From<HandshakeResult> for ResponderAction {
137    fn from(result: HandshakeResult) -> Self {
138        match result {
139            HandshakeResult::Accepted(version_number, version_data) => {
140                ResponderAction::Accept(version_number, version_data)
141            }
142            HandshakeResult::Refused(reason) => ResponderAction::Refuse(reason),
143            HandshakeResult::Query(version_table) => ResponderAction::Query(version_table),
144        }
145    }
146}
147
148#[cfg(test)]
149pub mod tests {
150    use crate::{
151        handshake::{Message, State, responder::ResponderAction},
152        protocol::Responder,
153    };
154
155    #[test]
156    fn test_responder_protocol() {
157        crate::handshake::spec::<Responder>().check(State::Propose, |msg| match msg {
158            Message::Accept(vn, vd) => Some(ResponderAction::Accept(*vn, vd.clone())),
159            Message::Refuse(reason) => Some(ResponderAction::Refuse(reason.clone())),
160            Message::QueryReply(vt) => Some(ResponderAction::Query(vt.clone())),
161            _ => None,
162        });
163    }
164}