Skip to main content

amaru_protocols/keepalive/
responder.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
16use tracing::instrument;
17
18use crate::{
19    keepalive::{
20        State,
21        messages::{Cookie, Message},
22    },
23    mux::MuxMessage,
24    protocol::{
25        Inputs, Miniprotocol, Outcome, PROTO_N2N_KEEP_ALIVE, ProtocolState, Responder, StageState, miniprotocol,
26        outcome,
27    },
28};
29
30pub fn register_deserializers() -> DeserializerGuards {
31    vec![
32        pure_stage::register_data_deserializer::<KeepAliveResponder>().boxed(),
33        pure_stage::register_data_deserializer::<(State, KeepAliveResponder)>().boxed(),
34    ]
35}
36
37pub fn responder() -> Miniprotocol<State, KeepAliveResponder, Responder> {
38    miniprotocol(PROTO_N2N_KEEP_ALIVE.responder())
39}
40
41#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
42pub struct KeepAliveResponder {
43    muxer: StageRef<MuxMessage>,
44}
45
46impl KeepAliveResponder {
47    pub fn new(muxer: StageRef<MuxMessage>) -> (State, Self) {
48        (State::Idle, Self { muxer })
49    }
50}
51
52impl StageState<State, Responder> for KeepAliveResponder {
53    type LocalIn = Void;
54
55    async fn local(
56        self,
57        _proto: &State,
58        input: Self::LocalIn,
59        _eff: &Effects<Inputs<Self::LocalIn>>,
60    ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
61        match input {}
62    }
63
64    #[instrument(name = "keepalive.responder.stage", skip_all, fields(cookie = input.cookie.as_u16()))]
65    async fn network(
66        self,
67        _proto: &State,
68        input: ResponderResult,
69        _eff: &Effects<Inputs<Self::LocalIn>>,
70    ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
71        Ok((Some(ResponderAction::SendResponse(input.cookie)), self))
72    }
73
74    fn muxer(&self) -> &StageRef<MuxMessage> {
75        &self.muxer
76    }
77}
78
79impl ProtocolState<Responder> for State {
80    type WireMsg = Message;
81    type Action = ResponderAction;
82    type Out = ResponderResult;
83    type Error = Void;
84
85    fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
86        Ok((outcome().want_next(), *self))
87    }
88
89    #[instrument(name = "keepalive.responder.protocol", skip_all, fields(message_type = input.message_type()))]
90    fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
91        use State::*;
92
93        Ok(match (self, input) {
94            (Idle, Message::KeepAlive(cookie)) => (outcome().result(ResponderResult { cookie }), Waiting),
95            (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
96        })
97    }
98
99    fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
100        use State::*;
101
102        Ok(match (self, input) {
103            (Waiting, ResponderAction::SendResponse(cookie)) => {
104                (outcome().send(Message::ResponseKeepAlive(cookie)).want_next(), Idle)
105            }
106            (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
107        })
108    }
109}
110
111#[derive(Debug)]
112pub enum ResponderAction {
113    SendResponse(Cookie),
114}
115
116#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
117pub struct ResponderResult {
118    pub cookie: Cookie,
119}
120
121#[cfg(test)]
122pub mod tests {
123    use crate::{
124        keepalive::{State, messages::Message, responder::ResponderAction},
125        protocol::Responder,
126    };
127
128    #[test]
129    fn test_responder_protocol() {
130        crate::keepalive::spec::<Responder>().check(State::Idle, |msg| match msg {
131            // ResponseKeepAlive is sent by responder (local action)
132            Message::ResponseKeepAlive(cookie) => Some(ResponderAction::SendResponse(*cookie)),
133            // KeepAlive is received from initiator (network message)
134            Message::KeepAlive(_) => None,
135            Message::Done => None,
136        });
137    }
138}