Skip to main content

amaru_protocols/blockfetch/
messages.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use amaru_kernel::{Point, cbor};
16
17#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
18pub enum Message {
19    RequestRange { from: Point, through: Point },
20    ClientDone,
21    StartBatch,
22    NoBlocks,
23    Block { body: Vec<u8> },
24    BatchDone,
25}
26
27impl Message {
28    pub fn message_type(&self) -> &str {
29        match self {
30            Message::RequestRange { .. } => "RequestRange",
31            Message::ClientDone => "ClientDone",
32            Message::StartBatch => "StartBatch",
33            Message::NoBlocks => "NoBlocks",
34            Message::Block { .. } => "Block",
35            Message::BatchDone => "BatchDone",
36        }
37    }
38}
39
40impl cbor::Encode<()> for Message {
41    fn encode<W: cbor::encode::Write>(
42        &self,
43        e: &mut cbor::Encoder<W>,
44        _ctx: &mut (),
45    ) -> Result<(), cbor::encode::Error<W::Error>> {
46        match self {
47            Message::RequestRange { from, through } => {
48                e.array(3)?.u16(0)?;
49                e.encode(from)?;
50                e.encode(through)?;
51                Ok(())
52            }
53            Message::ClientDone => {
54                e.array(1)?.u16(1)?;
55                Ok(())
56            }
57            Message::StartBatch => {
58                e.array(1)?.u16(2)?;
59                Ok(())
60            }
61            Message::NoBlocks => {
62                e.array(1)?.u16(3)?;
63                Ok(())
64            }
65            Message::Block { body } => {
66                e.array(2)?.u16(4)?;
67                e.tag(cbor::IanaTag::Cbor)?;
68                e.bytes(body)?;
69                Ok(())
70            }
71            Message::BatchDone => {
72                e.array(1)?.u16(5)?;
73                Ok(())
74            }
75        }
76    }
77}
78
79impl<'b> cbor::Decode<'b, ()> for Message {
80    fn decode(d: &mut cbor::Decoder<'b>, _ctx: &mut ()) -> Result<Self, cbor::decode::Error> {
81        let len = d.array()?;
82        let label = d.u16()?;
83
84        match label {
85            0 => {
86                cbor::check_tagged_array_length(0, len, 3)?;
87                let from = d.decode()?;
88                let through = d.decode()?;
89                Ok(Message::RequestRange { from, through })
90            }
91            1 => {
92                cbor::check_tagged_array_length(1, len, 1)?;
93                Ok(Message::ClientDone)
94            }
95            2 => {
96                cbor::check_tagged_array_length(2, len, 1)?;
97                Ok(Message::StartBatch)
98            }
99            3 => {
100                cbor::check_tagged_array_length(3, len, 1)?;
101                Ok(Message::NoBlocks)
102            }
103            4 => {
104                cbor::check_tagged_array_length(4, len, 2)?;
105                let tag = d.tag()?;
106                if tag != cbor::IanaTag::Cbor.tag() {
107                    return Err(cbor::decode::Error::message(format!(
108                        "unexpected tag for Block: expected {}, got {}",
109                        cbor::IanaTag::Cbor.tag(),
110                        tag
111                    )));
112                }
113
114                let body = d.bytes()?;
115                Ok(Message::Block { body: Vec::from(body) })
116            }
117            5 => {
118                cbor::check_tagged_array_length(5, len, 1)?;
119                Ok(Message::BatchDone)
120            }
121            _ => Err(cbor::decode::Error::message("unknown variant for blockfetch message")),
122        }
123    }
124}
125
126/// Roundtrip property tests for blockfetch messages.
127#[cfg(test)]
128pub(crate) mod tests {
129    use amaru_kernel::{any_point, prop_cbor_roundtrip};
130    use proptest::{prelude::*, prop_compose};
131
132    use super::*;
133
134    prop_cbor_roundtrip!(Message, any_message());
135
136    // HELPERS
137
138    fn block_message() -> impl Strategy<Value = Message> {
139        Just(Message::Block { body: vec![0u8; 128] })
140    }
141
142    fn no_blocks_message() -> impl Strategy<Value = Message> {
143        Just(Message::NoBlocks)
144    }
145
146    fn batch_done_message() -> impl Strategy<Value = Message> {
147        Just(Message::BatchDone)
148    }
149
150    fn start_batch_message() -> impl Strategy<Value = Message> {
151        Just(Message::StartBatch)
152    }
153
154    fn client_done_message() -> impl Strategy<Value = Message> {
155        Just(Message::ClientDone)
156    }
157
158    prop_compose! {
159        fn request_range_message()(from in any_point(), through in any_point()) -> Message {
160            Message::RequestRange {from, through}
161        }
162    }
163
164    pub fn any_message() -> impl Strategy<Value = Message> {
165        prop_oneof![
166            1 => block_message(),
167            3 => no_blocks_message(),
168            3 => start_batch_message(),
169            3 => batch_done_message(),
170            3 => client_done_message(),
171            3 => request_range_message(),
172        ]
173    }
174}