Skip to main content

amaru_protocols/tx_submission/
messages.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use amaru_kernel::{NonEmptyBytes, Transaction, cbor, to_cbor};
18use amaru_ouroboros_traits::TxId;
19
20use crate::tx_submission::Blocking;
21
22/// Messages for the txsubmission mini-protocol.
23#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
24#[repr(u8)]
25pub enum Message {
26    Init,
27    RequestTxIdsBlocking(u16, u16),
28    RequestTxIdsNonBlocking(u16, u16),
29    RequestTxs(Vec<TxId>),
30    ReplyTxIds(Vec<(TxId, u32)>),
31    ReplyTxs(Vec<Transaction>),
32    Done,
33}
34
35impl Message {
36    /// This is copied from the `std::mem` docs, it is the official way.
37    fn discriminant(&self) -> u8 {
38        // SAFETY: Because `Self` is marked `repr(u8)`, its layout is a `repr(C)` `union`
39        // between `repr(C)` structs, each of which has the `u8` discriminant as its first
40        // field, so we can read the discriminant without offsetting the pointer.
41        unsafe { *<*const _>::from(self).cast::<u8>() }
42    }
43}
44
45impl PartialOrd for Message {
46    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
47        Some(self.cmp(other))
48    }
49}
50
51impl Ord for Message {
52    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
53        self.discriminant().cmp(&other.discriminant()).then_with(|| match (self, other) {
54            (Message::Init, Message::Init) => std::cmp::Ordering::Equal,
55            (Message::RequestTxIdsBlocking(a1, b1), Message::RequestTxIdsBlocking(a2, b2)) => {
56                a1.cmp(a2).then_with(|| b1.cmp(b2))
57            }
58            (Message::RequestTxIdsNonBlocking(a1, b1), Message::RequestTxIdsNonBlocking(a2, b2)) => {
59                a1.cmp(a2).then_with(|| b1.cmp(b2))
60            }
61            (Message::RequestTxs(a1), Message::RequestTxs(a2)) => a1.cmp(a2),
62            (Message::ReplyTxIds(a1), Message::ReplyTxIds(a2)) => a1.cmp(a2),
63            (Message::ReplyTxs(a1), Message::ReplyTxs(a2)) => {
64                let left = to_cbor(a1);
65                let right = to_cbor(a2);
66                left.cmp(&right)
67            }
68            (Message::Done, Message::Done) => std::cmp::Ordering::Equal,
69            _ => unreachable!(),
70        })
71    }
72}
73
74impl Message {
75    pub fn message_type(&self) -> &'static str {
76        match self {
77            Message::Init => "Init",
78            Message::RequestTxIdsBlocking(_, _) => "RequestTxIdsBlocking",
79            Message::RequestTxIdsNonBlocking(_, _) => "RequestTxIdsNonBlocking",
80            Message::ReplyTxIds(_) => "ReplyTxIds",
81            Message::RequestTxs(_) => "RequestTxs",
82            Message::ReplyTxs(_) => "ReplyTxs",
83            Message::Done => "Done",
84        }
85    }
86}
87
88impl cbor::Encode<()> for Message {
89    fn encode<W: cbor::encode::Write>(
90        &self,
91        e: &mut cbor::Encoder<W>,
92        _ctx: &mut (),
93    ) -> Result<(), cbor::encode::Error<W::Error>> {
94        match self {
95            Message::RequestTxIdsBlocking(ack, req) => {
96                e.array(4)?.u16(0)?;
97                e.encode(Blocking::Yes)?;
98                e.u16(*ack)?;
99                e.u16(*req)?;
100            }
101            Message::RequestTxIdsNonBlocking(ack, req) => {
102                e.array(4)?.u16(0)?;
103                e.encode(Blocking::No)?;
104                e.u16(*ack)?;
105                e.u16(*req)?;
106            }
107            Message::ReplyTxIds(ids) => {
108                e.array(2)?.u16(1)?;
109                e.begin_array()?;
110                for id in ids {
111                    e.encode(id)?;
112                }
113                e.end()?;
114            }
115            Message::RequestTxs(ids) => {
116                e.array(2)?.u16(2)?;
117                e.array(ids.len() as u64)?;
118                for id in ids {
119                    e.encode(id)?;
120                }
121            }
122            Message::ReplyTxs(txs) => {
123                e.array(2)?.u16(3)?;
124                e.array(txs.len() as u64)?;
125                for tx in txs {
126                    e.encode(tx)?;
127                }
128            }
129            Message::Done => {
130                e.array(1)?.u16(4)?;
131            }
132            Message::Init => {
133                e.array(1)?.u16(6)?;
134            }
135        }
136        Ok(())
137    }
138}
139
140impl<'b> cbor::Decode<'b, ()> for Message {
141    fn decode(d: &mut cbor::Decoder<'b>, _ctx: &mut ()) -> Result<Self, cbor::decode::Error> {
142        let len = d.array()?;
143        let label = d.u16()?;
144
145        match label {
146            0 => {
147                cbor::check_tagged_array_length(0, len, 4)?;
148                let blocking = d.decode()?;
149                let ack = d.u16()?;
150                let req = d.u16()?;
151                match blocking {
152                    Blocking::Yes => Ok(Message::RequestTxIdsBlocking(ack, req)),
153                    Blocking::No => Ok(Message::RequestTxIdsNonBlocking(ack, req)),
154                }
155            }
156            1 => {
157                cbor::check_tagged_array_length(1, len, 2)?;
158                let items = d.decode()?;
159                Ok(Message::ReplyTxIds(items))
160            }
161            2 => {
162                cbor::check_tagged_array_length(2, len, 2)?;
163                let ids = d.decode()?;
164                Ok(Message::RequestTxs(ids))
165            }
166            3 => {
167                cbor::check_tagged_array_length(3, len, 2)?;
168                Ok(Message::ReplyTxs(d.array_iter()?.collect::<Result<_, _>>()?))
169            }
170            4 => {
171                cbor::check_tagged_array_length(4, len, 1)?;
172                Ok(Message::Done)
173            }
174            6 => {
175                cbor::check_tagged_array_length(6, len, 1)?;
176                Ok(Message::Init)
177            }
178            _ => Err(cbor::decode::Error::message("unknown variant for txsubmission message")),
179        }
180    }
181}
182
183impl Display for Message {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        match self {
186            Message::Init => write!(f, "Init"),
187            Message::RequestTxIdsBlocking(ack, req) => {
188                write!(f, "RequestTxIdsBlocking(ack: {}, req: {})", ack, req,)
189            }
190            Message::RequestTxIdsNonBlocking(ack, req) => {
191                write!(f, "RequestTxIdsNonBlocking(ack: {}, req: {})", ack, req,)
192            }
193            Message::ReplyTxIds(ids) => {
194                write!(
195                    f,
196                    "ReplyTxIds(ids: [{}])",
197                    ids.iter().map(|(id, size)| format!("({}, {})", id, size)).collect::<Vec<_>>().join(", ")
198                )
199            }
200            Message::RequestTxs(ids) => {
201                write!(
202                    f,
203                    "RequestTxs(ids: [{}])",
204                    ids.iter().map(|id| format!("{}", id)).collect::<Vec<_>>().join(", ")
205                )
206            }
207            Message::ReplyTxs(txs) => {
208                write!(
209                    f,
210                    "ReplyTxs(txs: [{}])",
211                    txs.iter().map(|tx| format!("{}", TxId::from(tx))).collect::<Vec<_>>().join(", ")
212                )
213            }
214            Message::Done => write!(f, "Done"),
215        }
216    }
217}
218
219/// Messages coming directly from the muxer.
220#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
221pub enum TxSubmissionMessage {
222    Registered,
223    FromNetwork(NonEmptyBytes),
224}
225
226/// Roundtrip property tests for txsubmission messages.
227#[cfg(test)]
228mod tests {
229    use amaru_kernel::{Hash, Transaction, prop_cbor_roundtrip};
230    use prop::collection::vec;
231    use proptest::{prelude::*, prop_compose};
232
233    use super::*;
234    use crate::tx_submission::tests::create_transaction;
235
236    mod tx_id {
237        use super::*;
238        prop_cbor_roundtrip!(TxId, any_tx_id());
239    }
240    mod message {
241        use super::*;
242        prop_cbor_roundtrip!(Message, any_message());
243    }
244    mod blocking {
245        use super::*;
246        prop_cbor_roundtrip!(Blocking, any_blocking());
247    }
248
249    // HELPERS
250
251    prop_compose! {
252        pub fn any_tx_id()(
253            bytes in any::<[u8; 32]>(),
254        ) -> TxId {
255            TxId::new(Hash::new(bytes))
256        }
257    }
258
259    prop_compose! {
260        pub fn any_blocking()(
261            bool in any::<bool>(),
262        ) -> Blocking {
263            if bool { Blocking::Yes } else { Blocking::No }
264        }
265    }
266
267    prop_compose! {
268        fn any_ack_req()(ack in 0u16..=1000, req in 0u16..=1000) -> (u16, u16) {
269            (ack, req)
270        }
271    }
272
273    prop_compose! {
274        fn any_tx_id_and_sizes_vec()(ids in vec(any_tx_id(), 0..20), sizes in vec(any::<u32>(), 0..20)) -> Vec<(TxId, u32)> {
275            ids.iter().zip(sizes).map(|(id, size)| (*id, size)).collect()
276        }
277    }
278
279    prop_compose! {
280        fn any_tx_id_vec()(ids in prop::collection::vec(any_tx_id(), 0..20)) -> Vec<TxId> {
281            ids
282        }
283    }
284
285    prop_compose! {
286        fn any_tx_vec()(txs in prop::collection::vec(any_tx(), 0..10)) -> Vec<Transaction> {
287            txs
288        }
289    }
290
291    prop_compose! {
292        fn any_tx()(n in 0u64..=1000) -> Transaction {
293            create_transaction(n)
294        }
295    }
296
297    fn init_message() -> impl Strategy<Value = Message> {
298        Just(Message::Init)
299    }
300
301    prop_compose! {
302        fn request_tx_ids_message()((ack, req) in any_ack_req(), blocking in any_blocking()) -> Message {
303            match blocking {
304                Blocking::Yes => Message::RequestTxIdsBlocking(ack, req),
305                Blocking::No => Message::RequestTxIdsNonBlocking(ack, req),
306            }
307        }
308    }
309
310    prop_compose! {
311        fn reply_tx_ids_message()(ids in any_tx_id_and_sizes_vec()) -> Message {
312            Message::ReplyTxIds(ids)
313        }
314    }
315
316    prop_compose! {
317        fn request_txs_message()(ids in any_tx_id_vec()) -> Message {
318            Message::RequestTxs(ids)
319        }
320    }
321
322    prop_compose! {
323        fn reply_txs_message()(txs in any_tx_vec()) -> Message {
324            Message::ReplyTxs(txs)
325        }
326    }
327
328    fn done_message() -> impl Strategy<Value = Message> {
329        Just(Message::Done)
330    }
331
332    pub fn any_message() -> impl Strategy<Value = Message> {
333        prop_oneof![
334            1 => init_message(),
335            3 => request_tx_ids_message(),
336            3 => reply_tx_ids_message(),
337            3 => request_txs_message(),
338            3 => reply_txs_message(),
339            1 => done_message(),
340        ]
341    }
342}