lazuli_core/
header.rs

1//! Contains the PacketHeader struct. This struct is used to prepend a header to a packet.
2//!
3//! The header is used to ensure that the data is sent and received correctly.
4
5use std::{
6    fmt::Debug,
7    hash::{DefaultHasher, Hash, Hasher},
8    mem,
9};
10
11use crate::{hash_type_id, Result, Sendable};
12
13// RSOCK was the development name for this project.
14// TODO: Maybe change this to lazi or something similar.
15const HEADER: [u8; 5] = *b"RSOCK";
16
17#[derive(Clone, Copy, PartialEq, Eq, Hash)]
18#[repr(C)] // This is important for the safety of the from_bytes_unchecked function.
19/// The header of a packet. When a packet is sent over a socket, it is prepended with this header.
20/// This contains the type_id of the payload, the size of the payload, and a checksum of the payload.
21/// The checksum is used to verify that the payload was received correctly.
22/// The type_id is used to determine the type of the payload.
23/// The payload_size is used to determine the size of the payload.
24// TODO: Remove the type parameter. It was never used, and ended up causing some issues.
25pub struct PacketHeader<T>
26where
27    T: 'static + Sendable,
28{
29    // should always be "RSOCK"
30    header: [u8; 5],
31    has_checksum: bool,
32    checksum: u32,
33    pub payload_size: u32,
34    type_id: u32,
35    // allow for some sort of type safety
36    _phantom: std::marker::PhantomData<T>,
37}
38
39impl<T: Sendable> Debug for PacketHeader<T> {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        f.debug_struct("PacketHeader")
42            .field("header", &self.header)
43            .field("has_checksum", &self.has_checksum)
44            .field("checksum", &self.checksum)
45            .field("payload_size", &self.payload_size)
46            .field("type_id", &self.type_id)
47            .finish_non_exhaustive()
48    }
49}
50
51/// A ZST that represents an unknown type.
52/// This is used when the type of the payload is unknown.
53#[derive(Clone, Copy, Debug)]
54pub struct UnknownType;
55
56impl Sendable for UnknownType {
57    fn send(&self) -> Vec<u8> {
58        Vec::new()
59    }
60
61    fn recv(_: &mut dyn std::io::Read) -> Result<Self> {
62        Ok(UnknownType)
63    }
64}
65
66impl<T> PacketHeader<T>
67where
68    T: 'static + Sendable,
69{
70    /// Creates a new PacketHeader with the type_id of T and the payload_size of T.
71    pub fn auto() -> PacketHeader<T> {
72        PacketHeader {
73            header: HEADER,
74            checksum: 0,
75            has_checksum: false,
76            payload_size: std::mem::size_of::<T>() as u32,
77            type_id: hash_type_id::<T>(),
78            _phantom: std::marker::PhantomData,
79        }
80    }
81    /// Creates a new PacketHeader with the specified length of the payload.
82    ///
83    /// This can be useful for types where the size of the payload is not constant. (e.g. Vec<T>, String, etc.)
84    /// This can also be useful for reference types.
85    ///
86    /// # Safety
87    /// The caller must ensure that the payload_size is correct, and that the sendable implementation accounts for the variable size of the payload.
88    pub unsafe fn new(payload_size: u32) -> PacketHeader<T> {
89        PacketHeader {
90            header: HEADER,
91            checksum: 0,
92            has_checksum: false,
93            payload_size,
94            type_id: hash_type_id::<T>(),
95            _phantom: std::marker::PhantomData,
96        }
97    }
98    /// Calculates the checksum of the payload. Sets the checksum field to the calculated checksum.
99    pub(crate) fn calculate_checksum(&mut self, payload: &[u8]) {
100        let mut hasher = DefaultHasher::new();
101        hasher.write(payload);
102        self.checksum = hasher.finish() as u32;
103        self.has_checksum = true;
104    }
105    /// Verifies the checksum of the payload.
106    pub fn verify_checksum(&self, payload: &[u8]) -> bool {
107        if !self.has_checksum {
108            return true;
109        }
110        let mut hasher = DefaultHasher::new();
111        hasher.write(payload);
112        self.checksum == hasher.finish() as u32
113    }
114
115    /// Converts the PacketHeader into a byte array.
116    pub fn to_bytes(&self) -> [u8; mem::size_of::<PacketHeader<UnknownType>>()] {
117        unsafe {
118            // SAFETY: We know that PacketHeader<T> is the same size as PacketHeader<UnknownType>
119            let bytes = std::mem::transmute_copy::<
120                PacketHeader<T>,
121                [u8; mem::size_of::<PacketHeader<UnknownType>>()],
122            >(self);
123            bytes
124        }
125    }
126
127    /// Gets the type_id of the payload.
128    pub(crate) fn id(&self) -> u32 {
129        self.type_id
130    }
131}
132
133impl PacketHeader<UnknownType> {
134    /// Converts the PacketHeader into a PacketHeader with a different type.
135    /// # Safety
136    /// The caller must ensure that the type_id and payload_size are correct.
137    /// The caller must also ensure that the type T is the correct type.
138    pub unsafe fn into_ty<U: Copy + Sendable>(self) -> PacketHeader<U> {
139        assert_eq!(self.payload_size, std::mem::size_of::<U>() as u32);
140        assert_eq!(self.type_id, hash_type_id::<U>());
141
142        PacketHeader {
143            header: self.header,
144            checksum: self.checksum,
145            has_checksum: self.has_checksum,
146            payload_size: self.payload_size,
147            type_id: self.type_id,
148            _phantom: std::marker::PhantomData,
149        }
150    }
151    /// Creates a new PacketHeader from a byte array.
152    /// # Safety
153    /// This function is unsafe because it creates a PacketHeader from a byte array without checking the checksum.
154    /// Use `PacketHeader::from_bytes` if you want to check the checksum.
155    pub unsafe fn from_bytes_unchecked(bytes: &[u8]) -> PacketHeader<UnknownType> {
156        assert!(
157            bytes.len() == mem::size_of::<PacketHeader<UnknownType>>(),
158            "bytes.len() = {}",
159            bytes.len()
160        );
161        assert!(
162            bytes.starts_with(&HEADER),
163            "Header is not correct (Expected: {:?}, Got: {:?})",
164            HEADER,
165            &bytes[..5]
166        );
167        // Safety: We just checked that the length of bytes is the same as the size of PacketHeader
168        // and that it starts with the HEADER.
169        unsafe { *(bytes.as_ptr() as *const PacketHeader<UnknownType>) }
170    }
171    /// Creates a new PacketHeader from a byte array.
172    pub fn from_bytes(bytes: &[u8], data: &[u8]) -> Option<PacketHeader<UnknownType>> {
173        let header: PacketHeader<UnknownType> =
174            unsafe { PacketHeader::<UnknownType>::from_bytes_unchecked(bytes) };
175        assert_eq!(header.payload_size as usize, data.len());
176        let checksum_ok: bool = header.verify_checksum(data);
177        let len_ok: bool = bytes.len() == mem::size_of::<PacketHeader<UnknownType>>();
178        let header_ok: bool = bytes.starts_with(&HEADER);
179        if checksum_ok && len_ok && header_ok {
180            Some(header)
181        } else {
182            None
183        }
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use crate::hash_type_id;
190
191    use super::*;
192
193    #[test]
194    fn test_packet_header() {
195        let mut header: PacketHeader<u128> = PacketHeader::auto();
196        let data = 32u128.send();
197        header.calculate_checksum(&data);
198        let bytes = header.to_bytes();
199        let new_header = PacketHeader::from_bytes(&bytes, &data).unwrap();
200        let ty_header = unsafe { new_header.into_ty::<u128>() };
201        assert_eq!(header, ty_header);
202    }
203
204    #[test]
205    fn test_new_auto() {
206        let header: PacketHeader<u32> = PacketHeader::auto();
207        assert_eq!(header.payload_size, 4);
208        assert_eq!(header.type_id, hash_type_id::<u32>());
209    }
210}