Skip to main content

amaru_protocols/handshake/
messages.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16
17use amaru_kernel::cbor;
18
19use crate::protocol_messages::{handshake::RefuseReason, version_number::VersionNumber, version_table::VersionTable};
20
21#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
22pub enum Message<D>
23where
24    D: fmt::Debug + Clone,
25{
26    Propose(VersionTable<D>),
27    Accept(VersionNumber, D),
28    Refuse(RefuseReason),
29    QueryReply(VersionTable<D>),
30}
31
32impl<D> Message<D>
33where
34    D: fmt::Debug + Clone,
35{
36    pub fn message_type(&self) -> &str {
37        match self {
38            Message::Propose(_) => "Propose",
39            Message::Accept(_, _) => "Accept",
40            Message::Refuse(_) => "Refuse",
41            Message::QueryReply(_) => "QueryReply",
42        }
43    }
44}
45
46impl<D> cbor::Encode<()> for Message<D>
47where
48    D: fmt::Debug + Clone + cbor::Encode<VersionNumber>,
49    VersionTable<D>: cbor::Encode<()>,
50{
51    fn encode<W: cbor::encode::Write>(
52        &self,
53        e: &mut cbor::Encoder<W>,
54        _ctx: &mut (),
55    ) -> Result<(), cbor::encode::Error<W::Error>> {
56        match self {
57            Message::Propose(version_table) => {
58                e.array(2)?.u16(0)?;
59                e.encode(version_table)?;
60            }
61            Message::Accept(version_number, version_data) => {
62                e.array(3)?.u16(1)?;
63                e.encode(version_number)?;
64                let mut ctx = *version_number;
65                e.encode_with(version_data, &mut ctx)?;
66            }
67            Message::Refuse(reason) => {
68                e.array(2)?.u16(2)?;
69                e.encode(reason)?;
70            }
71            Message::QueryReply(version_table) => {
72                e.array(2)?.u16(3)?;
73                e.encode(version_table)?;
74            }
75        };
76
77        Ok(())
78    }
79}
80
81impl<'b, D> cbor::Decode<'b, ()> for Message<D>
82where
83    D: cbor::Decode<'b, VersionNumber> + fmt::Debug + Clone,
84    VersionTable<D>: cbor::Decode<'b, ()>,
85{
86    fn decode(d: &mut cbor::Decoder<'b>, _ctx: &mut ()) -> Result<Self, cbor::decode::Error> {
87        let len = d.array()?;
88
89        match d.u16()? {
90            0 => {
91                cbor::check_tagged_array_length(0, len, 2)?;
92                let version_table = d.decode()?;
93                Ok(Message::Propose(version_table))
94            }
95            1 => {
96                cbor::check_tagged_array_length(1, len, 3)?;
97                let version_number = d.decode()?;
98                let mut ctx = version_number;
99                let version_data = d.decode_with(&mut ctx)?;
100                Ok(Message::Accept(version_number, version_data))
101            }
102            2 => {
103                cbor::check_tagged_array_length(2, len, 2)?;
104                let reason: RefuseReason = d.decode()?;
105                Ok(Message::Refuse(reason))
106            }
107            3 => {
108                cbor::check_tagged_array_length(3, len, 2)?;
109                let version_table = d.decode()?;
110                Ok(Message::QueryReply(version_table))
111            }
112            n => Err(cbor::decode::Error::message(format!("unknown variant for handshake message: {}", n,))),
113        }
114    }
115}
116
117/// Roundtrip property tests for handshake messages.
118#[cfg(test)]
119pub(crate) mod tests {
120    use amaru_kernel::prop_cbor_roundtrip;
121    use proptest::{prelude::*, prop_compose};
122
123    use super::*;
124    use crate::{
125        handshake::messages::Message::*,
126        protocol_messages::{
127            handshake::tests::any_refuse_reason,
128            version_data::{VersionData, tests::any_version_data},
129            version_number::tests::any_version_number,
130            version_table::tests::any_version_table,
131        },
132    };
133
134    prop_cbor_roundtrip!(Message<VersionData>, any_message());
135
136    // HELPERS
137    prop_compose! {
138        fn any_propose_message()(version_table in any_version_table()) -> Message<VersionData> {
139            Propose(version_table)
140        }
141    }
142
143    prop_compose! {
144        fn any_query_reply_message()(version_table in any_version_table()) -> Message<VersionData> {
145            QueryReply(version_table)
146        }
147    }
148
149    prop_compose! {
150        fn any_accept_message()(version_number in any_version_number(), version_data in any_version_data()) -> Message<VersionData> {
151            Accept(version_number, version_data)
152        }
153    }
154
155    prop_compose! {
156        fn any_refuse_message()(reason in any_refuse_reason()) -> Message<VersionData> {
157            Refuse(reason)
158        }
159    }
160
161    pub fn any_message() -> impl Strategy<Value = Message<VersionData>> {
162        prop_oneof![
163            1 => any_query_reply_message(),
164            1 => any_propose_message(),
165            1 => any_accept_message(),
166            1 => any_refuse_message(),
167        ]
168    }
169}