Skip to main content

nodedb_codec/vector_quant/
codec_envelope.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Self-describing envelope used by quantization codecs that persist their
4//! calibrated parameters (BBQ, RaBitQ, …).
5//!
6//! Wire layout: `magic` (5 bytes) + `version` (1 byte) + MessagePack body
7//! produced by `zerompk`. Codecs choose their own 5-byte magic (e.g.
8//! `b"NDBBQ"`, `b"NDRBQ"`) so a mismatched buffer fails fast with a typed
9//! [`CodecError`] instead of an opaque MessagePack decode error.
10
11use crate::error::CodecError;
12use zerompk::{FromMessagePack, ToMessagePack};
13
14/// Length of the magic prefix.
15pub const MAGIC_LEN: usize = 5;
16
17/// Total length of the envelope header (`magic` + `version`).
18pub const HEADER_LEN: usize = MAGIC_LEN + 1;
19
20/// Serialize `value` into a versioned, magic-tagged byte buffer.
21pub fn encode<T: ToMessagePack>(
22    magic: &[u8; MAGIC_LEN],
23    version: u8,
24    value: &T,
25) -> Result<Vec<u8>, CodecError> {
26    let body = zerompk::to_msgpack_vec(value).map_err(|e| CodecError::Corrupt {
27        detail: e.to_string(),
28    })?;
29    let mut out = Vec::with_capacity(HEADER_LEN + body.len());
30    out.extend_from_slice(magic);
31    out.push(version);
32    out.extend_from_slice(&body);
33    Ok(out)
34}
35
36/// Validate the envelope header and deserialize the body into `T`.
37///
38/// `expected_version` is checked exactly — bumping a codec's on-disk format
39/// is an explicit decision; silently accepting older bodies has bitten us
40/// before. Callers that need a window of compatible versions should call
41/// [`peek_version`] and dispatch.
42pub fn decode<T: for<'de> FromMessagePack<'de>>(
43    magic: &[u8; MAGIC_LEN],
44    expected_version: u8,
45    buf: &[u8],
46) -> Result<T, CodecError> {
47    if buf.len() < HEADER_LEN {
48        return Err(CodecError::Truncated {
49            expected: HEADER_LEN,
50            actual: buf.len(),
51        });
52    }
53    if &buf[..MAGIC_LEN] != magic {
54        return Err(CodecError::Corrupt {
55            detail: "bad magic".into(),
56        });
57    }
58    let got = buf[MAGIC_LEN];
59    if got != expected_version {
60        return Err(CodecError::Corrupt {
61            detail: format!("unsupported version {got}"),
62        });
63    }
64    zerompk::from_msgpack(&buf[HEADER_LEN..]).map_err(|e| CodecError::Corrupt {
65        detail: e.to_string(),
66    })
67}
68
69/// Read the version byte without decoding the body. Returns `None` if the
70/// buffer is too short or the magic does not match.
71pub fn peek_version(magic: &[u8; MAGIC_LEN], buf: &[u8]) -> Option<u8> {
72    if buf.len() < HEADER_LEN {
73        return None;
74    }
75    if &buf[..MAGIC_LEN] != magic {
76        return None;
77    }
78    Some(buf[MAGIC_LEN])
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84
85    const MAGIC: &[u8; MAGIC_LEN] = b"NDTST";
86
87    #[test]
88    fn roundtrip() {
89        let payload: Vec<i32> = vec![1, -2, 3, -4];
90        let bytes = encode(MAGIC, 7, &payload).unwrap();
91        let restored: Vec<i32> = decode(MAGIC, 7, &bytes).unwrap();
92        assert_eq!(restored, payload);
93    }
94
95    #[test]
96    fn rejects_short_buffer() {
97        let err = decode::<Vec<u8>>(MAGIC, 1, &[0u8; HEADER_LEN - 1]).unwrap_err();
98        matches!(err, CodecError::Truncated { .. });
99    }
100
101    #[test]
102    fn rejects_bad_magic() {
103        let bytes = encode(b"NDOTH", 1, &0u8).unwrap();
104        let err = decode::<u8>(MAGIC, 1, &bytes).unwrap_err();
105        matches!(err, CodecError::Corrupt { .. });
106    }
107
108    #[test]
109    fn rejects_version_mismatch() {
110        let bytes = encode(MAGIC, 1, &0u8).unwrap();
111        let err = decode::<u8>(MAGIC, 2, &bytes).unwrap_err();
112        matches!(err, CodecError::Corrupt { .. });
113    }
114
115    #[test]
116    fn peek_version_returns_version_byte() {
117        let bytes = encode(MAGIC, 9, &0u8).unwrap();
118        assert_eq!(peek_version(MAGIC, &bytes), Some(9));
119        assert_eq!(peek_version(b"NDOTH", &bytes), None);
120        assert_eq!(peek_version(MAGIC, &[]), None);
121    }
122}