amaru_protocols/chainsync/
messages.rs1use amaru_kernel::{BlockHeader, EraName, Point, Tip, cbor, to_cbor};
16use pure_stage::DeserializerGuards;
17
18pub fn register_deserializers() -> DeserializerGuards {
19 vec![pure_stage::register_data_deserializer::<Message>().boxed()]
20}
21
22#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, Ord, PartialOrd)]
23pub enum Message {
24 RequestNext(u8),
25 AwaitReply,
26 RollForward(HeaderContent, Tip),
27 RollBackward(Point, Tip),
28 FindIntersect(Vec<Point>),
29 IntersectFound(Point, Tip),
30 IntersectNotFound(Tip),
31 Done,
32}
33
34impl Message {
35 pub fn message_type(&self) -> &str {
36 match self {
37 Message::RequestNext(_) => "RequestNext",
38 Message::AwaitReply => "AwaitReply",
39 Message::RollForward(_, _) => "RollForward",
40 Message::RollBackward(_, _) => "RollBackward",
41 Message::FindIntersect(_) => "FindIntersect",
42 Message::IntersectFound(_, _) => "IntersectFound",
43 Message::IntersectNotFound(_) => "IntersectNotFound",
44 Message::Done => "Done",
45 }
46 }
47}
48
49#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, Ord, PartialOrd)]
50pub struct HeaderContent {
51 pub variant: EraName,
52 pub byron_prefix: Option<(u8, u64)>,
53 pub cbor: Vec<u8>,
54}
55
56impl HeaderContent {
57 pub fn new(header: &BlockHeader, era: EraName) -> Self {
58 Self { variant: era, byron_prefix: None, cbor: to_cbor(header) }
59 }
60
61 pub fn with_bytes(bytes: Vec<u8>, variant: EraName) -> Self {
62 Self { variant, byron_prefix: None, cbor: bytes }
63 }
64}
65
66impl cbor::Encode<()> for Message {
67 fn encode<W: cbor::encode::Write>(
68 &self,
69 e: &mut cbor::Encoder<W>,
70 _ctx: &mut (),
71 ) -> Result<(), cbor::encode::Error<W::Error>> {
72 match self {
73 Message::RequestNext(n) => {
74 for _ in 0..*n {
75 e.array(1)?.u16(0)?;
76 }
77 Ok(())
78 }
79 Message::AwaitReply => {
80 e.array(1)?.u16(1)?;
81 Ok(())
82 }
83 Message::RollForward(content, tip) => {
84 e.array(3)?.u16(2)?;
85 e.encode(content)?;
86 e.encode(tip)?;
87 Ok(())
88 }
89 Message::RollBackward(point, tip) => {
90 e.array(3)?.u16(3)?;
91 e.encode(point)?;
92 e.encode(tip)?;
93 Ok(())
94 }
95 Message::FindIntersect(points) => {
96 e.array(2)?.u16(4)?;
97 e.array(points.len() as u64)?;
98 for point in points.iter() {
99 e.encode(point)?;
100 }
101 Ok(())
102 }
103 Message::IntersectFound(point, tip) => {
104 e.array(3)?.u16(5)?;
105 e.encode(point)?;
106 e.encode(tip)?;
107 Ok(())
108 }
109 Message::IntersectNotFound(tip) => {
110 e.array(2)?.u16(6)?;
111 e.encode(tip)?;
112 Ok(())
113 }
114 Message::Done => {
115 e.array(1)?.u16(7)?;
116 Ok(())
117 }
118 }
119 }
120}
121
122impl<'b> cbor::Decode<'b, ()> for Message {
123 fn decode(d: &mut cbor::Decoder<'b>, _ctx: &mut ()) -> Result<Self, cbor::decode::Error> {
124 let len = d.array()?;
125 let label = d.u16()?;
126
127 match label {
128 0 => {
129 cbor::check_tagged_array_length(0, len, 1)?;
130 Ok(Message::RequestNext(1))
131 }
132 1 => {
133 cbor::check_tagged_array_length(1, len, 1)?;
134 Ok(Message::AwaitReply)
135 }
136 2 => {
137 cbor::check_tagged_array_length(2, len, 3)?;
138 let content = d.decode()?;
139 let tip = d.decode()?;
140 Ok(Message::RollForward(content, tip))
141 }
142 3 => {
143 cbor::check_tagged_array_length(3, len, 3)?;
144 let point = d.decode()?;
145 let tip = d.decode()?;
146 Ok(Message::RollBackward(point, tip))
147 }
148 4 => {
149 cbor::check_tagged_array_length(4, len, 2)?;
150 let points = d.decode()?;
151 Ok(Message::FindIntersect(points))
152 }
153 5 => {
154 cbor::check_tagged_array_length(5, len, 3)?;
155 let point = d.decode()?;
156 let tip = d.decode()?;
157 Ok(Message::IntersectFound(point, tip))
158 }
159 6 => {
160 cbor::check_tagged_array_length(6, len, 2)?;
161 let tip = d.decode()?;
162 Ok(Message::IntersectNotFound(tip))
163 }
164 7 => {
165 cbor::check_tagged_array_length(7, len, 1)?;
166 Ok(Message::Done)
167 }
168 _ => Err(cbor::decode::Error::message("unknown variant for chainsync message")),
169 }
170 }
171}
172
173impl<'b> cbor::Decode<'b, ()> for HeaderContent {
174 fn decode(d: &mut cbor::Decoder<'b>, _ctx: &mut ()) -> Result<Self, cbor::decode::Error> {
175 let len = d.array()?;
176 let variant = EraName::from_header_variant(d.u8()?).map_err(cbor::decode::Error::custom)?;
177
178 match variant {
179 EraName::Byron => {
180 cbor::check_tagged_array_length(0, len, 2)?;
181 let len = d.array()?;
182 cbor::check_tagged_array_length(0, len, 2)?;
183
184 let (a, b): (u8, u64) = d.decode()?;
187
188 d.tag()?;
189 let bytes = d.bytes()?;
190
191 Ok(HeaderContent { variant, byron_prefix: Some((a, b)), cbor: Vec::from(bytes) })
192 }
193 EraName::Shelley
194 | EraName::Allegra
195 | EraName::Mary
196 | EraName::Alonzo
197 | EraName::Babbage
198 | EraName::Conway
199 | EraName::Dijkstra => {
200 cbor::check_tagged_array_length(variant.header_variant().into(), len, 2)?;
201 d.tag()?;
202 let bytes = d.bytes()?;
203 Ok(HeaderContent { variant, byron_prefix: None, cbor: Vec::from(bytes) })
204 }
205 }
206 }
207}
208
209impl cbor::Encode<()> for HeaderContent {
210 fn encode<W: cbor::encode::Write>(
211 &self,
212 e: &mut cbor::Encoder<W>,
213 _ctx: &mut (),
214 ) -> Result<(), cbor::encode::Error<W::Error>> {
215 e.array(2)?;
216 e.u8(self.variant.header_variant())?;
217
218 if self.variant == EraName::Byron {
219 e.array(2)?;
220
221 if let Some((a, b)) = self.byron_prefix {
222 e.array(2)?;
223 e.u8(a)?;
224 e.u64(b)?;
225 } else {
226 return Err(cbor::encode::Error::message("header variant 0 but no byron prefix"));
227 }
228
229 e.tag(cbor::IanaTag::Cbor)?;
230 e.bytes(&self.cbor)?;
231 } else {
232 e.tag(cbor::IanaTag::Cbor)?;
233 e.bytes(&self.cbor)?;
234 }
235
236 Ok(())
237 }
238}
239
240#[cfg(test)]
242mod tests {
243 use amaru_kernel::{any_era_name, any_point, any_tip, prop_cbor_roundtrip};
244 use proptest::{prelude::*, prop_compose};
245
246 use super::*;
247 use crate::{chainsync::messages::Message::*, protocol_messages::handshake::tests::any_byron_prefix};
248
249 mod header_content {
250 use super::*;
251 prop_cbor_roundtrip!(HeaderContent, any_header_content());
252 }
253
254 mod message {
255 use super::*;
256 prop_cbor_roundtrip!(Message, any_message());
257 }
258
259 fn done_message() -> impl Strategy<Value = Message> {
262 Just(Message::Done)
263 }
264
265 fn request_next_message() -> impl Strategy<Value = Message> {
266 Just(Message::RequestNext(1))
267 }
268
269 fn await_reply_message() -> impl Strategy<Value = Message> {
270 Just(Message::AwaitReply)
271 }
272
273 prop_compose! {
274 fn any_vec_u8()(elems in proptest::collection::vec(any::<u8>(), 0..10)) -> Vec<u8> {
275 elems
276 }
277 }
278
279 prop_compose! {
280 fn any_header_content()(variant in any_era_name(), byron_prefix in any_byron_prefix(), cbor in any_vec_u8()) -> HeaderContent {
281 if variant == EraName::Byron {
282 HeaderContent { variant, byron_prefix: Some(byron_prefix), cbor }
283 } else {
284 HeaderContent { variant, byron_prefix: None, cbor }
285 }
286 }
287 }
288
289 prop_compose! {
290 fn roll_forward_message()(header_content in any_header_content(), tip in any_tip()) -> Message {
291 RollForward(header_content, tip)
292 }
293 }
294
295 prop_compose! {
296 fn roll_backward_message()(point in any_point(), tip in any_tip()) -> Message {
297 RollBackward(point, tip)
298 }
299 }
300
301 prop_compose! {
302 fn find_intersect_message()(points in proptest::collection::vec(any_point(), 0..3)) -> Message {
303 FindIntersect(points)
304 }
305 }
306
307 prop_compose! {
308 fn intersect_found_message()(point in any_point(), tip in any_tip()) -> Message {
309 IntersectFound(point, tip)
310 }
311 }
312
313 prop_compose! {
314 fn intersect_not_found_message()(tip in any_tip()) -> Message {
315 IntersectNotFound(tip)
316 }
317 }
318
319 pub fn any_message() -> impl Strategy<Value = Message> {
320 prop_oneof![
321 1 => done_message(),
322 3 => request_next_message(),
323 3 => await_reply_message(),
324 3 => roll_forward_message(),
325 3 => roll_backward_message(),
326 3 => find_intersect_message(),
327 3 => intersect_found_message(),
328 3 => intersect_not_found_message(),
329 ]
330 }
331}