amaru_protocols/keepalive/
responder.rs1use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
16use tracing::instrument;
17
18use crate::{
19 keepalive::{
20 State,
21 messages::{Cookie, Message},
22 },
23 mux::MuxMessage,
24 protocol::{
25 Inputs, Miniprotocol, Outcome, PROTO_N2N_KEEP_ALIVE, ProtocolState, Responder, StageState, miniprotocol,
26 outcome,
27 },
28};
29
30pub fn register_deserializers() -> DeserializerGuards {
31 vec![
32 pure_stage::register_data_deserializer::<KeepAliveResponder>().boxed(),
33 pure_stage::register_data_deserializer::<(State, KeepAliveResponder)>().boxed(),
34 ]
35}
36
37pub fn responder() -> Miniprotocol<State, KeepAliveResponder, Responder> {
38 miniprotocol(PROTO_N2N_KEEP_ALIVE.responder())
39}
40
41#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
42pub struct KeepAliveResponder {
43 muxer: StageRef<MuxMessage>,
44}
45
46impl KeepAliveResponder {
47 pub fn new(muxer: StageRef<MuxMessage>) -> (State, Self) {
48 (State::Idle, Self { muxer })
49 }
50}
51
52impl StageState<State, Responder> for KeepAliveResponder {
53 type LocalIn = Void;
54
55 async fn local(
56 self,
57 _proto: &State,
58 input: Self::LocalIn,
59 _eff: &Effects<Inputs<Self::LocalIn>>,
60 ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
61 match input {}
62 }
63
64 #[instrument(name = "keepalive.responder.stage", skip_all, fields(cookie = input.cookie.as_u16()))]
65 async fn network(
66 self,
67 _proto: &State,
68 input: ResponderResult,
69 _eff: &Effects<Inputs<Self::LocalIn>>,
70 ) -> anyhow::Result<(Option<ResponderAction>, Self)> {
71 Ok((Some(ResponderAction::SendResponse(input.cookie)), self))
72 }
73
74 fn muxer(&self) -> &StageRef<MuxMessage> {
75 &self.muxer
76 }
77}
78
79impl ProtocolState<Responder> for State {
80 type WireMsg = Message;
81 type Action = ResponderAction;
82 type Out = ResponderResult;
83 type Error = Void;
84
85 fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
86 Ok((outcome().want_next(), *self))
87 }
88
89 #[instrument(name = "keepalive.responder.protocol", skip_all, fields(message_type = input.message_type()))]
90 fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
91 use State::*;
92
93 Ok(match (self, input) {
94 (Idle, Message::KeepAlive(cookie)) => (outcome().result(ResponderResult { cookie }), Waiting),
95 (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
96 })
97 }
98
99 fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
100 use State::*;
101
102 Ok(match (self, input) {
103 (Waiting, ResponderAction::SendResponse(cookie)) => {
104 (outcome().send(Message::ResponseKeepAlive(cookie)).want_next(), Idle)
105 }
106 (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
107 })
108 }
109}
110
111#[derive(Debug)]
112pub enum ResponderAction {
113 SendResponse(Cookie),
114}
115
116#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
117pub struct ResponderResult {
118 pub cookie: Cookie,
119}
120
121#[cfg(test)]
122pub mod tests {
123 use crate::{
124 keepalive::{State, messages::Message, responder::ResponderAction},
125 protocol::Responder,
126 };
127
128 #[test]
129 fn test_responder_protocol() {
130 crate::keepalive::spec::<Responder>().check(State::Idle, |msg| match msg {
131 Message::ResponseKeepAlive(cookie) => Some(ResponderAction::SendResponse(*cookie)),
133 Message::KeepAlive(_) => None,
135 Message::Done => None,
136 });
137 }
138}