amaru_protocols/handshake/
messages.rs1use std::fmt;
16
17use amaru_kernel::cbor;
18
19use crate::protocol_messages::{handshake::RefuseReason, version_number::VersionNumber, version_table::VersionTable};
20
21#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
22pub enum Message<D>
23where
24 D: fmt::Debug + Clone,
25{
26 Propose(VersionTable<D>),
27 Accept(VersionNumber, D),
28 Refuse(RefuseReason),
29 QueryReply(VersionTable<D>),
30}
31
32impl<D> Message<D>
33where
34 D: fmt::Debug + Clone,
35{
36 pub fn message_type(&self) -> &str {
37 match self {
38 Message::Propose(_) => "Propose",
39 Message::Accept(_, _) => "Accept",
40 Message::Refuse(_) => "Refuse",
41 Message::QueryReply(_) => "QueryReply",
42 }
43 }
44}
45
46impl<D> cbor::Encode<()> for Message<D>
47where
48 D: fmt::Debug + Clone + cbor::Encode<VersionNumber>,
49 VersionTable<D>: cbor::Encode<()>,
50{
51 fn encode<W: cbor::encode::Write>(
52 &self,
53 e: &mut cbor::Encoder<W>,
54 _ctx: &mut (),
55 ) -> Result<(), cbor::encode::Error<W::Error>> {
56 match self {
57 Message::Propose(version_table) => {
58 e.array(2)?.u16(0)?;
59 e.encode(version_table)?;
60 }
61 Message::Accept(version_number, version_data) => {
62 e.array(3)?.u16(1)?;
63 e.encode(version_number)?;
64 let mut ctx = *version_number;
65 e.encode_with(version_data, &mut ctx)?;
66 }
67 Message::Refuse(reason) => {
68 e.array(2)?.u16(2)?;
69 e.encode(reason)?;
70 }
71 Message::QueryReply(version_table) => {
72 e.array(2)?.u16(3)?;
73 e.encode(version_table)?;
74 }
75 };
76
77 Ok(())
78 }
79}
80
81impl<'b, D> cbor::Decode<'b, ()> for Message<D>
82where
83 D: cbor::Decode<'b, VersionNumber> + fmt::Debug + Clone,
84 VersionTable<D>: cbor::Decode<'b, ()>,
85{
86 fn decode(d: &mut cbor::Decoder<'b>, _ctx: &mut ()) -> Result<Self, cbor::decode::Error> {
87 let len = d.array()?;
88
89 match d.u16()? {
90 0 => {
91 cbor::check_tagged_array_length(0, len, 2)?;
92 let version_table = d.decode()?;
93 Ok(Message::Propose(version_table))
94 }
95 1 => {
96 cbor::check_tagged_array_length(1, len, 3)?;
97 let version_number = d.decode()?;
98 let mut ctx = version_number;
99 let version_data = d.decode_with(&mut ctx)?;
100 Ok(Message::Accept(version_number, version_data))
101 }
102 2 => {
103 cbor::check_tagged_array_length(2, len, 2)?;
104 let reason: RefuseReason = d.decode()?;
105 Ok(Message::Refuse(reason))
106 }
107 3 => {
108 cbor::check_tagged_array_length(3, len, 2)?;
109 let version_table = d.decode()?;
110 Ok(Message::QueryReply(version_table))
111 }
112 n => Err(cbor::decode::Error::message(format!("unknown variant for handshake message: {}", n,))),
113 }
114 }
115}
116
117#[cfg(test)]
119pub(crate) mod tests {
120 use amaru_kernel::prop_cbor_roundtrip;
121 use proptest::{prelude::*, prop_compose};
122
123 use super::*;
124 use crate::{
125 handshake::messages::Message::*,
126 protocol_messages::{
127 handshake::tests::any_refuse_reason,
128 version_data::{VersionData, tests::any_version_data},
129 version_number::tests::any_version_number,
130 version_table::tests::any_version_table,
131 },
132 };
133
134 prop_cbor_roundtrip!(Message<VersionData>, any_message());
135
136 prop_compose! {
138 fn any_propose_message()(version_table in any_version_table()) -> Message<VersionData> {
139 Propose(version_table)
140 }
141 }
142
143 prop_compose! {
144 fn any_query_reply_message()(version_table in any_version_table()) -> Message<VersionData> {
145 QueryReply(version_table)
146 }
147 }
148
149 prop_compose! {
150 fn any_accept_message()(version_number in any_version_number(), version_data in any_version_data()) -> Message<VersionData> {
151 Accept(version_number, version_data)
152 }
153 }
154
155 prop_compose! {
156 fn any_refuse_message()(reason in any_refuse_reason()) -> Message<VersionData> {
157 Refuse(reason)
158 }
159 }
160
161 pub fn any_message() -> impl Strategy<Value = Message<VersionData>> {
162 prop_oneof![
163 1 => any_query_reply_message(),
164 1 => any_propose_message(),
165 1 => any_accept_message(),
166 1 => any_refuse_message(),
167 ]
168 }
169}