Skip to main content

hermod/mux/
handshake.rs

1//! Trace-forward protocol handshake implementation
2//!
3//! Implements the handshake for the trace-forward protocol, which negotiates
4//! the protocol version (ForwardingV_1) and network magic.
5
6use pallas_codec::minicbor::{Decode, Decoder, Encode, Encoder, decode, encode};
7use std::collections::HashMap;
8
9/// Version data containing network magic
10#[derive(Debug, Clone)]
11pub struct ForwardingVersionData {
12    /// The Cardano network magic number
13    pub network_magic: u64,
14}
15
16impl Encode<()> for ForwardingVersionData {
17    fn encode<W: encode::Write>(
18        &self,
19        e: &mut Encoder<W>,
20        _ctx: &mut (),
21    ) -> Result<(), encode::Error<W::Error>> {
22        e.u64(self.network_magic)?;
23        Ok(())
24    }
25}
26
27impl<'b> Decode<'b, ()> for ForwardingVersionData {
28    fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result<Self, decode::Error> {
29        let network_magic = d.u64()?;
30        Ok(ForwardingVersionData { network_magic })
31    }
32}
33
34/// Version table for handshake negotiation
35pub type VersionTable = HashMap<u64, ForwardingVersionData>;
36
37/// Creates a version table with ForwardingV_1
38pub fn version_table_v1(network_magic: u64) -> VersionTable {
39    let mut table = HashMap::new();
40    table.insert(1, ForwardingVersionData { network_magic });
41    table
42}
43
44/// Handshake message types
45#[derive(Debug, Clone)]
46pub enum HandshakeMessage {
47    /// Propose versions
48    Propose(VersionTable),
49    /// Accept a version
50    Accept(u64, ForwardingVersionData),
51    /// Refuse all versions
52    Refuse(Vec<u64>),
53}
54
55impl Encode<()> for HandshakeMessage {
56    fn encode<W: encode::Write>(
57        &self,
58        e: &mut Encoder<W>,
59        _ctx: &mut (),
60    ) -> Result<(), encode::Error<W::Error>> {
61        match self {
62            HandshakeMessage::Propose(versions) => {
63                e.array(2)?.u16(0)?;
64                e.map(versions.len() as u64)?;
65                for (version, data) in versions {
66                    e.encode(version)?;
67                    e.encode_with(data, _ctx)?;
68                }
69            }
70            HandshakeMessage::Accept(version, data) => {
71                e.array(3)?.u16(1)?;
72                e.encode(version)?;
73                e.encode_with(data, _ctx)?;
74            }
75            HandshakeMessage::Refuse(versions) => {
76                e.array(2)?.u16(2)?;
77                e.array(versions.len() as u64)?;
78                for v in versions {
79                    e.encode(v)?;
80                }
81            }
82        }
83        Ok(())
84    }
85}
86
87impl<'b> Decode<'b, ()> for HandshakeMessage {
88    fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result<Self, decode::Error> {
89        d.array()?;
90        let label = d.u16()?;
91
92        match label {
93            0 => {
94                let map_len = d
95                    .map()?
96                    .ok_or_else(|| decode::Error::message("expected definite map"))?;
97                let mut versions = HashMap::new();
98                for _ in 0..map_len {
99                    let version = d.decode()?;
100                    let data = d.decode_with(_ctx)?;
101                    versions.insert(version, data);
102                }
103                Ok(HandshakeMessage::Propose(versions))
104            }
105            1 => {
106                let version = d.decode()?;
107                let data = d.decode_with(_ctx)?;
108                Ok(HandshakeMessage::Accept(version, data))
109            }
110            2 => {
111                let arr_len = d
112                    .array()?
113                    .ok_or_else(|| decode::Error::message("expected definite array"))?;
114                let mut versions = Vec::new();
115                for _ in 0..arr_len {
116                    versions.push(d.decode()?);
117                }
118                Ok(HandshakeMessage::Refuse(versions))
119            }
120            _ => Err(decode::Error::message("unknown handshake message")),
121        }
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use pallas_codec::minicbor;
129
130    fn encode<T: minicbor::Encode<()>>(value: &T) -> Vec<u8> {
131        let mut buf = Vec::new();
132        minicbor::encode_with(value, &mut buf, &mut ()).unwrap();
133        buf
134    }
135
136    fn decode<T: for<'b> minicbor::Decode<'b, ()>>(buf: &[u8]) -> T {
137        minicbor::decode_with(buf, &mut ()).unwrap()
138    }
139
140    #[test]
141    fn version_data_round_trip() {
142        let data = ForwardingVersionData {
143            network_magic: 764824073,
144        };
145        let buf = encode(&data);
146        let decoded: ForwardingVersionData = decode(&buf);
147        assert_eq!(decoded.network_magic, 764824073);
148    }
149
150    #[test]
151    fn version_table_v1_has_single_version_1() {
152        let table = version_table_v1(12345);
153        assert_eq!(table.len(), 1);
154        assert!(table.contains_key(&1));
155        assert_eq!(table[&1].network_magic, 12345);
156    }
157
158    #[test]
159    fn propose_round_trip() {
160        let versions = version_table_v1(764824073);
161        let msg = HandshakeMessage::Propose(versions);
162        let buf = encode(&msg);
163        let decoded: HandshakeMessage = decode(&buf);
164        match decoded {
165            HandshakeMessage::Propose(v) => {
166                assert!(v.contains_key(&1));
167                assert_eq!(v[&1].network_magic, 764824073);
168            }
169            _ => panic!("expected Propose, got something else"),
170        }
171    }
172
173    #[test]
174    fn accept_round_trip() {
175        let msg = HandshakeMessage::Accept(1, ForwardingVersionData { network_magic: 42 });
176        let buf = encode(&msg);
177        let decoded: HandshakeMessage = decode(&buf);
178        match decoded {
179            HandshakeMessage::Accept(ver, data) => {
180                assert_eq!(ver, 1);
181                assert_eq!(data.network_magic, 42);
182            }
183            _ => panic!("expected Accept"),
184        }
185    }
186
187    #[test]
188    fn refuse_round_trip() {
189        let msg = HandshakeMessage::Refuse(vec![1, 2, 3]);
190        let buf = encode(&msg);
191        let decoded: HandshakeMessage = decode(&buf);
192        match decoded {
193            HandshakeMessage::Refuse(mut versions) => {
194                versions.sort_unstable();
195                assert_eq!(versions, vec![1, 2, 3]);
196            }
197            _ => panic!("expected Refuse"),
198        }
199    }
200
201    #[test]
202    fn refuse_empty_versions_round_trip() {
203        let msg = HandshakeMessage::Refuse(vec![]);
204        let buf = encode(&msg);
205        match decode::<HandshakeMessage>(&buf) {
206            HandshakeMessage::Refuse(v) => assert!(v.is_empty()),
207            _ => panic!("expected Refuse"),
208        }
209    }
210}