Skip to main content

amaru_protocols/protocol_messages/
handshake.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use amaru_kernel::cbor::{self, check_tagged_array_length};
16
17use crate::protocol_messages::{version_data::VersionData, version_number::VersionNumber, version_table::VersionTable};
18
19#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
20pub enum HandshakeResult {
21    Accepted(VersionNumber, VersionData),
22    Refused(RefuseReason),
23    Query(VersionTable<VersionData>),
24}
25
26#[derive(Debug, PartialEq, Eq, Clone, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
27pub enum RefuseReason {
28    VersionMismatch(Vec<VersionNumber>),
29    HandshakeDecodeError(VersionNumber, String),
30    Refused(VersionNumber, String),
31}
32
33impl cbor::Encode<()> for RefuseReason {
34    fn encode<W: cbor::encode::Write>(
35        &self,
36        e: &mut cbor::Encoder<W>,
37        _ctx: &mut (),
38    ) -> Result<(), cbor::encode::Error<W::Error>> {
39        match self {
40            RefuseReason::VersionMismatch(versions) => {
41                e.array(2)?;
42                e.u16(0)?;
43                e.array(versions.len() as u64)?;
44                for v in versions.iter() {
45                    e.encode(v)?;
46                }
47
48                Ok(())
49            }
50            RefuseReason::HandshakeDecodeError(version, msg) => {
51                e.array(3)?;
52                e.u16(1)?;
53                e.encode(version)?;
54                e.str(msg)?;
55
56                Ok(())
57            }
58            RefuseReason::Refused(version, msg) => {
59                e.array(3)?;
60                e.u16(2)?;
61                e.encode(version)?;
62                e.str(msg)?;
63
64                Ok(())
65            }
66        }
67    }
68}
69
70impl<'b> cbor::Decode<'b, ()> for RefuseReason {
71    fn decode(d: &mut cbor::Decoder<'b>, _ctx: &mut ()) -> Result<Self, cbor::decode::Error> {
72        let len = d.array()?;
73
74        match d.u16()? {
75            0 => {
76                check_tagged_array_length(0, len, 2)?;
77                let versions = d.array_iter::<VersionNumber>()?;
78                let versions = versions.collect::<Result<_, _>>()?;
79                Ok(RefuseReason::VersionMismatch(versions))
80            }
81            1 => {
82                check_tagged_array_length(1, len, 3)?;
83                let version = d.decode()?;
84                let msg = d.str()?;
85
86                Ok(RefuseReason::HandshakeDecodeError(version, msg.to_string()))
87            }
88            2 => {
89                check_tagged_array_length(2, len, 3)?;
90                let version = d.decode()?;
91                let msg = d.str()?;
92
93                Ok(RefuseReason::Refused(version, msg.to_string()))
94            }
95            _ => Err(cbor::decode::Error::message("unknown variant for refusereason")),
96        }
97    }
98}
99
100#[cfg(test)]
101pub(crate) mod tests {
102    use amaru_kernel::prop_cbor_roundtrip;
103    use proptest::{prelude::*, prop_compose};
104
105    use super::*;
106    use crate::protocol_messages::version_number::tests::any_version_number;
107
108    prop_cbor_roundtrip!(RefuseReason, any_refuse_reason());
109
110    prop_compose! {
111        pub fn any_handshake_decode_error_reason()(version_number in any_version_number(), message in any::<String>()) -> RefuseReason {
112            RefuseReason::HandshakeDecodeError(version_number, message)
113        }
114    }
115
116    prop_compose! {
117        pub fn any_refused_reason()(version_number in any_version_number(), message in any::<String>()) -> RefuseReason {
118            RefuseReason::Refused(version_number, message)
119        }
120    }
121
122    prop_compose! {
123        pub fn any_version_mismatch_reason()(versions in proptest::collection::vec(any_version_number(), 1..3)) -> RefuseReason {
124            RefuseReason::VersionMismatch(versions)
125        }
126    }
127
128    pub fn any_refuse_reason() -> impl Strategy<Value = RefuseReason> {
129        prop_oneof![
130            1 => any_version_mismatch_reason(),
131            1 => any_handshake_decode_error_reason(),
132            1 => any_refused_reason(),
133        ]
134    }
135
136    prop_compose! {
137        pub fn any_byron_prefix()(b1 in any::<u8>(), b2 in any::<u64>()) -> (u8, u64) {
138            (b1, b2)
139        }
140    }
141}