amaru_protocols/protocol_messages/
handshake.rs1use amaru_kernel::cbor::{self, check_tagged_array_length};
16
17use crate::protocol_messages::{version_data::VersionData, version_number::VersionNumber, version_table::VersionTable};
18
19#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
20pub enum HandshakeResult {
21 Accepted(VersionNumber, VersionData),
22 Refused(RefuseReason),
23 Query(VersionTable<VersionData>),
24}
25
26#[derive(Debug, PartialEq, Eq, Clone, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
27pub enum RefuseReason {
28 VersionMismatch(Vec<VersionNumber>),
29 HandshakeDecodeError(VersionNumber, String),
30 Refused(VersionNumber, String),
31}
32
33impl cbor::Encode<()> for RefuseReason {
34 fn encode<W: cbor::encode::Write>(
35 &self,
36 e: &mut cbor::Encoder<W>,
37 _ctx: &mut (),
38 ) -> Result<(), cbor::encode::Error<W::Error>> {
39 match self {
40 RefuseReason::VersionMismatch(versions) => {
41 e.array(2)?;
42 e.u16(0)?;
43 e.array(versions.len() as u64)?;
44 for v in versions.iter() {
45 e.encode(v)?;
46 }
47
48 Ok(())
49 }
50 RefuseReason::HandshakeDecodeError(version, msg) => {
51 e.array(3)?;
52 e.u16(1)?;
53 e.encode(version)?;
54 e.str(msg)?;
55
56 Ok(())
57 }
58 RefuseReason::Refused(version, msg) => {
59 e.array(3)?;
60 e.u16(2)?;
61 e.encode(version)?;
62 e.str(msg)?;
63
64 Ok(())
65 }
66 }
67 }
68}
69
70impl<'b> cbor::Decode<'b, ()> for RefuseReason {
71 fn decode(d: &mut cbor::Decoder<'b>, _ctx: &mut ()) -> Result<Self, cbor::decode::Error> {
72 let len = d.array()?;
73
74 match d.u16()? {
75 0 => {
76 check_tagged_array_length(0, len, 2)?;
77 let versions = d.array_iter::<VersionNumber>()?;
78 let versions = versions.collect::<Result<_, _>>()?;
79 Ok(RefuseReason::VersionMismatch(versions))
80 }
81 1 => {
82 check_tagged_array_length(1, len, 3)?;
83 let version = d.decode()?;
84 let msg = d.str()?;
85
86 Ok(RefuseReason::HandshakeDecodeError(version, msg.to_string()))
87 }
88 2 => {
89 check_tagged_array_length(2, len, 3)?;
90 let version = d.decode()?;
91 let msg = d.str()?;
92
93 Ok(RefuseReason::Refused(version, msg.to_string()))
94 }
95 _ => Err(cbor::decode::Error::message("unknown variant for refusereason")),
96 }
97 }
98}
99
100#[cfg(test)]
101pub(crate) mod tests {
102 use amaru_kernel::prop_cbor_roundtrip;
103 use proptest::{prelude::*, prop_compose};
104
105 use super::*;
106 use crate::protocol_messages::version_number::tests::any_version_number;
107
108 prop_cbor_roundtrip!(RefuseReason, any_refuse_reason());
109
110 prop_compose! {
111 pub fn any_handshake_decode_error_reason()(version_number in any_version_number(), message in any::<String>()) -> RefuseReason {
112 RefuseReason::HandshakeDecodeError(version_number, message)
113 }
114 }
115
116 prop_compose! {
117 pub fn any_refused_reason()(version_number in any_version_number(), message in any::<String>()) -> RefuseReason {
118 RefuseReason::Refused(version_number, message)
119 }
120 }
121
122 prop_compose! {
123 pub fn any_version_mismatch_reason()(versions in proptest::collection::vec(any_version_number(), 1..3)) -> RefuseReason {
124 RefuseReason::VersionMismatch(versions)
125 }
126 }
127
128 pub fn any_refuse_reason() -> impl Strategy<Value = RefuseReason> {
129 prop_oneof![
130 1 => any_version_mismatch_reason(),
131 1 => any_handshake_decode_error_reason(),
132 1 => any_refused_reason(),
133 ]
134 }
135
136 prop_compose! {
137 pub fn any_byron_prefix()(b1 in any::<u8>(), b2 in any::<u64>()) -> (u8, u64) {
138 (b1, b2)
139 }
140 }
141}