Skip to main content

amaru_protocols/protocol/
mod.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    fmt::{Display, Formatter},
17    marker::PhantomData,
18    time::Duration,
19};
20
21use bytes::{Buf, BufMut, Bytes, BytesMut, TryGetError};
22
23mod check;
24mod miniprotocol;
25
26pub use check::ProtoSpec;
27pub use miniprotocol::{Inputs, Miniprotocol, Outcome, ProtocolState, StageState, miniprotocol, outcome};
28
29/// Input to a protocol step
30#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
31pub enum Input<L, R> {
32    Local(L),
33    Remote(R),
34}
35
36// TODO(network) find right value
37pub const NETWORK_SEND_TIMEOUT: Duration = Duration::from_secs(1);
38
39#[derive(serde::Serialize, serde::Deserialize)]
40pub struct ProtocolId<T: RoleT>(u16, PhantomData<T>);
41
42impl<T: RoleT> ProtocolId<T> {
43    pub fn encode(self, buffer: &mut BytesMut) {
44        buffer.put_u16(self.0);
45    }
46
47    pub fn decode(buffer: &mut Bytes) -> Result<Self, TryGetError> {
48        Ok(Self(buffer.try_get_u16()?, PhantomData))
49    }
50}
51
52impl<T: RoleT> std::fmt::Display for ProtocolId<T> {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        write!(f, "{}", self.0)
55    }
56}
57
58impl<T: RoleT> std::hash::Hash for ProtocolId<T> {
59    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
60        self.0.hash(state);
61    }
62}
63
64impl<T: RoleT> Ord for ProtocolId<T> {
65    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
66        self.0.cmp(&other.0)
67    }
68}
69
70impl<T: RoleT> PartialOrd for ProtocolId<T> {
71    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
72        Some(self.cmp(other))
73    }
74}
75
76impl<T: RoleT> Eq for ProtocolId<T> {}
77
78impl<T: RoleT> PartialEq for ProtocolId<T> {
79    fn eq(&self, other: &Self) -> bool {
80        self.0 == other.0 && self.1 == other.1
81    }
82}
83
84impl<T: RoleT> std::fmt::Debug for ProtocolId<T> {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        f.debug_tuple("ProtocolId").field(&self.0).finish()
87    }
88}
89
90impl<T: RoleT> Copy for ProtocolId<T> {}
91
92impl<R: RoleT> Clone for ProtocolId<R> {
93    fn clone(&self) -> Self {
94        *self
95    }
96}
97
98const RESPONDER: u16 = 0x8000;
99
100#[derive(Debug, PartialEq, Eq, Clone, Copy, serde::Serialize, serde::Deserialize)]
101pub enum Role {
102    Initiator,
103    Responder,
104}
105
106impl Display for Role {
107    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
108        match self {
109            Role::Initiator => write!(f, "initiator"),
110            Role::Responder => write!(f, "responder"),
111        }
112    }
113}
114
115impl Role {
116    pub const fn opposite(self) -> Role {
117        match self {
118            Role::Initiator => Role::Responder,
119            Role::Responder => Role::Initiator,
120        }
121    }
122}
123
124mod sealed {
125    pub trait Sealed {}
126}
127pub trait RoleT:
128    Clone
129    + Copy
130    + std::fmt::Debug
131    + std::hash::Hash
132    + std::cmp::Ord
133    + std::cmp::PartialOrd
134    + std::cmp::Eq
135    + std::cmp::PartialEq
136    + serde::Serialize
137    + serde::de::DeserializeOwned
138    + Send
139    + Sync
140    + 'static
141    + sealed::Sealed
142{
143    type Opposite: RoleT;
144
145    const ROLE: Option<Role>;
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
149pub struct Initiator;
150impl sealed::Sealed for Initiator {}
151impl RoleT for Initiator {
152    type Opposite = Responder;
153
154    const ROLE: Option<Role> = Some(Role::Initiator);
155}
156
157#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
158pub struct Responder;
159impl sealed::Sealed for Responder {}
160impl RoleT for Responder {
161    type Opposite = Initiator;
162
163    const ROLE: Option<Role> = Some(Role::Responder);
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
167pub struct Erased;
168impl sealed::Sealed for Erased {}
169impl RoleT for Erased {
170    type Opposite = Erased;
171
172    const ROLE: Option<Role> = None;
173}
174
175pub const PROTO_HANDSHAKE: ProtocolId<Initiator> = ProtocolId::<Initiator>(0, PhantomData);
176
177pub const PROTO_N2N_CHAIN_SYNC: ProtocolId<Initiator> = ProtocolId::<Initiator>(2, PhantomData);
178pub const PROTO_N2N_BLOCK_FETCH: ProtocolId<Initiator> = ProtocolId::<Initiator>(3, PhantomData);
179pub const PROTO_N2N_TX_SUB: ProtocolId<Initiator> = ProtocolId::<Initiator>(4, PhantomData);
180pub const PROTO_N2N_KEEP_ALIVE: ProtocolId<Initiator> = ProtocolId::<Initiator>(8, PhantomData);
181pub const PROTO_N2N_PEER_SHARE: ProtocolId<Initiator> = ProtocolId::<Initiator>(10, PhantomData);
182
183// The below are only for information regarding the allocated numbers, Amaru will not implement N2C protocols.
184
185// pub const PROTO_N2C_CHAIN_SYNC: ProtocolId<Initiator> = ProtocolId::<Initiator>(5, PhantomData);
186// pub const PROTO_N2C_TX_SUB: ProtocolId<Initiator> = ProtocolId::<Initiator>(6, PhantomData);
187// pub const PROTO_N2C_STATE_QUERY: ProtocolId<Initiator> = ProtocolId::<Initiator>(7, PhantomData);
188// pub const PROTO_N2C_TX_MON: ProtocolId<Initiator> = ProtocolId::<Initiator>(9, PhantomData);
189
190#[cfg(test)]
191pub const PROTO_TEST: ProtocolId<Initiator> = ProtocolId::<Initiator>(257, PhantomData);
192
193impl<R: RoleT> ProtocolId<R> {
194    pub const fn is_initiator(self) -> bool {
195        self.0 & RESPONDER == 0
196    }
197
198    pub const fn is_responder(self) -> bool {
199        !self.is_initiator()
200    }
201
202    pub const fn opposite(self) -> ProtocolId<R::Opposite> {
203        ProtocolId(self.0 ^ RESPONDER, PhantomData)
204    }
205
206    pub const fn erase(self) -> ProtocolId<Erased> {
207        ProtocolId(self.0, PhantomData)
208    }
209
210    pub const fn for_role(self, role: Role) -> ProtocolId<Erased> {
211        match (role, self.role()) {
212            (Role::Initiator, Role::Initiator) | (Role::Responder, Role::Responder) => self.erase(),
213            (Role::Initiator, Role::Responder) | (Role::Responder, Role::Initiator) => self.opposite().erase(),
214        }
215    }
216
217    pub const fn role(self) -> Role {
218        if let Some(role) = R::ROLE {
219            role
220        } else if self.is_initiator() {
221            Role::Initiator
222        } else {
223            Role::Responder
224        }
225    }
226}
227
228impl ProtocolId<Initiator> {
229    pub const fn responder(self) -> ProtocolId<Responder> {
230        ProtocolId(self.0 | RESPONDER, PhantomData)
231    }
232}
233
234impl ProtocolId<Responder> {
235    pub const fn initiator(self) -> ProtocolId<Initiator> {
236        ProtocolId(self.0 & !RESPONDER, PhantomData)
237    }
238}