amaru_protocols/protocol/
mod.rs1use std::{
16 fmt::{Display, Formatter},
17 marker::PhantomData,
18 time::Duration,
19};
20
21use bytes::{Buf, BufMut, Bytes, BytesMut, TryGetError};
22
23mod check;
24mod miniprotocol;
25
26pub use check::ProtoSpec;
27pub use miniprotocol::{Inputs, Miniprotocol, Outcome, ProtocolState, StageState, miniprotocol, outcome};
28
29#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
31pub enum Input<L, R> {
32 Local(L),
33 Remote(R),
34}
35
36pub const NETWORK_SEND_TIMEOUT: Duration = Duration::from_secs(1);
38
39#[derive(serde::Serialize, serde::Deserialize)]
40pub struct ProtocolId<T: RoleT>(u16, PhantomData<T>);
41
42impl<T: RoleT> ProtocolId<T> {
43 pub fn encode(self, buffer: &mut BytesMut) {
44 buffer.put_u16(self.0);
45 }
46
47 pub fn decode(buffer: &mut Bytes) -> Result<Self, TryGetError> {
48 Ok(Self(buffer.try_get_u16()?, PhantomData))
49 }
50}
51
52impl<T: RoleT> std::fmt::Display for ProtocolId<T> {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 write!(f, "{}", self.0)
55 }
56}
57
58impl<T: RoleT> std::hash::Hash for ProtocolId<T> {
59 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
60 self.0.hash(state);
61 }
62}
63
64impl<T: RoleT> Ord for ProtocolId<T> {
65 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
66 self.0.cmp(&other.0)
67 }
68}
69
70impl<T: RoleT> PartialOrd for ProtocolId<T> {
71 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
72 Some(self.cmp(other))
73 }
74}
75
76impl<T: RoleT> Eq for ProtocolId<T> {}
77
78impl<T: RoleT> PartialEq for ProtocolId<T> {
79 fn eq(&self, other: &Self) -> bool {
80 self.0 == other.0 && self.1 == other.1
81 }
82}
83
84impl<T: RoleT> std::fmt::Debug for ProtocolId<T> {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 f.debug_tuple("ProtocolId").field(&self.0).finish()
87 }
88}
89
90impl<T: RoleT> Copy for ProtocolId<T> {}
91
92impl<R: RoleT> Clone for ProtocolId<R> {
93 fn clone(&self) -> Self {
94 *self
95 }
96}
97
98const RESPONDER: u16 = 0x8000;
99
100#[derive(Debug, PartialEq, Eq, Clone, Copy, serde::Serialize, serde::Deserialize)]
101pub enum Role {
102 Initiator,
103 Responder,
104}
105
106impl Display for Role {
107 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
108 match self {
109 Role::Initiator => write!(f, "initiator"),
110 Role::Responder => write!(f, "responder"),
111 }
112 }
113}
114
115impl Role {
116 pub const fn opposite(self) -> Role {
117 match self {
118 Role::Initiator => Role::Responder,
119 Role::Responder => Role::Initiator,
120 }
121 }
122}
123
124mod sealed {
125 pub trait Sealed {}
126}
127pub trait RoleT:
128 Clone
129 + Copy
130 + std::fmt::Debug
131 + std::hash::Hash
132 + std::cmp::Ord
133 + std::cmp::PartialOrd
134 + std::cmp::Eq
135 + std::cmp::PartialEq
136 + serde::Serialize
137 + serde::de::DeserializeOwned
138 + Send
139 + Sync
140 + 'static
141 + sealed::Sealed
142{
143 type Opposite: RoleT;
144
145 const ROLE: Option<Role>;
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
149pub struct Initiator;
150impl sealed::Sealed for Initiator {}
151impl RoleT for Initiator {
152 type Opposite = Responder;
153
154 const ROLE: Option<Role> = Some(Role::Initiator);
155}
156
157#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
158pub struct Responder;
159impl sealed::Sealed for Responder {}
160impl RoleT for Responder {
161 type Opposite = Initiator;
162
163 const ROLE: Option<Role> = Some(Role::Responder);
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
167pub struct Erased;
168impl sealed::Sealed for Erased {}
169impl RoleT for Erased {
170 type Opposite = Erased;
171
172 const ROLE: Option<Role> = None;
173}
174
175pub const PROTO_HANDSHAKE: ProtocolId<Initiator> = ProtocolId::<Initiator>(0, PhantomData);
176
177pub const PROTO_N2N_CHAIN_SYNC: ProtocolId<Initiator> = ProtocolId::<Initiator>(2, PhantomData);
178pub const PROTO_N2N_BLOCK_FETCH: ProtocolId<Initiator> = ProtocolId::<Initiator>(3, PhantomData);
179pub const PROTO_N2N_TX_SUB: ProtocolId<Initiator> = ProtocolId::<Initiator>(4, PhantomData);
180pub const PROTO_N2N_KEEP_ALIVE: ProtocolId<Initiator> = ProtocolId::<Initiator>(8, PhantomData);
181pub const PROTO_N2N_PEER_SHARE: ProtocolId<Initiator> = ProtocolId::<Initiator>(10, PhantomData);
182
183#[cfg(test)]
191pub const PROTO_TEST: ProtocolId<Initiator> = ProtocolId::<Initiator>(257, PhantomData);
192
193impl<R: RoleT> ProtocolId<R> {
194 pub const fn is_initiator(self) -> bool {
195 self.0 & RESPONDER == 0
196 }
197
198 pub const fn is_responder(self) -> bool {
199 !self.is_initiator()
200 }
201
202 pub const fn opposite(self) -> ProtocolId<R::Opposite> {
203 ProtocolId(self.0 ^ RESPONDER, PhantomData)
204 }
205
206 pub const fn erase(self) -> ProtocolId<Erased> {
207 ProtocolId(self.0, PhantomData)
208 }
209
210 pub const fn for_role(self, role: Role) -> ProtocolId<Erased> {
211 match (role, self.role()) {
212 (Role::Initiator, Role::Initiator) | (Role::Responder, Role::Responder) => self.erase(),
213 (Role::Initiator, Role::Responder) | (Role::Responder, Role::Initiator) => self.opposite().erase(),
214 }
215 }
216
217 pub const fn role(self) -> Role {
218 if let Some(role) = R::ROLE {
219 role
220 } else if self.is_initiator() {
221 Role::Initiator
222 } else {
223 Role::Responder
224 }
225 }
226}
227
228impl ProtocolId<Initiator> {
229 pub const fn responder(self) -> ProtocolId<Responder> {
230 ProtocolId(self.0 | RESPONDER, PhantomData)
231 }
232}
233
234impl ProtocolId<Responder> {
235 pub const fn initiator(self) -> ProtocolId<Initiator> {
236 ProtocolId(self.0 & !RESPONDER, PhantomData)
237 }
238}