Skip to main content

amaru_protocols/chainsync/
messages.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use amaru_kernel::{BlockHeader, EraName, Point, Tip, cbor, to_cbor};
16use pure_stage::DeserializerGuards;
17
18pub fn register_deserializers() -> DeserializerGuards {
19    vec![pure_stage::register_data_deserializer::<Message>().boxed()]
20}
21
22#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, Ord, PartialOrd)]
23pub enum Message {
24    RequestNext(u8),
25    AwaitReply,
26    RollForward(HeaderContent, Tip),
27    RollBackward(Point, Tip),
28    FindIntersect(Vec<Point>),
29    IntersectFound(Point, Tip),
30    IntersectNotFound(Tip),
31    Done,
32}
33
34impl Message {
35    pub fn message_type(&self) -> &str {
36        match self {
37            Message::RequestNext(_) => "RequestNext",
38            Message::AwaitReply => "AwaitReply",
39            Message::RollForward(_, _) => "RollForward",
40            Message::RollBackward(_, _) => "RollBackward",
41            Message::FindIntersect(_) => "FindIntersect",
42            Message::IntersectFound(_, _) => "IntersectFound",
43            Message::IntersectNotFound(_) => "IntersectNotFound",
44            Message::Done => "Done",
45        }
46    }
47}
48
49#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, Ord, PartialOrd)]
50pub struct HeaderContent {
51    pub variant: EraName,
52    pub byron_prefix: Option<(u8, u64)>,
53    pub cbor: Vec<u8>,
54}
55
56impl HeaderContent {
57    pub fn new(header: &BlockHeader, era: EraName) -> Self {
58        Self { variant: era, byron_prefix: None, cbor: to_cbor(header) }
59    }
60
61    pub fn with_bytes(bytes: Vec<u8>, variant: EraName) -> Self {
62        Self { variant, byron_prefix: None, cbor: bytes }
63    }
64}
65
66impl cbor::Encode<()> for Message {
67    fn encode<W: cbor::encode::Write>(
68        &self,
69        e: &mut cbor::Encoder<W>,
70        _ctx: &mut (),
71    ) -> Result<(), cbor::encode::Error<W::Error>> {
72        match self {
73            Message::RequestNext(n) => {
74                for _ in 0..*n {
75                    e.array(1)?.u16(0)?;
76                }
77                Ok(())
78            }
79            Message::AwaitReply => {
80                e.array(1)?.u16(1)?;
81                Ok(())
82            }
83            Message::RollForward(content, tip) => {
84                e.array(3)?.u16(2)?;
85                e.encode(content)?;
86                e.encode(tip)?;
87                Ok(())
88            }
89            Message::RollBackward(point, tip) => {
90                e.array(3)?.u16(3)?;
91                e.encode(point)?;
92                e.encode(tip)?;
93                Ok(())
94            }
95            Message::FindIntersect(points) => {
96                e.array(2)?.u16(4)?;
97                e.array(points.len() as u64)?;
98                for point in points.iter() {
99                    e.encode(point)?;
100                }
101                Ok(())
102            }
103            Message::IntersectFound(point, tip) => {
104                e.array(3)?.u16(5)?;
105                e.encode(point)?;
106                e.encode(tip)?;
107                Ok(())
108            }
109            Message::IntersectNotFound(tip) => {
110                e.array(2)?.u16(6)?;
111                e.encode(tip)?;
112                Ok(())
113            }
114            Message::Done => {
115                e.array(1)?.u16(7)?;
116                Ok(())
117            }
118        }
119    }
120}
121
122impl<'b> cbor::Decode<'b, ()> for Message {
123    fn decode(d: &mut cbor::Decoder<'b>, _ctx: &mut ()) -> Result<Self, cbor::decode::Error> {
124        let len = d.array()?;
125        let label = d.u16()?;
126
127        match label {
128            0 => {
129                cbor::check_tagged_array_length(0, len, 1)?;
130                Ok(Message::RequestNext(1))
131            }
132            1 => {
133                cbor::check_tagged_array_length(1, len, 1)?;
134                Ok(Message::AwaitReply)
135            }
136            2 => {
137                cbor::check_tagged_array_length(2, len, 3)?;
138                let content = d.decode()?;
139                let tip = d.decode()?;
140                Ok(Message::RollForward(content, tip))
141            }
142            3 => {
143                cbor::check_tagged_array_length(3, len, 3)?;
144                let point = d.decode()?;
145                let tip = d.decode()?;
146                Ok(Message::RollBackward(point, tip))
147            }
148            4 => {
149                cbor::check_tagged_array_length(4, len, 2)?;
150                let points = d.decode()?;
151                Ok(Message::FindIntersect(points))
152            }
153            5 => {
154                cbor::check_tagged_array_length(5, len, 3)?;
155                let point = d.decode()?;
156                let tip = d.decode()?;
157                Ok(Message::IntersectFound(point, tip))
158            }
159            6 => {
160                cbor::check_tagged_array_length(6, len, 2)?;
161                let tip = d.decode()?;
162                Ok(Message::IntersectNotFound(tip))
163            }
164            7 => {
165                cbor::check_tagged_array_length(7, len, 1)?;
166                Ok(Message::Done)
167            }
168            _ => Err(cbor::decode::Error::message("unknown variant for chainsync message")),
169        }
170    }
171}
172
173impl<'b> cbor::Decode<'b, ()> for HeaderContent {
174    fn decode(d: &mut cbor::Decoder<'b>, _ctx: &mut ()) -> Result<Self, cbor::decode::Error> {
175        let len = d.array()?;
176        let variant = EraName::from_header_variant(d.u8()?).map_err(cbor::decode::Error::custom)?;
177
178        match variant {
179            EraName::Byron => {
180                cbor::check_tagged_array_length(0, len, 2)?;
181                let len = d.array()?;
182                cbor::check_tagged_array_length(0, len, 2)?;
183
184                // can't find a reference anywhere about the structure of these values, but they
185                // seem to provide the Byron-specific variant of the header
186                let (a, b): (u8, u64) = d.decode()?;
187
188                d.tag()?;
189                let bytes = d.bytes()?;
190
191                Ok(HeaderContent { variant, byron_prefix: Some((a, b)), cbor: Vec::from(bytes) })
192            }
193            EraName::Shelley
194            | EraName::Allegra
195            | EraName::Mary
196            | EraName::Alonzo
197            | EraName::Babbage
198            | EraName::Conway
199            | EraName::Dijkstra => {
200                cbor::check_tagged_array_length(variant.header_variant().into(), len, 2)?;
201                d.tag()?;
202                let bytes = d.bytes()?;
203                Ok(HeaderContent { variant, byron_prefix: None, cbor: Vec::from(bytes) })
204            }
205        }
206    }
207}
208
209impl cbor::Encode<()> for HeaderContent {
210    fn encode<W: cbor::encode::Write>(
211        &self,
212        e: &mut cbor::Encoder<W>,
213        _ctx: &mut (),
214    ) -> Result<(), cbor::encode::Error<W::Error>> {
215        e.array(2)?;
216        e.u8(self.variant.header_variant())?;
217
218        if self.variant == EraName::Byron {
219            e.array(2)?;
220
221            if let Some((a, b)) = self.byron_prefix {
222                e.array(2)?;
223                e.u8(a)?;
224                e.u64(b)?;
225            } else {
226                return Err(cbor::encode::Error::message("header variant 0 but no byron prefix"));
227            }
228
229            e.tag(cbor::IanaTag::Cbor)?;
230            e.bytes(&self.cbor)?;
231        } else {
232            e.tag(cbor::IanaTag::Cbor)?;
233            e.bytes(&self.cbor)?;
234        }
235
236        Ok(())
237    }
238}
239
240/// Roundtrip property tests for chainsync messages.
241#[cfg(test)]
242mod tests {
243    use amaru_kernel::{any_era_name, any_point, any_tip, prop_cbor_roundtrip};
244    use proptest::{prelude::*, prop_compose};
245
246    use super::*;
247    use crate::{chainsync::messages::Message::*, protocol_messages::handshake::tests::any_byron_prefix};
248
249    mod header_content {
250        use super::*;
251        prop_cbor_roundtrip!(HeaderContent, any_header_content());
252    }
253
254    mod message {
255        use super::*;
256        prop_cbor_roundtrip!(Message, any_message());
257    }
258
259    // HELPERS
260
261    fn done_message() -> impl Strategy<Value = Message> {
262        Just(Message::Done)
263    }
264
265    fn request_next_message() -> impl Strategy<Value = Message> {
266        Just(Message::RequestNext(1))
267    }
268
269    fn await_reply_message() -> impl Strategy<Value = Message> {
270        Just(Message::AwaitReply)
271    }
272
273    prop_compose! {
274        fn any_vec_u8()(elems in proptest::collection::vec(any::<u8>(), 0..10)) -> Vec<u8> {
275            elems
276        }
277    }
278
279    prop_compose! {
280        fn any_header_content()(variant in any_era_name(), byron_prefix in any_byron_prefix(), cbor in any_vec_u8()) -> HeaderContent {
281            if variant == EraName::Byron {
282                HeaderContent { variant, byron_prefix: Some(byron_prefix), cbor }
283            } else {
284                HeaderContent { variant, byron_prefix: None, cbor }
285            }
286        }
287    }
288
289    prop_compose! {
290        fn roll_forward_message()(header_content in any_header_content(), tip in any_tip()) -> Message {
291            RollForward(header_content, tip)
292        }
293    }
294
295    prop_compose! {
296        fn roll_backward_message()(point in any_point(), tip in any_tip()) -> Message {
297            RollBackward(point, tip)
298        }
299    }
300
301    prop_compose! {
302        fn find_intersect_message()(points in proptest::collection::vec(any_point(), 0..3)) -> Message {
303            FindIntersect(points)
304        }
305    }
306
307    prop_compose! {
308        fn intersect_found_message()(point in any_point(), tip in any_tip()) -> Message {
309            IntersectFound(point, tip)
310        }
311    }
312
313    prop_compose! {
314        fn intersect_not_found_message()(tip in any_tip()) -> Message {
315            IntersectNotFound(tip)
316        }
317    }
318
319    pub fn any_message() -> impl Strategy<Value = Message> {
320        prop_oneof![
321            1 => done_message(),
322            3 => request_next_message(),
323            3 => await_reply_message(),
324            3 => roll_forward_message(),
325            3 => roll_backward_message(),
326            3 => find_intersect_message(),
327            3 => intersect_found_message(),
328            3 => intersect_not_found_message(),
329        ]
330    }
331}