1use pallas_codec::minicbor::{Decode, Decoder, Encode, Encoder, decode, encode};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11pub struct ForwardingVersionData {
12 pub network_magic: u64,
14}
15
16impl Encode<()> for ForwardingVersionData {
17 fn encode<W: encode::Write>(
18 &self,
19 e: &mut Encoder<W>,
20 _ctx: &mut (),
21 ) -> Result<(), encode::Error<W::Error>> {
22 e.u64(self.network_magic)?;
23 Ok(())
24 }
25}
26
27impl<'b> Decode<'b, ()> for ForwardingVersionData {
28 fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result<Self, decode::Error> {
29 let network_magic = d.u64()?;
30 Ok(ForwardingVersionData { network_magic })
31 }
32}
33
34pub type VersionTable = HashMap<u64, ForwardingVersionData>;
36
37pub fn version_table_v1(network_magic: u64) -> VersionTable {
39 let mut table = HashMap::new();
40 table.insert(1, ForwardingVersionData { network_magic });
41 table
42}
43
44#[derive(Debug, Clone)]
46pub enum HandshakeMessage {
47 Propose(VersionTable),
49 Accept(u64, ForwardingVersionData),
51 Refuse(Vec<u64>),
53}
54
55impl Encode<()> for HandshakeMessage {
56 fn encode<W: encode::Write>(
57 &self,
58 e: &mut Encoder<W>,
59 _ctx: &mut (),
60 ) -> Result<(), encode::Error<W::Error>> {
61 match self {
62 HandshakeMessage::Propose(versions) => {
63 e.array(2)?.u16(0)?;
64 e.map(versions.len() as u64)?;
65 for (version, data) in versions {
66 e.encode(version)?;
67 e.encode_with(data, _ctx)?;
68 }
69 }
70 HandshakeMessage::Accept(version, data) => {
71 e.array(3)?.u16(1)?;
72 e.encode(version)?;
73 e.encode_with(data, _ctx)?;
74 }
75 HandshakeMessage::Refuse(versions) => {
76 e.array(2)?.u16(2)?;
77 e.array(versions.len() as u64)?;
78 for v in versions {
79 e.encode(v)?;
80 }
81 }
82 }
83 Ok(())
84 }
85}
86
87impl<'b> Decode<'b, ()> for HandshakeMessage {
88 fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result<Self, decode::Error> {
89 d.array()?;
90 let label = d.u16()?;
91
92 match label {
93 0 => {
94 let map_len = d
95 .map()?
96 .ok_or_else(|| decode::Error::message("expected definite map"))?;
97 let mut versions = HashMap::new();
98 for _ in 0..map_len {
99 let version = d.decode()?;
100 let data = d.decode_with(_ctx)?;
101 versions.insert(version, data);
102 }
103 Ok(HandshakeMessage::Propose(versions))
104 }
105 1 => {
106 let version = d.decode()?;
107 let data = d.decode_with(_ctx)?;
108 Ok(HandshakeMessage::Accept(version, data))
109 }
110 2 => {
111 let arr_len = d
112 .array()?
113 .ok_or_else(|| decode::Error::message("expected definite array"))?;
114 let mut versions = Vec::new();
115 for _ in 0..arr_len {
116 versions.push(d.decode()?);
117 }
118 Ok(HandshakeMessage::Refuse(versions))
119 }
120 _ => Err(decode::Error::message("unknown handshake message")),
121 }
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use pallas_codec::minicbor;
129
130 fn encode<T: minicbor::Encode<()>>(value: &T) -> Vec<u8> {
131 let mut buf = Vec::new();
132 minicbor::encode_with(value, &mut buf, &mut ()).unwrap();
133 buf
134 }
135
136 fn decode<T: for<'b> minicbor::Decode<'b, ()>>(buf: &[u8]) -> T {
137 minicbor::decode_with(buf, &mut ()).unwrap()
138 }
139
140 #[test]
141 fn version_data_round_trip() {
142 let data = ForwardingVersionData {
143 network_magic: 764824073,
144 };
145 let buf = encode(&data);
146 let decoded: ForwardingVersionData = decode(&buf);
147 assert_eq!(decoded.network_magic, 764824073);
148 }
149
150 #[test]
151 fn version_table_v1_has_single_version_1() {
152 let table = version_table_v1(12345);
153 assert_eq!(table.len(), 1);
154 assert!(table.contains_key(&1));
155 assert_eq!(table[&1].network_magic, 12345);
156 }
157
158 #[test]
159 fn propose_round_trip() {
160 let versions = version_table_v1(764824073);
161 let msg = HandshakeMessage::Propose(versions);
162 let buf = encode(&msg);
163 let decoded: HandshakeMessage = decode(&buf);
164 match decoded {
165 HandshakeMessage::Propose(v) => {
166 assert!(v.contains_key(&1));
167 assert_eq!(v[&1].network_magic, 764824073);
168 }
169 _ => panic!("expected Propose, got something else"),
170 }
171 }
172
173 #[test]
174 fn accept_round_trip() {
175 let msg = HandshakeMessage::Accept(1, ForwardingVersionData { network_magic: 42 });
176 let buf = encode(&msg);
177 let decoded: HandshakeMessage = decode(&buf);
178 match decoded {
179 HandshakeMessage::Accept(ver, data) => {
180 assert_eq!(ver, 1);
181 assert_eq!(data.network_magic, 42);
182 }
183 _ => panic!("expected Accept"),
184 }
185 }
186
187 #[test]
188 fn refuse_round_trip() {
189 let msg = HandshakeMessage::Refuse(vec![1, 2, 3]);
190 let buf = encode(&msg);
191 let decoded: HandshakeMessage = decode(&buf);
192 match decoded {
193 HandshakeMessage::Refuse(mut versions) => {
194 versions.sort_unstable();
195 assert_eq!(versions, vec![1, 2, 3]);
196 }
197 _ => panic!("expected Refuse"),
198 }
199 }
200
201 #[test]
202 fn refuse_empty_versions_round_trip() {
203 let msg = HandshakeMessage::Refuse(vec![]);
204 let buf = encode(&msg);
205 match decode::<HandshakeMessage>(&buf) {
206 HandshakeMessage::Refuse(v) => assert!(v.is_empty()),
207 _ => panic!("expected Refuse"),
208 }
209 }
210}