amaru_protocols/protocol/
miniprotocol.rs1use std::future::Future;
16
17use amaru_kernel::{NonEmptyBytes, cbor};
18use pure_stage::{BoxFuture, Effects, SendData, StageRef, TryInStage, Void, err};
19
20use crate::{
21 mux::{HandlerMessage, MuxMessage},
22 protocol::{NETWORK_SEND_TIMEOUT, ProtocolId, RoleT},
23};
24
25#[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
27pub enum Inputs<L> {
28 Local(L),
29 Network(HandlerMessage),
30}
31
32#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
34pub struct Outcome<S, R, E> {
35 pub send: Option<S>,
36 pub result: Option<R>,
37 pub terminate_with: Option<E>,
38 pub want_next: bool,
39}
40
41impl<S, R, E> Outcome<S, R, E> {
42 pub fn send(self, send: S) -> Self {
43 Self { send: Some(send), result: self.result, terminate_with: self.terminate_with, want_next: self.want_next }
44 }
45
46 pub fn result(self, done: R) -> Self {
47 Self { send: self.send, result: Some(done), terminate_with: self.terminate_with, want_next: self.want_next }
48 }
49
50 pub fn want_next(self) -> Self {
51 Self { send: self.send, result: self.result, terminate_with: self.terminate_with, want_next: true }
52 }
53
54 pub fn terminate_with(self, e: E) -> Self {
55 Self { send: self.send, result: self.result, terminate_with: Some(e), want_next: self.want_next }
56 }
57
58 pub fn without_result(self) -> Outcome<S, Void, E> {
59 Outcome { send: self.send, result: None, terminate_with: self.terminate_with, want_next: self.want_next }
60 }
61}
62
63pub fn outcome<S, R, E>() -> Outcome<S, R, E> {
64 Outcome { send: None, result: None, terminate_with: None, want_next: false }
65}
66
67pub trait ProtocolState<R: RoleT>: Sized + SendData {
71 type WireMsg: for<'de> cbor::Decode<'de, ()> + cbor::Encode<()> + Send;
72 type Action: std::fmt::Debug + Send;
73 type Out: std::fmt::Debug + PartialEq + Send;
74 type Error: std::fmt::Debug + std::fmt::Display + PartialEq + Send;
75
76 fn init(&self) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)>;
77
78 fn network(&self, input: Self::WireMsg) -> anyhow::Result<(Outcome<Self::WireMsg, Self::Out, Self::Error>, Self)>;
79
80 fn local(&self, input: Self::Action) -> anyhow::Result<(Outcome<Self::WireMsg, Void, Self::Error>, Self)>;
81}
82
83pub trait StageState<Proto: ProtocolState<R>, R: RoleT>: Sized + SendData {
87 type LocalIn: SendData;
88
89 fn local(
90 self,
91 proto: &Proto,
92 input: Self::LocalIn,
93 eff: &Effects<Inputs<Self::LocalIn>>,
94 ) -> impl Future<Output = anyhow::Result<(Option<Proto::Action>, Self)>> + Send;
95
96 fn network(
97 self,
98 proto: &Proto,
99 input: Proto::Out,
100 eff: &Effects<Inputs<Self::LocalIn>>,
101 ) -> impl Future<Output = anyhow::Result<(Option<Proto::Action>, Self)>> + Send;
102
103 fn muxer(&self) -> &StageRef<MuxMessage>;
104}
105
106pub type Miniprotocol<A, B, R>
107where
108 A: ProtocolState<R>,
109 B: StageState<A, R>,
110 R: RoleT,
111= impl Fn((A, B), Inputs<B::LocalIn>, Effects<Inputs<B::LocalIn>>) -> BoxFuture<'static, (A, B)> + Send + 'static;
112
113#[define_opaque(Miniprotocol)]
120pub fn miniprotocol<Proto, Stage, Role>(proto_id: ProtocolId<Role>) -> Miniprotocol<Proto, Stage, Role>
121where
122 Proto: ProtocolState<Role>,
123 Stage: StageState<Proto, Role>,
124 Role: RoleT,
125{
126 enum LocalOrNetwork<L, A> {
127 Local(L),
128 Network(A),
129 None,
130 }
131
132 move |(mut proto, mut stage), input, eff| {
133 Box::pin(async move {
134 let local_or_network = match input {
136 Inputs::Network(wire_msg) => {
137 let (outcome, s) = if let HandlerMessage::FromNetwork(wire_msg) = wire_msg {
138 let wire_msg: Proto::WireMsg = cbor::decode(&wire_msg)
139 .or_terminate(&eff, err("failed to decode message from network"))
140 .await;
141 proto.network(wire_msg).or_terminate(&eff, err("failed to step protocol state (network)")).await
142 } else {
143 proto.init().or_terminate(&eff, err("failed to initialize protocol state")).await
144 };
145 proto = s;
146 if outcome.want_next {
147 eff.send(stage.muxer(), MuxMessage::WantNext(proto_id.erase())).await;
148 }
149 if let Some(msg) = outcome.send {
150 let msg = NonEmptyBytes::encode(&msg);
151 eff.call(stage.muxer(), NETWORK_SEND_TIMEOUT, move |cr| {
152 MuxMessage::Send(proto_id.erase(), msg, cr)
153 })
154 .await;
155 }
156 outcome.result.map(LocalOrNetwork::Network).unwrap_or(LocalOrNetwork::None)
157 }
158 Inputs::Local(input) => LocalOrNetwork::Local(input),
159 };
160
161 let action = match local_or_network {
163 LocalOrNetwork::Local(local) => {
164 let (action, s) = stage
165 .local(&proto, local, &eff)
166 .await
167 .or_terminate(&eff, err("failed to step stage state (local)"))
168 .await;
169 stage = s;
170 action
171 }
172 LocalOrNetwork::Network(network) => {
173 let (action, s) = stage
174 .network(&proto, network, &eff)
175 .await
176 .or_terminate(&eff, err("failed to step stage state (network)"))
177 .await;
178 stage = s;
179 action
180 }
181 LocalOrNetwork::None => None,
182 };
183
184 if let Some(action) = action {
186 let (outcome, s) =
187 proto.local(action).or_terminate(&eff, err("failed to step protocol state (local)")).await;
188 proto = s;
189 if let Some(e) = outcome.terminate_with {
190 err("protocol error")(e).await;
191 return eff.terminate().await;
192 }
193 if outcome.want_next {
194 eff.send(stage.muxer(), MuxMessage::WantNext(proto_id.erase())).await;
195 }
196 if let Some(msg) = outcome.send {
197 let msg = NonEmptyBytes::encode(&msg);
198 eff.call(stage.muxer(), NETWORK_SEND_TIMEOUT, move |cr| {
199 MuxMessage::Send(proto_id.erase(), msg, cr)
200 })
201 .await;
202 }
203 }
204
205 (proto, stage)
206 })
207 }
208}