Skip to main content

amaru_protocols/protocol/
miniprotocol.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::future::Future;
16
17use amaru_kernel::{NonEmptyBytes, cbor};
18use pure_stage::{BoxFuture, Effects, SendData, StageRef, TryInStage, Void, err};
19
20use crate::{
21    mux::{HandlerMessage, MuxMessage},
22    protocol::{NETWORK_SEND_TIMEOUT, ProtocolId, RoleT},
23};
24
25/// An input to a miniprotocol handler stage.
26#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
27pub enum Inputs<L> {
28    Local(L),
29    Network(HandlerMessage),
30}
31
32/// Outcome of a protocol step
33#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
34pub struct Outcome<S, R, E> {
35    pub send: Option<S>,
36    pub result: Option<R>,
37    pub terminate_with: Option<E>,
38    pub want_next: bool,
39}
40
41impl<S, R, E> Outcome<S, R, E> {
42    pub fn send(self, send: S) -> Self {
43        Self { send: Some(send), result: self.result, terminate_with: self.terminate_with, want_next: self.want_next }
44    }
45
46    pub fn result(self, done: R) -> Self {
47        Self { send: self.send, result: Some(done), terminate_with: self.terminate_with, want_next: self.want_next }
48    }
49
50    pub fn want_next(self) -> Self {
51        Self { send: self.send, result: self.result, terminate_with: self.terminate_with, want_next: true }
52    }
53
54    pub fn terminate_with(self, e: E) -> Self {
55        Self { send: self.send, result: self.result, terminate_with: Some(e), want_next: self.want_next }
56    }
57
58    pub fn without_result(self) -> Outcome<S, Void, E> {
59        Outcome { send: self.send, result: None, terminate_with: self.terminate_with, want_next: self.want_next }
60    }
61}
62
63pub fn outcome<S, R, E>() -> Outcome<S, R, E> {
64    Outcome { send: None, result: None, terminate_with: None, want_next: false }
65}
66
67/// This tracks only the network protocol state, reacting to local decisions
68/// (`Action`) or incoming network messages (`WireMsg`). It may emit information
69/// via the `Out` type.
70pub trait ProtocolState<R: RoleT>: Sized + SendData {
71    type WireMsg: for<'de> cbor::Decode<'de, ()> + cbor::Encode<()> + Send;
72    type Action: std::fmt::Debug + Send;
73    type Out: std::fmt::Debug + PartialEq + Send;
74    type Error: std::fmt::Debug + std::fmt::Display + PartialEq + Send;
75
76    fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)>;
77
78    fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)>;
79
80    fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)>;
81}
82
83/// This tracks the stage state that is used to make decisions based on inputs
84/// from the local node (`LocalIn`) or incoming network messages (`NetworkIn`).
85/// It may emit network actions to be performed via the `Action` type.
86pub trait StageState<Proto: ProtocolState<R>, R: RoleT>: Sized + SendData {
87    type LocalIn: SendData;
88
89    fn local(
90        self,
91        proto: &Proto,
92        input: Self::LocalIn,
93        eff: &Effects<Inputs<Self::LocalIn>>,
94    ) -> impl Future<Output = anyhow::Result<(Option<Proto::Action>, Self)>> + Send;
95
96    fn network(
97        self,
98        proto: &Proto,
99        input: Proto::Out,
100        eff: &Effects<Inputs<Self::LocalIn>>,
101    ) -> impl Future<Output = anyhow::Result<(Option<Proto::Action>, Self)>> + Send;
102
103    fn muxer(&self) -> &StageRef<MuxMessage>;
104}
105
106pub type Miniprotocol<A, B, R>
107where
108    A: ProtocolState<R>,
109    B: StageState<A, R>,
110    R: RoleT,
111= impl Fn((A, B), Inputs<B::LocalIn>, Effects<Inputs<B::LocalIn>>) -> BoxFuture<'static, (A, B)> + Send + 'static;
112
113/// A miniprotocol is described using two states:
114/// - `S`: the protocol state that tracks the network protocol state
115/// - `S2`: the stage state that tracks the stage state
116///
117/// It is important to clearly separate these two, with `S2` being
118/// responsible for decision making and `S` only following the protocol.
119#[define_opaque(Miniprotocol)]
120pub fn miniprotocol<Proto, Stage, Role>(proto_id: ProtocolId<Role>) -> Miniprotocol<Proto, Stage, Role>
121where
122    Proto: ProtocolState<Role>,
123    Stage: StageState<Proto, Role>,
124    Role: RoleT,
125{
126    enum LocalOrNetwork<L, A> {
127        Local(L),
128        Network(A),
129        None,
130    }
131
132    move |(mut proto, mut stage), input, eff| {
133        Box::pin(async move {
134            // handle network input, if any
135            let local_or_network = match input {
136                Inputs::Network(wire_msg) => {
137                    let (outcome, s) = if let HandlerMessage::FromNetwork(wire_msg) = wire_msg {
138                        let wire_msg: Proto::WireMsg = cbor::decode(&wire_msg)
139                            .or_terminate(&eff, err("failed to decode message from network"))
140                            .await;
141                        proto.network(wire_msg).or_terminate(&eff, err("failed to step protocol state (network)")).await
142                    } else {
143                        proto.init().or_terminate(&eff, err("failed to initialize protocol state")).await
144                    };
145                    proto = s;
146                    if outcome.want_next {
147                        eff.send(stage.muxer(), MuxMessage::WantNext(proto_id.erase())).await;
148                    }
149                    if let Some(msg) = outcome.send {
150                        let msg = NonEmptyBytes::encode(&msg);
151                        eff.call(stage.muxer(), NETWORK_SEND_TIMEOUT, move |cr| {
152                            MuxMessage::Send(proto_id.erase(), msg, cr)
153                        })
154                        .await;
155                    }
156                    outcome.result.map(LocalOrNetwork::Network).unwrap_or(LocalOrNetwork::None)
157                }
158                Inputs::Local(input) => LocalOrNetwork::Local(input),
159            };
160
161            // run decision making, if there was new information
162            let action = match local_or_network {
163                LocalOrNetwork::Local(local) => {
164                    let (action, s) = stage
165                        .local(&proto, local, &eff)
166                        .await
167                        .or_terminate(&eff, err("failed to step stage state (local)"))
168                        .await;
169                    stage = s;
170                    action
171                }
172                LocalOrNetwork::Network(network) => {
173                    let (action, s) = stage
174                        .network(&proto, network, &eff)
175                        .await
176                        .or_terminate(&eff, err("failed to step stage state (network)"))
177                        .await;
178                    stage = s;
179                    action
180                }
181                LocalOrNetwork::None => None,
182            };
183
184            // send network messages, if required
185            if let Some(action) = action {
186                let (outcome, s) =
187                    proto.local(action).or_terminate(&eff, err("failed to step protocol state (local)")).await;
188                proto = s;
189                if let Some(e) = outcome.terminate_with {
190                    err("protocol error")(e).await;
191                    return eff.terminate().await;
192                }
193                if outcome.want_next {
194                    eff.send(stage.muxer(), MuxMessage::WantNext(proto_id.erase())).await;
195                }
196                if let Some(msg) = outcome.send {
197                    let msg = NonEmptyBytes::encode(&msg);
198                    eff.call(stage.muxer(), NETWORK_SEND_TIMEOUT, move |cr| {
199                        MuxMessage::Send(proto_id.erase(), msg, cr)
200                    })
201                    .await;
202                }
203            }
204
205            (proto, stage)
206        })
207    }
208}