nodedb_codec/vector_quant/
codec_envelope.rs1use crate::error::CodecError;
12use zerompk::{FromMessagePack, ToMessagePack};
13
14pub const MAGIC_LEN: usize = 5;
16
17pub const HEADER_LEN: usize = MAGIC_LEN + 1;
19
20pub fn encode<T: ToMessagePack>(
22 magic: &[u8; MAGIC_LEN],
23 version: u8,
24 value: &T,
25) -> Result<Vec<u8>, CodecError> {
26 let body = zerompk::to_msgpack_vec(value).map_err(|e| CodecError::Corrupt {
27 detail: e.to_string(),
28 })?;
29 let mut out = Vec::with_capacity(HEADER_LEN + body.len());
30 out.extend_from_slice(magic);
31 out.push(version);
32 out.extend_from_slice(&body);
33 Ok(out)
34}
35
36pub fn decode<T: for<'de> FromMessagePack<'de>>(
43 magic: &[u8; MAGIC_LEN],
44 expected_version: u8,
45 buf: &[u8],
46) -> Result<T, CodecError> {
47 if buf.len() < HEADER_LEN {
48 return Err(CodecError::Truncated {
49 expected: HEADER_LEN,
50 actual: buf.len(),
51 });
52 }
53 if &buf[..MAGIC_LEN] != magic {
54 return Err(CodecError::Corrupt {
55 detail: "bad magic".into(),
56 });
57 }
58 let got = buf[MAGIC_LEN];
59 if got != expected_version {
60 return Err(CodecError::Corrupt {
61 detail: format!("unsupported version {got}"),
62 });
63 }
64 zerompk::from_msgpack(&buf[HEADER_LEN..]).map_err(|e| CodecError::Corrupt {
65 detail: e.to_string(),
66 })
67}
68
69pub fn peek_version(magic: &[u8; MAGIC_LEN], buf: &[u8]) -> Option<u8> {
72 if buf.len() < HEADER_LEN {
73 return None;
74 }
75 if &buf[..MAGIC_LEN] != magic {
76 return None;
77 }
78 Some(buf[MAGIC_LEN])
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84
85 const MAGIC: &[u8; MAGIC_LEN] = b"NDTST";
86
87 #[test]
88 fn roundtrip() {
89 let payload: Vec<i32> = vec![1, -2, 3, -4];
90 let bytes = encode(MAGIC, 7, &payload).unwrap();
91 let restored: Vec<i32> = decode(MAGIC, 7, &bytes).unwrap();
92 assert_eq!(restored, payload);
93 }
94
95 #[test]
96 fn rejects_short_buffer() {
97 let err = decode::<Vec<u8>>(MAGIC, 1, &[0u8; HEADER_LEN - 1]).unwrap_err();
98 matches!(err, CodecError::Truncated { .. });
99 }
100
101 #[test]
102 fn rejects_bad_magic() {
103 let bytes = encode(b"NDOTH", 1, &0u8).unwrap();
104 let err = decode::<u8>(MAGIC, 1, &bytes).unwrap_err();
105 matches!(err, CodecError::Corrupt { .. });
106 }
107
108 #[test]
109 fn rejects_version_mismatch() {
110 let bytes = encode(MAGIC, 1, &0u8).unwrap();
111 let err = decode::<u8>(MAGIC, 2, &bytes).unwrap_err();
112 matches!(err, CodecError::Corrupt { .. });
113 }
114
115 #[test]
116 fn peek_version_returns_version_byte() {
117 let bytes = encode(MAGIC, 9, &0u8).unwrap();
118 assert_eq!(peek_version(MAGIC, &bytes), Some(9));
119 assert_eq!(peek_version(b"NDOTH", &bytes), None);
120 assert_eq!(peek_version(MAGIC, &[]), None);
121 }
122}