1use std::{fmt::Display, io::Cursor};
10
11use anyhow::Result;
12use bincode::{Decode, Encode, config::standard, decode_from_slice};
13use bytes::Buf as _;
14use tracing::trace;
15
16use crate::{
17 frames::{get_bytes, get_usize},
18 uuid::UuidWrapper,
19};
20
21#[derive(Clone, Debug, Decode, Encode, Eq, Hash, Ord, PartialEq, PartialOrd)]
23pub enum Frame {
24 Initialize(Vec<u8>),
26 PeerInitialize(Vec<u8>, Vec<u8>),
28 Check([u8; 12], Vec<u8>),
30 KeyAgreement(UuidWrapper),
32}
33
34impl Frame {
35 #[must_use]
37 pub fn id(&self) -> u8 {
38 match self {
39 Frame::Initialize(_) => 0,
40 Frame::PeerInitialize(_, _) => 1,
41 Frame::Check(_, _) => 2,
42 Frame::KeyAgreement(_) => 3,
43 }
44 }
45
46 pub fn parse(src: &mut Cursor<&[u8]>) -> Result<Option<Self>> {
52 match get_u8(src) {
53 Some(0..=3) => {
54 if let Some(length_slice) = get_usize(src)? {
55 let length = usize::from_be_bytes(length_slice.try_into()?);
56 if let Some(data) = get_bytes(src, length)? {
57 let (frame, _): (Frame, _) = decode_from_slice(data, standard())?;
58 return Ok(Some(frame));
59 }
60 }
61 Ok(None)
62 }
63 Some(_) => {
64 trace!("Unknown frame");
65 Ok(None)
66 }
67 None => {
68 trace!("Incomplete frame");
69 Ok(None)
70 }
71 }
72 }
73}
74
75fn get_u8(src: &mut Cursor<&[u8]>) -> Option<u8> {
76 if !src.has_remaining() {
77 return None;
78 }
79
80 Some(src.get_u8())
81}
82
83impl Display for Frame {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 Frame::Initialize(data) => write!(f, "Initialize({} bytes)", data.len()),
87 Frame::PeerInitialize(pk, salt) => write!(
88 f,
89 "PeerInitialize({} bytes, {} bytes)",
90 pk.len(),
91 salt.len(),
92 ),
93 Frame::Check(nonce, data) => {
94 write!(f, "Check({} bytes, {} bytes)", nonce.len(), data.len())
95 }
96 Frame::KeyAgreement(uuid) => write!(f, "KeyAgreement({uuid})"),
97 }
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use std::io::Cursor;
104
105 use anyhow::Result;
106 use bincode::{config::standard, encode_to_vec};
107
108 use crate::frames::USIZE_LENGTH;
109
110 use super::{Frame, get_bytes, get_u8, get_usize};
111
112 const TEST_USIZE: usize = 12;
113
114 fn validate_get_u8(cursor: &mut Cursor<&[u8]>) {
115 let flag = get_u8(cursor);
116 assert!(flag.is_some());
117 let flag = flag.unwrap();
118 assert_eq!(flag, 0);
119 assert_eq!(cursor.position(), 1);
120 }
121
122 fn validate_get_usize(cursor: &mut Cursor<&[u8]>, expected: usize) -> Result<()> {
123 let line = get_usize(cursor)?;
124 assert!(line.is_some());
125 let line = line.unwrap();
126 let value = usize::from_be_bytes(line.try_into()?);
127 assert_eq!(value, expected);
128 assert_eq!(cursor.position(), u64::try_from(USIZE_LENGTH + 1)?);
129 Ok(())
130 }
131
132 fn validate_get_bytes(cursor: &mut Cursor<&[u8]>, expected: &[u8]) -> Result<()> {
133 let bytes = get_bytes(cursor, expected.len())?;
134 assert!(bytes.is_some());
135 let bytes = bytes.unwrap();
136 assert_eq!(bytes, expected);
137 assert_eq!(
138 cursor.position(),
139 u64::try_from(USIZE_LENGTH + 1 + expected.len())?
140 );
141 Ok(())
142 }
143
144 enum Completness {
145 Complete,
146 Incomplete,
147 }
148
149 enum DataKind {
150 U8,
151 Usize,
152 Bytes,
153 }
154
155 fn test_data(kind: DataKind, completeness: Completness) -> (Vec<u8>, usize, Vec<u8>) {
156 match (kind, completeness) {
157 (DataKind::U8, Completness::Complete) => (vec![0u8], 0, vec![]),
158 (DataKind::U8, Completness::Incomplete) => (vec![], 0, vec![]),
159 (DataKind::Usize, Completness::Complete) => {
160 let val = TEST_USIZE;
161 let data = val.to_be_bytes();
162 ([&[0], data.as_slice()].concat(), val, vec![])
163 }
164 (DataKind::Usize, Completness::Incomplete) => {
165 let val = TEST_USIZE;
166 let data = val.to_be_bytes();
167 ([&[0], &data[..4]].concat(), val, vec![])
168 }
169 (DataKind::Bytes, Completness::Complete) => {
170 let data = b"hello";
171 let length = data.len();
172 let length_bytes = length.to_be_bytes();
173 (
174 [&[0], length_bytes.as_slice(), data.as_slice()].concat(),
175 length,
176 data.to_vec(),
177 )
178 }
179 (DataKind::Bytes, Completness::Incomplete) => {
180 let data = b"hello";
181 let length = data.len() + 5; let length_bytes = length.to_be_bytes();
183 (
184 [&[0], length_bytes.as_slice(), data.as_slice()].concat(),
185 length,
186 data.to_vec(),
187 )
188 }
189 }
190 }
191
192 #[test]
193 fn test_get_u8() {
194 let (all_data, _, _) = test_data(DataKind::U8, Completness::Complete);
195 let mut cursor = Cursor::new(&all_data[..]);
196 validate_get_u8(&mut cursor);
197 }
198
199 #[test]
200 fn test_get_u8_incomplete() {
201 let (all_data, _, _) = test_data(DataKind::U8, Completness::Incomplete);
202 let mut cursor = Cursor::new(&all_data[..]);
203 assert!(get_u8(&mut cursor).is_none());
204 }
205
206 #[test]
207 fn test_get_usize() -> Result<()> {
208 let (all_data, expected_usize, _) = test_data(DataKind::Usize, Completness::Complete);
209 let mut cursor = Cursor::new(&all_data[..]);
210 validate_get_u8(&mut cursor);
211 validate_get_usize(&mut cursor, expected_usize)?;
212 Ok(())
213 }
214
215 #[test]
216 fn test_get_usize_incomplete() {
217 let (all_data, _, _) = test_data(DataKind::Usize, Completness::Incomplete);
218 let mut cursor = Cursor::new(&all_data[..]);
219 validate_get_u8(&mut cursor);
220 let res = get_usize(&mut cursor);
221 assert!(res.is_ok());
222 let maybe_usize = res.unwrap();
223 assert!(maybe_usize.is_none());
224 }
225
226 #[test]
227 fn test_get_bytes() -> Result<()> {
228 let (all_data, expected_usize, expected_bytes) =
229 test_data(DataKind::Bytes, Completness::Complete);
230 let mut cursor = Cursor::new(&all_data[..]);
231 validate_get_u8(&mut cursor);
232 validate_get_usize(&mut cursor, expected_usize)?;
233 validate_get_bytes(&mut cursor, &expected_bytes)?;
234 Ok(())
235 }
236
237 #[test]
238 fn test_get_bytes_incomplete() -> Result<()> {
239 let (all_data, expected_usize, _) = test_data(DataKind::Bytes, Completness::Incomplete);
240 let mut cursor = Cursor::new(&all_data[..]);
241 validate_get_u8(&mut cursor);
242 validate_get_usize(&mut cursor, expected_usize)?;
243 let res = get_bytes(&mut cursor, expected_usize);
244 assert!(res.is_ok());
245 let maybe_bytes = res.unwrap();
246 assert!(maybe_bytes.is_none());
247 Ok(())
248 }
249
250 #[test]
251 fn test_parse() -> Result<()> {
252 let data = b"hello world".to_vec();
253 let frame = Frame::Initialize(data.clone());
254 let encoded_frame = encode_to_vec(&frame, standard())?;
255
256 let length = encoded_frame.len();
257 let length_bytes = length.to_be_bytes();
258
259 let mut all_data = vec![0u8];
260 all_data.extend_from_slice(&length_bytes);
261 all_data.extend_from_slice(&encoded_frame);
262
263 let mut cursor = Cursor::new(&all_data[..]);
264 let parsed_frame = Frame::parse(&mut cursor)?;
265 assert!(parsed_frame.is_some());
266 let parsed_frame = parsed_frame.unwrap();
267 assert_eq!(parsed_frame, frame);
268 Ok(())
269 }
270
271 #[test]
272 fn test_parse_incomplete() {
273 let all_data = [200u8];
274 let mut cursor = Cursor::new(&all_data[..]);
275 let result = Frame::parse(&mut cursor);
276 assert!(result.is_ok());
277 let maybe_frame = result.unwrap();
278 assert!(maybe_frame.is_none());
279 }
280
281 #[test]
282 fn test_parse_error() {
283 let all_data = [];
284 let mut cursor = Cursor::new(&all_data[..]);
285 let result = Frame::parse(&mut cursor);
286 assert!(result.is_ok());
287 let maybe_frame = result.unwrap();
288 assert!(maybe_frame.is_none());
289 }
290}