Skip to main content

amaru_protocols/keepalive/
initiator.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::time::Duration;
16
17use pure_stage::{DeserializerGuards, Effects, StageRef, Void};
18use tracing::instrument;
19
20use crate::{
21    keepalive::{
22        State,
23        messages::{Cookie, Message},
24    },
25    mux::MuxMessage,
26    protocol::{
27        Initiator, Inputs, Miniprotocol, Outcome, PROTO_N2N_KEEP_ALIVE, ProtocolState, StageState, miniprotocol,
28        outcome,
29    },
30};
31
32pub fn register_deserializers() -> DeserializerGuards {
33    vec![
34        pure_stage::register_data_deserializer::<InitiatorMessage>().boxed(),
35        pure_stage::register_data_deserializer::<(State, KeepAliveInitiator)>().boxed(),
36        pure_stage::register_data_deserializer::<KeepAliveInitiator>().boxed(),
37    ]
38}
39
40pub fn initiator() -> Miniprotocol<State, KeepAliveInitiator, Initiator> {
41    miniprotocol(PROTO_N2N_KEEP_ALIVE)
42}
43
44/// Message sent to the handler to trigger periodic keep-alive sends
45#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
46pub enum InitiatorMessage {
47    SendKeepAlive,
48}
49
50/// Message sent from the handler (for future use, e.g., RTT reporting)
51#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
52pub struct InitiatorResult {
53    pub cookie: Cookie,
54}
55
56#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
57pub struct KeepAliveInitiator {
58    cookie: Cookie,
59    muxer: StageRef<MuxMessage>,
60}
61
62impl KeepAliveInitiator {
63    pub fn new(muxer: StageRef<MuxMessage>) -> (State, Self) {
64        (State::Idle, Self { cookie: Cookie::new(), muxer })
65    }
66}
67
68impl StageState<State, Initiator> for KeepAliveInitiator {
69    type LocalIn = InitiatorMessage;
70
71    async fn local(
72        self,
73        proto: &State,
74        input: Self::LocalIn,
75        _eff: &Effects<Inputs<Self::LocalIn>>,
76    ) -> anyhow::Result<(Option<InitiatorAction>, Self)> {
77        use State::*;
78
79        match (proto, input) {
80            (Idle, InitiatorMessage::SendKeepAlive) => Ok((Some(InitiatorAction::SendKeepAlive(self.cookie)), self)),
81            (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
82        }
83    }
84
85    #[instrument(name = "keepalive.initiator.stage", skip_all, fields(cookie = input.cookie.as_u16()))]
86    async fn network(
87        mut self,
88        _proto: &State,
89        input: InitiatorResult,
90        eff: &Effects<Inputs<Self::LocalIn>>,
91    ) -> anyhow::Result<(Option<InitiatorAction>, Self)> {
92        // After receiving a response, increment cookie and schedule next send
93        self.cookie = input.cookie.next();
94        let delay = if u16::from(input.cookie) == 0 {
95            // this is only for the very first keep-alive message, which the Haskell node expects within the first
96            // five seconds
97            Duration::from_secs(1)
98        } else {
99            Duration::from_secs(30)
100        };
101        eff.schedule_after(Inputs::Local(InitiatorMessage::SendKeepAlive), delay).await;
102        Ok((None, self))
103    }
104
105    fn muxer(&self) -> &StageRef<MuxMessage> {
106        &self.muxer
107    }
108}
109
110impl ProtocolState<Initiator> for State {
111    type WireMsg = Message;
112    type Action = InitiatorAction;
113    type Out = InitiatorResult;
114    type Error = Void;
115
116    fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
117        // On init, trigger the first KeepAlive send via the StageState to set timers in motion
118        Ok((outcome().result(InitiatorResult { cookie: Cookie::new() }), *self))
119    }
120
121    #[instrument(name = "keepalive.initiator.protocol", skip_all, fields(message_type = input.message_type()))]
122    fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)> {
123        use State::*;
124
125        Ok(match (self, input) {
126            (Waiting, Message::ResponseKeepAlive(cookie)) => (outcome().result(InitiatorResult { cookie }), Idle),
127            (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
128        })
129    }
130
131    fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)> {
132        use State::*;
133
134        Ok(match (self, input) {
135            (Idle, InitiatorAction::SendKeepAlive(cookie)) => {
136                (outcome().send(Message::KeepAlive(cookie)).want_next(), Waiting)
137            }
138            (this, input) => anyhow::bail!("invalid state: {:?} <- {:?}", this, input),
139        })
140    }
141}
142
143#[derive(Debug)]
144pub enum InitiatorAction {
145    SendKeepAlive(Cookie),
146}
147
148#[cfg(test)]
149#[expect(clippy::wildcard_enum_match_arm)]
150pub mod tests {
151    use crate::{
152        keepalive::{State, initiator::InitiatorAction, messages::Message},
153        protocol::Initiator,
154    };
155
156    #[test]
157    fn test_initiator_protocol() {
158        crate::keepalive::spec::<Initiator>().check(State::Idle, |msg| match msg {
159            Message::KeepAlive(cookie) => Some(InitiatorAction::SendKeepAlive(*cookie)),
160            _ => None,
161        });
162    }
163}