1use bytes::{Buf, BufMut, Bytes, BytesMut};
6
7use crate::error::{WireError, WireResult};
8
9pub const MAGIC: u32 = 0x5644_4220;
11
12pub const PROTOCOL_VERSION: u16 = 1;
14
15pub const FRAME_HEADER_SIZE: usize = 14;
17
18pub const MAX_PAYLOAD_SIZE: u32 = 16 * 1024 * 1024;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub struct FrameHeader {
24 pub magic: u32,
26 pub version: u16,
28 pub length: u32,
30 pub checksum: u32,
32}
33
34impl FrameHeader {
35 pub fn new(payload: &[u8]) -> Self {
37 let checksum = compute_checksum(payload);
38 Self {
39 magic: MAGIC,
40 version: PROTOCOL_VERSION,
41 length: payload.len() as u32,
42 checksum,
43 }
44 }
45
46 pub fn encode(&self, buf: &mut BytesMut) {
48 buf.put_u32(self.magic);
49 buf.put_u16(self.version);
50 buf.put_u32(self.length);
51 buf.put_u32(self.checksum);
52 }
53
54 pub fn decode(buf: &mut impl Buf) -> Option<Self> {
58 if buf.remaining() < FRAME_HEADER_SIZE {
59 return None;
60 }
61
62 Some(Self {
63 magic: buf.get_u32(),
64 version: buf.get_u16(),
65 length: buf.get_u32(),
66 checksum: buf.get_u32(),
67 })
68 }
69
70 pub fn validate(&self) -> WireResult<()> {
72 if self.magic != MAGIC {
73 return Err(WireError::InvalidMagic(self.magic));
74 }
75
76 if self.version != PROTOCOL_VERSION {
77 return Err(WireError::UnsupportedVersion(self.version));
78 }
79
80 if self.length > MAX_PAYLOAD_SIZE {
81 return Err(WireError::PayloadTooLarge {
82 size: self.length,
83 max: MAX_PAYLOAD_SIZE,
84 });
85 }
86
87 Ok(())
88 }
89}
90
91#[derive(Debug, Clone, PartialEq, Eq)]
93pub struct Frame {
94 pub header: FrameHeader,
96 pub payload: Bytes,
98}
99
100impl Frame {
101 pub fn new(payload: Bytes) -> Self {
103 let header = FrameHeader::new(&payload);
104 Self { header, payload }
105 }
106
107 pub fn encode(&self, buf: &mut BytesMut) {
109 self.header.encode(buf);
110 buf.put_slice(&self.payload);
111 }
112
113 pub fn encode_to_bytes(&self) -> Bytes {
115 let mut buf = BytesMut::with_capacity(FRAME_HEADER_SIZE + self.payload.len());
116 self.encode(&mut buf);
117 buf.freeze()
118 }
119
120 pub fn decode(buf: &mut BytesMut) -> WireResult<Option<Self>> {
128 if buf.len() < FRAME_HEADER_SIZE {
130 return Ok(None);
131 }
132
133 let header = {
135 let mut peek = buf.as_ref();
136 FrameHeader::decode(&mut peek).expect("checked length above")
137 };
138
139 header.validate()?;
141
142 let total_size = FRAME_HEADER_SIZE + header.length as usize;
144 if buf.len() < total_size {
145 return Ok(None);
146 }
147
148 buf.advance(FRAME_HEADER_SIZE);
150
151 let payload = buf.split_to(header.length as usize).freeze();
153
154 let actual_checksum = compute_checksum(&payload);
156 if actual_checksum != header.checksum {
157 return Err(WireError::ChecksumMismatch {
158 expected: header.checksum,
159 actual: actual_checksum,
160 });
161 }
162
163 Ok(Some(Self { header, payload }))
164 }
165
166 pub fn total_size(&self) -> usize {
168 FRAME_HEADER_SIZE + self.payload.len()
169 }
170}
171
172fn compute_checksum(data: &[u8]) -> u32 {
174 kimberlite_crypto::crc32(data)
175}
176
177#[cfg(test)]
178mod frame_tests {
179 use super::*;
180
181 #[test]
182 fn test_frame_roundtrip() {
183 let payload = Bytes::from("hello, world!");
184 let frame = Frame::new(payload.clone());
185
186 let encoded = frame.encode_to_bytes();
188 assert_eq!(encoded.len(), FRAME_HEADER_SIZE + payload.len());
189
190 let mut buf = BytesMut::from(&encoded[..]);
192 let decoded = Frame::decode(&mut buf).unwrap().unwrap();
193
194 assert_eq!(decoded.payload, payload);
195 assert!(buf.is_empty());
196 }
197
198 #[test]
199 fn test_incomplete_header() {
200 let mut buf = BytesMut::from(&[0u8; 5][..]);
201 assert!(Frame::decode(&mut buf).unwrap().is_none());
202 }
203
204 #[test]
205 fn test_incomplete_payload() {
206 let payload = Bytes::from("test");
207 let frame = Frame::new(payload);
208 let encoded = frame.encode_to_bytes();
209
210 let mut buf = BytesMut::from(&encoded[..FRAME_HEADER_SIZE + 2]);
212 assert!(Frame::decode(&mut buf).unwrap().is_none());
213 }
214
215 #[test]
216 fn test_invalid_magic() {
217 let mut buf = BytesMut::new();
218 buf.put_u32(0xDEAD_BEEF); buf.put_u16(PROTOCOL_VERSION);
220 buf.put_u32(4);
221 buf.put_u32(0);
222 buf.put_slice(b"test");
223
224 let result = Frame::decode(&mut buf);
225 assert!(matches!(result, Err(WireError::InvalidMagic(0xDEAD_BEEF))));
226 }
227
228 #[test]
229 fn test_checksum_mismatch() {
230 let mut buf = BytesMut::new();
231 buf.put_u32(MAGIC);
232 buf.put_u16(PROTOCOL_VERSION);
233 buf.put_u32(4);
234 buf.put_u32(0x00BA_DBAD); buf.put_slice(b"test");
236
237 let result = Frame::decode(&mut buf);
238 assert!(matches!(result, Err(WireError::ChecksumMismatch { .. })));
239 }
240
241 #[test]
242 fn test_header_constants() {
243 assert_eq!(MAGIC, 0x5644_4220);
244 assert_eq!(FRAME_HEADER_SIZE, 14);
245 assert_eq!(MAX_PAYLOAD_SIZE, 16 * 1024 * 1024);
246 }
247}