1use std::{
6 fmt::Debug,
7 hash::{DefaultHasher, Hash, Hasher},
8 mem,
9};
10
11use crate::{hash_type_id, Result, Sendable};
12
13const HEADER: [u8; 5] = *b"RSOCK";
16
17#[derive(Clone, Copy, PartialEq, Eq, Hash)]
18#[repr(C)] pub struct PacketHeader<T>
26where
27 T: 'static + Sendable,
28{
29 header: [u8; 5],
31 has_checksum: bool,
32 checksum: u32,
33 pub payload_size: u32,
34 type_id: u32,
35 _phantom: std::marker::PhantomData<T>,
37}
38
39impl<T: Sendable> Debug for PacketHeader<T> {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 f.debug_struct("PacketHeader")
42 .field("header", &self.header)
43 .field("has_checksum", &self.has_checksum)
44 .field("checksum", &self.checksum)
45 .field("payload_size", &self.payload_size)
46 .field("type_id", &self.type_id)
47 .finish_non_exhaustive()
48 }
49}
50
51#[derive(Clone, Copy, Debug)]
54pub struct UnknownType;
55
56impl Sendable for UnknownType {
57 fn send(&self) -> Vec<u8> {
58 Vec::new()
59 }
60
61 fn recv(_: &mut dyn std::io::Read) -> Result<Self> {
62 Ok(UnknownType)
63 }
64}
65
66impl<T> PacketHeader<T>
67where
68 T: 'static + Sendable,
69{
70 pub fn auto() -> PacketHeader<T> {
72 PacketHeader {
73 header: HEADER,
74 checksum: 0,
75 has_checksum: false,
76 payload_size: std::mem::size_of::<T>() as u32,
77 type_id: hash_type_id::<T>(),
78 _phantom: std::marker::PhantomData,
79 }
80 }
81 pub unsafe fn new(payload_size: u32) -> PacketHeader<T> {
89 PacketHeader {
90 header: HEADER,
91 checksum: 0,
92 has_checksum: false,
93 payload_size,
94 type_id: hash_type_id::<T>(),
95 _phantom: std::marker::PhantomData,
96 }
97 }
98 pub(crate) fn calculate_checksum(&mut self, payload: &[u8]) {
100 let mut hasher = DefaultHasher::new();
101 hasher.write(payload);
102 self.checksum = hasher.finish() as u32;
103 self.has_checksum = true;
104 }
105 pub fn verify_checksum(&self, payload: &[u8]) -> bool {
107 if !self.has_checksum {
108 return true;
109 }
110 let mut hasher = DefaultHasher::new();
111 hasher.write(payload);
112 self.checksum == hasher.finish() as u32
113 }
114
115 pub fn to_bytes(&self) -> [u8; mem::size_of::<PacketHeader<UnknownType>>()] {
117 unsafe {
118 let bytes = std::mem::transmute_copy::<
120 PacketHeader<T>,
121 [u8; mem::size_of::<PacketHeader<UnknownType>>()],
122 >(self);
123 bytes
124 }
125 }
126
127 pub(crate) fn id(&self) -> u32 {
129 self.type_id
130 }
131}
132
133impl PacketHeader<UnknownType> {
134 pub unsafe fn into_ty<U: Copy + Sendable>(self) -> PacketHeader<U> {
139 assert_eq!(self.payload_size, std::mem::size_of::<U>() as u32);
140 assert_eq!(self.type_id, hash_type_id::<U>());
141
142 PacketHeader {
143 header: self.header,
144 checksum: self.checksum,
145 has_checksum: self.has_checksum,
146 payload_size: self.payload_size,
147 type_id: self.type_id,
148 _phantom: std::marker::PhantomData,
149 }
150 }
151 pub unsafe fn from_bytes_unchecked(bytes: &[u8]) -> PacketHeader<UnknownType> {
156 assert!(
157 bytes.len() == mem::size_of::<PacketHeader<UnknownType>>(),
158 "bytes.len() = {}",
159 bytes.len()
160 );
161 assert!(
162 bytes.starts_with(&HEADER),
163 "Header is not correct (Expected: {:?}, Got: {:?})",
164 HEADER,
165 &bytes[..5]
166 );
167 unsafe { *(bytes.as_ptr() as *const PacketHeader<UnknownType>) }
170 }
171 pub fn from_bytes(bytes: &[u8], data: &[u8]) -> Option<PacketHeader<UnknownType>> {
173 let header: PacketHeader<UnknownType> =
174 unsafe { PacketHeader::<UnknownType>::from_bytes_unchecked(bytes) };
175 assert_eq!(header.payload_size as usize, data.len());
176 let checksum_ok: bool = header.verify_checksum(data);
177 let len_ok: bool = bytes.len() == mem::size_of::<PacketHeader<UnknownType>>();
178 let header_ok: bool = bytes.starts_with(&HEADER);
179 if checksum_ok && len_ok && header_ok {
180 Some(header)
181 } else {
182 None
183 }
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use crate::hash_type_id;
190
191 use super::*;
192
193 #[test]
194 fn test_packet_header() {
195 let mut header: PacketHeader<u128> = PacketHeader::auto();
196 let data = 32u128.send();
197 header.calculate_checksum(&data);
198 let bytes = header.to_bytes();
199 let new_header = PacketHeader::from_bytes(&bytes, &data).unwrap();
200 let ty_header = unsafe { new_header.into_ty::<u128>() };
201 assert_eq!(header, ty_header);
202 }
203
204 #[test]
205 fn test_new_auto() {
206 let header: PacketHeader<u32> = PacketHeader::auto();
207 assert_eq!(header.payload_size, 4);
208 assert_eq!(header.type_id, hash_type_id::<u32>());
209 }
210}