amaru_protocols/tx_submission/
messages.rs1use std::fmt::Display;
16
17use amaru_kernel::{NonEmptyBytes, Transaction, cbor, to_cbor};
18use amaru_ouroboros_traits::TxId;
19
20use crate::tx_submission::Blocking;
21
22#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
24#[repr(u8)]
25pub enum Message {
26 Init,
27 RequestTxIdsBlocking(u16, u16),
28 RequestTxIdsNonBlocking(u16, u16),
29 RequestTxs(Vec<TxId>),
30 ReplyTxIds(Vec<(TxId, u32)>),
31 ReplyTxs(Vec<Transaction>),
32 Done,
33}
34
35impl Message {
36 fn discriminant(&self) -> u8 {
38 unsafe { *<*const _>::from(self).cast::<u8>() }
42 }
43}
44
45impl PartialOrd for Message {
46 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
47 Some(self.cmp(other))
48 }
49}
50
51impl Ord for Message {
52 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
53 self.discriminant().cmp(&other.discriminant()).then_with(|| match (self, other) {
54 (Message::Init, Message::Init) => std::cmp::Ordering::Equal,
55 (Message::RequestTxIdsBlocking(a1, b1), Message::RequestTxIdsBlocking(a2, b2)) => {
56 a1.cmp(a2).then_with(|| b1.cmp(b2))
57 }
58 (Message::RequestTxIdsNonBlocking(a1, b1), Message::RequestTxIdsNonBlocking(a2, b2)) => {
59 a1.cmp(a2).then_with(|| b1.cmp(b2))
60 }
61 (Message::RequestTxs(a1), Message::RequestTxs(a2)) => a1.cmp(a2),
62 (Message::ReplyTxIds(a1), Message::ReplyTxIds(a2)) => a1.cmp(a2),
63 (Message::ReplyTxs(a1), Message::ReplyTxs(a2)) => {
64 let left = to_cbor(a1);
65 let right = to_cbor(a2);
66 left.cmp(&right)
67 }
68 (Message::Done, Message::Done) => std::cmp::Ordering::Equal,
69 _ => unreachable!(),
70 })
71 }
72}
73
74impl Message {
75 pub fn message_type(&self) -> &'static str {
76 match self {
77 Message::Init => "Init",
78 Message::RequestTxIdsBlocking(_, _) => "RequestTxIdsBlocking",
79 Message::RequestTxIdsNonBlocking(_, _) => "RequestTxIdsNonBlocking",
80 Message::ReplyTxIds(_) => "ReplyTxIds",
81 Message::RequestTxs(_) => "RequestTxs",
82 Message::ReplyTxs(_) => "ReplyTxs",
83 Message::Done => "Done",
84 }
85 }
86}
87
88impl cbor::Encode<()> for Message {
89 fn encode<W: cbor::encode::Write>(
90 &self,
91 e: &mut cbor::Encoder<W>,
92 _ctx: &mut (),
93 ) -> Result<(), cbor::encode::Error<W::Error>> {
94 match self {
95 Message::RequestTxIdsBlocking(ack, req) => {
96 e.array(4)?.u16(0)?;
97 e.encode(Blocking::Yes)?;
98 e.u16(*ack)?;
99 e.u16(*req)?;
100 }
101 Message::RequestTxIdsNonBlocking(ack, req) => {
102 e.array(4)?.u16(0)?;
103 e.encode(Blocking::No)?;
104 e.u16(*ack)?;
105 e.u16(*req)?;
106 }
107 Message::ReplyTxIds(ids) => {
108 e.array(2)?.u16(1)?;
109 e.begin_array()?;
110 for id in ids {
111 e.encode(id)?;
112 }
113 e.end()?;
114 }
115 Message::RequestTxs(ids) => {
116 e.array(2)?.u16(2)?;
117 e.array(ids.len() as u64)?;
118 for id in ids {
119 e.encode(id)?;
120 }
121 }
122 Message::ReplyTxs(txs) => {
123 e.array(2)?.u16(3)?;
124 e.array(txs.len() as u64)?;
125 for tx in txs {
126 e.encode(tx)?;
127 }
128 }
129 Message::Done => {
130 e.array(1)?.u16(4)?;
131 }
132 Message::Init => {
133 e.array(1)?.u16(6)?;
134 }
135 }
136 Ok(())
137 }
138}
139
140impl<'b> cbor::Decode<'b, ()> for Message {
141 fn decode(d: &mut cbor::Decoder<'b>, _ctx: &mut ()) -> Result<Self, cbor::decode::Error> {
142 let len = d.array()?;
143 let label = d.u16()?;
144
145 match label {
146 0 => {
147 cbor::check_tagged_array_length(0, len, 4)?;
148 let blocking = d.decode()?;
149 let ack = d.u16()?;
150 let req = d.u16()?;
151 match blocking {
152 Blocking::Yes => Ok(Message::RequestTxIdsBlocking(ack, req)),
153 Blocking::No => Ok(Message::RequestTxIdsNonBlocking(ack, req)),
154 }
155 }
156 1 => {
157 cbor::check_tagged_array_length(1, len, 2)?;
158 let items = d.decode()?;
159 Ok(Message::ReplyTxIds(items))
160 }
161 2 => {
162 cbor::check_tagged_array_length(2, len, 2)?;
163 let ids = d.decode()?;
164 Ok(Message::RequestTxs(ids))
165 }
166 3 => {
167 cbor::check_tagged_array_length(3, len, 2)?;
168 Ok(Message::ReplyTxs(d.array_iter()?.collect::<Result<_, _>>()?))
169 }
170 4 => {
171 cbor::check_tagged_array_length(4, len, 1)?;
172 Ok(Message::Done)
173 }
174 6 => {
175 cbor::check_tagged_array_length(6, len, 1)?;
176 Ok(Message::Init)
177 }
178 _ => Err(cbor::decode::Error::message("unknown variant for txsubmission message")),
179 }
180 }
181}
182
183impl Display for Message {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 match self {
186 Message::Init => write!(f, "Init"),
187 Message::RequestTxIdsBlocking(ack, req) => {
188 write!(f, "RequestTxIdsBlocking(ack: {}, req: {})", ack, req,)
189 }
190 Message::RequestTxIdsNonBlocking(ack, req) => {
191 write!(f, "RequestTxIdsNonBlocking(ack: {}, req: {})", ack, req,)
192 }
193 Message::ReplyTxIds(ids) => {
194 write!(
195 f,
196 "ReplyTxIds(ids: [{}])",
197 ids.iter().map(|(id, size)| format!("({}, {})", id, size)).collect::<Vec<_>>().join(", ")
198 )
199 }
200 Message::RequestTxs(ids) => {
201 write!(
202 f,
203 "RequestTxs(ids: [{}])",
204 ids.iter().map(|id| format!("{}", id)).collect::<Vec<_>>().join(", ")
205 )
206 }
207 Message::ReplyTxs(txs) => {
208 write!(
209 f,
210 "ReplyTxs(txs: [{}])",
211 txs.iter().map(|tx| format!("{}", TxId::from(tx))).collect::<Vec<_>>().join(", ")
212 )
213 }
214 Message::Done => write!(f, "Done"),
215 }
216 }
217}
218
219#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
221pub enum TxSubmissionMessage {
222 Registered,
223 FromNetwork(NonEmptyBytes),
224}
225
226#[cfg(test)]
228mod tests {
229 use amaru_kernel::{Hash, Transaction, prop_cbor_roundtrip};
230 use prop::collection::vec;
231 use proptest::{prelude::*, prop_compose};
232
233 use super::*;
234 use crate::tx_submission::tests::create_transaction;
235
236 mod tx_id {
237 use super::*;
238 prop_cbor_roundtrip!(TxId, any_tx_id());
239 }
240 mod message {
241 use super::*;
242 prop_cbor_roundtrip!(Message, any_message());
243 }
244 mod blocking {
245 use super::*;
246 prop_cbor_roundtrip!(Blocking, any_blocking());
247 }
248
249 prop_compose! {
252 pub fn any_tx_id()(
253 bytes in any::<[u8; 32]>(),
254 ) -> TxId {
255 TxId::new(Hash::new(bytes))
256 }
257 }
258
259 prop_compose! {
260 pub fn any_blocking()(
261 bool in any::<bool>(),
262 ) -> Blocking {
263 if bool { Blocking::Yes } else { Blocking::No }
264 }
265 }
266
267 prop_compose! {
268 fn any_ack_req()(ack in 0u16..=1000, req in 0u16..=1000) -> (u16, u16) {
269 (ack, req)
270 }
271 }
272
273 prop_compose! {
274 fn any_tx_id_and_sizes_vec()(ids in vec(any_tx_id(), 0..20), sizes in vec(any::<u32>(), 0..20)) -> Vec<(TxId, u32)> {
275 ids.iter().zip(sizes).map(|(id, size)| (*id, size)).collect()
276 }
277 }
278
279 prop_compose! {
280 fn any_tx_id_vec()(ids in prop::collection::vec(any_tx_id(), 0..20)) -> Vec<TxId> {
281 ids
282 }
283 }
284
285 prop_compose! {
286 fn any_tx_vec()(txs in prop::collection::vec(any_tx(), 0..10)) -> Vec<Transaction> {
287 txs
288 }
289 }
290
291 prop_compose! {
292 fn any_tx()(n in 0u64..=1000) -> Transaction {
293 create_transaction(n)
294 }
295 }
296
297 fn init_message() -> impl Strategy<Value = Message> {
298 Just(Message::Init)
299 }
300
301 prop_compose! {
302 fn request_tx_ids_message()((ack, req) in any_ack_req(), blocking in any_blocking()) -> Message {
303 match blocking {
304 Blocking::Yes => Message::RequestTxIdsBlocking(ack, req),
305 Blocking::No => Message::RequestTxIdsNonBlocking(ack, req),
306 }
307 }
308 }
309
310 prop_compose! {
311 fn reply_tx_ids_message()(ids in any_tx_id_and_sizes_vec()) -> Message {
312 Message::ReplyTxIds(ids)
313 }
314 }
315
316 prop_compose! {
317 fn request_txs_message()(ids in any_tx_id_vec()) -> Message {
318 Message::RequestTxs(ids)
319 }
320 }
321
322 prop_compose! {
323 fn reply_txs_message()(txs in any_tx_vec()) -> Message {
324 Message::ReplyTxs(txs)
325 }
326 }
327
328 fn done_message() -> impl Strategy<Value = Message> {
329 Just(Message::Done)
330 }
331
332 pub fn any_message() -> impl Strategy<Value = Message> {
333 prop_oneof![
334 1 => init_message(),
335 3 => request_tx_ids_message(),
336 3 => reply_tx_ids_message(),
337 3 => request_txs_message(),
338 3 => reply_txs_message(),
339 1 => done_message(),
340 ]
341 }
342}