libmoshpit/frames/
frame.rs

1// Copyright (c) 2025 moshpit developers
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or https://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use std::{fmt::Display, io::Cursor};
10
11use anyhow::Result;
12use bincode::{Decode, Encode, config::standard, decode_from_slice};
13use bytes::Buf as _;
14use tracing::trace;
15
16use crate::{
17    frames::{get_bytes, get_usize},
18    uuid::UuidWrapper,
19};
20
21/// A moshpit frame.
22#[derive(Clone, Debug, Decode, Encode, Eq, Hash, Ord, PartialEq, PartialOrd)]
23pub enum Frame {
24    /// An initialization frame from moshpit.
25    Initialize(Vec<u8>),
26    /// A peer initialization frame from moshpits.
27    PeerInitialize(Vec<u8>, Vec<u8>),
28    /// A check message from moshpit.
29    Check([u8; 12], Vec<u8>),
30    /// A key agreement message from moshpits.
31    KeyAgreement(UuidWrapper),
32}
33
34impl Frame {
35    /// Get the frame identifier.
36    #[must_use]
37    pub fn id(&self) -> u8 {
38        match self {
39            Frame::Initialize(_) => 0,
40            Frame::PeerInitialize(_, _) => 1,
41            Frame::Check(_, _) => 2,
42            Frame::KeyAgreement(_) => 3,
43        }
44    }
45
46    /// Parse a moshpit frame from the given byte source.
47    ///
48    /// # Errors
49    /// * Incomplete data.
50    ///
51    pub fn parse(src: &mut Cursor<&[u8]>) -> Result<Option<Self>> {
52        match get_u8(src) {
53            Some(0..=3) => {
54                if let Some(length_slice) = get_usize(src)? {
55                    let length = usize::from_be_bytes(length_slice.try_into()?);
56                    if let Some(data) = get_bytes(src, length)? {
57                        let (frame, _): (Frame, _) = decode_from_slice(data, standard())?;
58                        return Ok(Some(frame));
59                    }
60                }
61                Ok(None)
62            }
63            Some(_) => {
64                trace!("Unknown frame");
65                Ok(None)
66            }
67            None => {
68                trace!("Incomplete frame");
69                Ok(None)
70            }
71        }
72    }
73}
74
75fn get_u8(src: &mut Cursor<&[u8]>) -> Option<u8> {
76    if !src.has_remaining() {
77        return None;
78    }
79
80    Some(src.get_u8())
81}
82
83impl Display for Frame {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        match self {
86            Frame::Initialize(data) => write!(f, "Initialize({} bytes)", data.len()),
87            Frame::PeerInitialize(pk, salt) => write!(
88                f,
89                "PeerInitialize({} bytes, {} bytes)",
90                pk.len(),
91                salt.len(),
92            ),
93            Frame::Check(nonce, data) => {
94                write!(f, "Check({} bytes, {} bytes)", nonce.len(), data.len())
95            }
96            Frame::KeyAgreement(uuid) => write!(f, "KeyAgreement({uuid})"),
97        }
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use std::io::Cursor;
104
105    use anyhow::Result;
106    use bincode::{config::standard, encode_to_vec};
107
108    use crate::frames::USIZE_LENGTH;
109
110    use super::{Frame, get_bytes, get_u8, get_usize};
111
112    const TEST_USIZE: usize = 12;
113
114    fn validate_get_u8(cursor: &mut Cursor<&[u8]>) {
115        let flag = get_u8(cursor);
116        assert!(flag.is_some());
117        let flag = flag.unwrap();
118        assert_eq!(flag, 0);
119        assert_eq!(cursor.position(), 1);
120    }
121
122    fn validate_get_usize(cursor: &mut Cursor<&[u8]>, expected: usize) -> Result<()> {
123        let line = get_usize(cursor)?;
124        assert!(line.is_some());
125        let line = line.unwrap();
126        let value = usize::from_be_bytes(line.try_into()?);
127        assert_eq!(value, expected);
128        assert_eq!(cursor.position(), u64::try_from(USIZE_LENGTH + 1)?);
129        Ok(())
130    }
131
132    fn validate_get_bytes(cursor: &mut Cursor<&[u8]>, expected: &[u8]) -> Result<()> {
133        let bytes = get_bytes(cursor, expected.len())?;
134        assert!(bytes.is_some());
135        let bytes = bytes.unwrap();
136        assert_eq!(bytes, expected);
137        assert_eq!(
138            cursor.position(),
139            u64::try_from(USIZE_LENGTH + 1 + expected.len())?
140        );
141        Ok(())
142    }
143
144    enum Completness {
145        Complete,
146        Incomplete,
147    }
148
149    enum DataKind {
150        U8,
151        Usize,
152        Bytes,
153    }
154
155    fn test_data(kind: DataKind, completeness: Completness) -> (Vec<u8>, usize, Vec<u8>) {
156        match (kind, completeness) {
157            (DataKind::U8, Completness::Complete) => (vec![0u8], 0, vec![]),
158            (DataKind::U8, Completness::Incomplete) => (vec![], 0, vec![]),
159            (DataKind::Usize, Completness::Complete) => {
160                let val = TEST_USIZE;
161                let data = val.to_be_bytes();
162                ([&[0], data.as_slice()].concat(), val, vec![])
163            }
164            (DataKind::Usize, Completness::Incomplete) => {
165                let val = TEST_USIZE;
166                let data = val.to_be_bytes();
167                ([&[0], &data[..4]].concat(), val, vec![])
168            }
169            (DataKind::Bytes, Completness::Complete) => {
170                let data = b"hello";
171                let length = data.len();
172                let length_bytes = length.to_be_bytes();
173                (
174                    [&[0], length_bytes.as_slice(), data.as_slice()].concat(),
175                    length,
176                    data.to_vec(),
177                )
178            }
179            (DataKind::Bytes, Completness::Incomplete) => {
180                let data = b"hello";
181                let length = data.len() + 5; // Intentionally incorrect length
182                let length_bytes = length.to_be_bytes();
183                (
184                    [&[0], length_bytes.as_slice(), data.as_slice()].concat(),
185                    length,
186                    data.to_vec(),
187                )
188            }
189        }
190    }
191
192    #[test]
193    fn test_get_u8() {
194        let (all_data, _, _) = test_data(DataKind::U8, Completness::Complete);
195        let mut cursor = Cursor::new(&all_data[..]);
196        validate_get_u8(&mut cursor);
197    }
198
199    #[test]
200    fn test_get_u8_incomplete() {
201        let (all_data, _, _) = test_data(DataKind::U8, Completness::Incomplete);
202        let mut cursor = Cursor::new(&all_data[..]);
203        assert!(get_u8(&mut cursor).is_none());
204    }
205
206    #[test]
207    fn test_get_usize() -> Result<()> {
208        let (all_data, expected_usize, _) = test_data(DataKind::Usize, Completness::Complete);
209        let mut cursor = Cursor::new(&all_data[..]);
210        validate_get_u8(&mut cursor);
211        validate_get_usize(&mut cursor, expected_usize)?;
212        Ok(())
213    }
214
215    #[test]
216    fn test_get_usize_incomplete() {
217        let (all_data, _, _) = test_data(DataKind::Usize, Completness::Incomplete);
218        let mut cursor = Cursor::new(&all_data[..]);
219        validate_get_u8(&mut cursor);
220        let res = get_usize(&mut cursor);
221        assert!(res.is_ok());
222        let maybe_usize = res.unwrap();
223        assert!(maybe_usize.is_none());
224    }
225
226    #[test]
227    fn test_get_bytes() -> Result<()> {
228        let (all_data, expected_usize, expected_bytes) =
229            test_data(DataKind::Bytes, Completness::Complete);
230        let mut cursor = Cursor::new(&all_data[..]);
231        validate_get_u8(&mut cursor);
232        validate_get_usize(&mut cursor, expected_usize)?;
233        validate_get_bytes(&mut cursor, &expected_bytes)?;
234        Ok(())
235    }
236
237    #[test]
238    fn test_get_bytes_incomplete() -> Result<()> {
239        let (all_data, expected_usize, _) = test_data(DataKind::Bytes, Completness::Incomplete);
240        let mut cursor = Cursor::new(&all_data[..]);
241        validate_get_u8(&mut cursor);
242        validate_get_usize(&mut cursor, expected_usize)?;
243        let res = get_bytes(&mut cursor, expected_usize);
244        assert!(res.is_ok());
245        let maybe_bytes = res.unwrap();
246        assert!(maybe_bytes.is_none());
247        Ok(())
248    }
249
250    #[test]
251    fn test_parse() -> Result<()> {
252        let data = b"hello world".to_vec();
253        let frame = Frame::Initialize(data.clone());
254        let encoded_frame = encode_to_vec(&frame, standard())?;
255
256        let length = encoded_frame.len();
257        let length_bytes = length.to_be_bytes();
258
259        let mut all_data = vec![0u8];
260        all_data.extend_from_slice(&length_bytes);
261        all_data.extend_from_slice(&encoded_frame);
262
263        let mut cursor = Cursor::new(&all_data[..]);
264        let parsed_frame = Frame::parse(&mut cursor)?;
265        assert!(parsed_frame.is_some());
266        let parsed_frame = parsed_frame.unwrap();
267        assert_eq!(parsed_frame, frame);
268        Ok(())
269    }
270
271    #[test]
272    fn test_parse_incomplete() {
273        let all_data = [200u8];
274        let mut cursor = Cursor::new(&all_data[..]);
275        let result = Frame::parse(&mut cursor);
276        assert!(result.is_ok());
277        let maybe_frame = result.unwrap();
278        assert!(maybe_frame.is_none());
279    }
280
281    #[test]
282    fn test_parse_error() {
283        let all_data = [];
284        let mut cursor = Cursor::new(&all_data[..]);
285        let result = Frame::parse(&mut cursor);
286        assert!(result.is_ok());
287        let maybe_frame = result.unwrap();
288        assert!(maybe_frame.is_none());
289    }
290}