Skip to main content

bytesandbrains_codec/pq/
code.rs

1use std::fmt;
2use std::hash::{Hash, Hasher};
3
4use bb_core::embedding::Embedding;
5
6/// Bytes needed for NBITS (compile-time ceil(NBITS/8))
7pub const fn bytes_for_nbits(nbits: usize) -> usize {
8    (nbits + 7) / 8
9}
10
11/// Encoded representation from Product Quantization.
12///
13/// Stores M centroid indices, each using ceil(NBITS/8) bytes (byte-aligned).
14///
15/// - M: number of subquantizers
16/// - NBITS: bits per centroid index
17///
18/// Examples:
19/// - `PQCode<8, 8>`: 8 subquantizers, 256 centroids each, 8 bytes total
20/// - `PQCode<16, 10>`: 16 subquantizers, 1024 centroids each, 32 bytes total
21#[derive(Clone, PartialEq, Eq)]
22pub struct PQCode<const M: usize, const NBITS: usize>
23where
24    [(); bytes_for_nbits(NBITS)]:,
25{
26    /// Raw byte storage: M * ceil(NBITS/8) bytes, stored as [[u8; B]; M]
27    pub codes: [[u8; bytes_for_nbits(NBITS)]; M],
28}
29
30impl<const M: usize, const NBITS: usize> PQCode<M, NBITS>
31where
32    [(); bytes_for_nbits(NBITS)]:,
33{
34    /// Number of centroids per subquantizer (2^NBITS)
35    pub const KSUB: usize = 1 << NBITS;
36
37    /// Bytes per centroid index
38    pub const BYTES_PER_CODE: usize = bytes_for_nbits(NBITS);
39
40    /// Total bytes for all M codes
41    pub const TOTAL_BYTES: usize = M * Self::BYTES_PER_CODE;
42
43    /// Create a new PQCode from raw byte array
44    pub fn new(codes: [[u8; bytes_for_nbits(NBITS)]; M]) -> Self {
45        Self { codes }
46    }
47
48    /// Get centroid index at position m as u32
49    pub fn get(&self, m: usize) -> u32 {
50        let mut value = 0u32;
51        for (i, &byte) in self.codes[m].iter().enumerate() {
52            value |= (byte as u32) << (i * 8);
53        }
54        value
55    }
56
57    /// Set centroid index at position m
58    pub fn set(&mut self, m: usize, value: u32) {
59        for i in 0..Self::BYTES_PER_CODE {
60            self.codes[m][i] = ((value >> (i * 8)) & 0xFF) as u8;
61        }
62    }
63
64    /// Number of subquantizers
65    pub fn m(&self) -> usize {
66        M
67    }
68
69    /// Create a zeroed PQCode
70    pub fn zeros() -> Self {
71        Self { codes: [[0u8; bytes_for_nbits(NBITS)]; M] }
72    }
73
74    /// Create from individual centroid indices
75    pub fn from_indices(indices: &[u32]) -> Self {
76        let mut code = Self::zeros();
77        for (m, &idx) in indices.iter().take(M).enumerate() {
78            code.set(m, idx);
79        }
80        code
81    }
82}
83
84impl<const M: usize, const NBITS: usize> fmt::Debug for PQCode<M, NBITS>
85where
86    [(); bytes_for_nbits(NBITS)]:,
87{
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        let indices: Vec<u32> = (0..M).map(|m| self.get(m)).collect();
90        f.debug_struct("PQCode")
91            .field("M", &M)
92            .field("NBITS", &NBITS)
93            .field("indices", &indices)
94            .finish()
95    }
96}
97
98impl<const M: usize, const NBITS: usize> fmt::Display for PQCode<M, NBITS>
99where
100    [(); bytes_for_nbits(NBITS)]:,
101{
102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103        let indices: Vec<u32> = (0..M).map(|m| self.get(m)).collect();
104        write!(f, "PQCode<{}, {}>({:?})", M, NBITS, indices)
105    }
106}
107
108impl<const M: usize, const NBITS: usize> Hash for PQCode<M, NBITS>
109where
110    [(); bytes_for_nbits(NBITS)]:,
111{
112    fn hash<H: Hasher>(&self, state: &mut H) {
113        self.codes.hash(state);
114    }
115}
116
117impl<const M: usize, const NBITS: usize> PQCode<M, NBITS>
118where
119    [(); bytes_for_nbits(NBITS)]:,
120{
121    /// Compile-time assertion: verify [[u8; B]; M] has same size as [u8; M*B]
122    /// This ensures our unsafe transmutation in Embedding impl is sound.
123    const _SIZE_CHECK: () = assert!(
124        std::mem::size_of::<[[u8; bytes_for_nbits(NBITS)]; M]>() == M * bytes_for_nbits(NBITS)
125    );
126}
127
128impl<const M: usize, const NBITS: usize> Embedding for PQCode<M, NBITS>
129where
130    [(); bytes_for_nbits(NBITS)]:,
131{
132    type Scalar = u8;
133
134    fn length() -> usize {
135        Self::TOTAL_BYTES
136    }
137
138    fn as_slice(&self) -> &[Self::Scalar] {
139        // Safe: [[u8; B]; M] has same layout as [u8; M*B] (verified by _SIZE_CHECK)
140        let _ = Self::_SIZE_CHECK;
141        unsafe {
142            std::slice::from_raw_parts(
143                self.codes.as_ptr() as *const u8,
144                Self::TOTAL_BYTES
145            )
146        }
147    }
148
149    fn from_slice(data: &[Self::Scalar]) -> Self {
150        let _ = Self::_SIZE_CHECK;
151        let mut codes = [[0u8; bytes_for_nbits(NBITS)]; M];
152        let total_bytes = M * bytes_for_nbits(NBITS);
153        let copy_len = data.len().min(total_bytes);
154        // Safe: [[u8; B]; M] has same layout as [u8; M*B] (verified by _SIZE_CHECK)
155        let flat = unsafe {
156            std::slice::from_raw_parts_mut(
157                codes.as_mut_ptr() as *mut u8,
158                total_bytes
159            )
160        };
161        flat[..copy_len].copy_from_slice(&data[..copy_len]);
162        Self { codes }
163    }
164
165    fn zeros() -> Self {
166        Self::zeros()
167    }
168}
169
170#[cfg(feature = "proto")]
171impl<const M: usize, const NBITS: usize> From<PQCode<M, NBITS>> for bb_core::proto::TensorProto
172where
173    [(); bytes_for_nbits(NBITS)]:,
174{
175    fn from(code: PQCode<M, NBITS>) -> Self {
176        bb_core::proto::TensorProto {
177            dims: vec![M as i64, PQCode::<M, NBITS>::BYTES_PER_CODE as i64],
178            data_type: bb_core::proto::DATA_TYPE_UINT8,
179            raw_data: code.as_slice().to_vec(),
180            ..Default::default()
181        }
182    }
183}
184
185#[cfg(feature = "proto")]
186impl<const M: usize, const NBITS: usize> TryFrom<bb_core::proto::TensorProto> for PQCode<M, NBITS>
187where
188    [(); bytes_for_nbits(NBITS)]:,
189{
190    type Error = bb_core::proto::ProtoConversionError;
191
192    fn try_from(proto: bb_core::proto::TensorProto) -> Result<Self, Self::Error> {
193        use bb_core::proto::{ProtoConversionError, DATA_TYPE_UINT8};
194
195        if proto.data_type != DATA_TYPE_UINT8 {
196            return Err(ProtoConversionError::InvalidDataType {
197                expected: DATA_TYPE_UINT8,
198                actual: proto.data_type,
199            });
200        }
201
202        let expected_dims = vec![M as i64, Self::BYTES_PER_CODE as i64];
203        if proto.dims != expected_dims {
204            return Err(ProtoConversionError::InvalidTensorShape {
205                expected: expected_dims,
206                actual: proto.dims,
207            });
208        }
209
210        if proto.raw_data.len() != Self::TOTAL_BYTES {
211            return Err(ProtoConversionError::ConversionFailed(format!(
212                "Expected {} bytes in TensorProto raw_data, got {}",
213                Self::TOTAL_BYTES,
214                proto.raw_data.len()
215            )));
216        }
217
218        Ok(Self::from_slice(&proto.raw_data))
219    }
220}
221
222/// Backwards-compatible alias for common case (nbits=8)
223pub type PQCode8<const M: usize> = PQCode<M, 8>;
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[test]
230    fn test_pq_code_creation_nbits8() {
231        let code = PQCode::<4, 8>::from_indices(&[1, 2, 3, 4]);
232        assert_eq!(code.get(0), 1);
233        assert_eq!(code.get(1), 2);
234        assert_eq!(code.get(2), 3);
235        assert_eq!(code.get(3), 4);
236        assert_eq!(code.m(), 4);
237    }
238
239    #[test]
240    fn test_pq_code_creation_nbits10() {
241        // 10 bits = 2 bytes per code, up to 1024 centroids
242        let code = PQCode::<4, 10>::from_indices(&[500, 1000, 100, 1023]);
243        assert_eq!(code.get(0), 500);
244        assert_eq!(code.get(1), 1000);
245        assert_eq!(code.get(2), 100);
246        assert_eq!(code.get(3), 1023);
247    }
248
249    #[test]
250    fn test_pq_code_creation_nbits16() {
251        // 16 bits = 2 bytes per code, up to 65536 centroids
252        let code = PQCode::<2, 16>::from_indices(&[65535, 32768]);
253        assert_eq!(code.get(0), 65535);
254        assert_eq!(code.get(1), 32768);
255    }
256
257    #[test]
258    fn test_bytes_for_nbits() {
259        assert_eq!(bytes_for_nbits(1), 1);
260        assert_eq!(bytes_for_nbits(4), 1);
261        assert_eq!(bytes_for_nbits(8), 1);
262        assert_eq!(bytes_for_nbits(9), 2);
263        assert_eq!(bytes_for_nbits(10), 2);
264        assert_eq!(bytes_for_nbits(16), 2);
265        assert_eq!(bytes_for_nbits(17), 3);
266        assert_eq!(bytes_for_nbits(24), 3);
267    }
268
269    #[test]
270    fn test_pq_code_total_bytes() {
271        assert_eq!(PQCode::<8, 8>::TOTAL_BYTES, 8);
272        assert_eq!(PQCode::<8, 10>::TOTAL_BYTES, 16);
273        assert_eq!(PQCode::<16, 8>::TOTAL_BYTES, 16);
274        assert_eq!(PQCode::<16, 10>::TOTAL_BYTES, 32);
275    }
276
277    #[test]
278    fn test_pq_code_embedding_trait() {
279        assert_eq!(PQCode::<8, 8>::length(), 8);
280        assert_eq!(PQCode::<8, 10>::length(), 16);
281
282        let code = PQCode::<4, 8>::from_indices(&[5, 6, 7, 8]);
283        assert_eq!(code.as_slice(), &[5, 6, 7, 8]);
284
285        let zeros = PQCode::<4, 8>::zeros();
286        assert_eq!(zeros.as_slice(), &[0, 0, 0, 0]);
287    }
288
289    #[test]
290    fn test_pq_code_embedding_trait_nbits10() {
291        // nbits=10 uses 2 bytes per code
292        let code = PQCode::<2, 10>::from_indices(&[500, 1000]);
293        let slice = code.as_slice();
294        assert_eq!(slice.len(), 4); // 2 codes * 2 bytes each
295
296        // Verify little-endian encoding
297        // 500 = 0x01F4 -> [0xF4, 0x01]
298        assert_eq!(slice[0], 0xF4);
299        assert_eq!(slice[1], 0x01);
300        // 1000 = 0x03E8 -> [0xE8, 0x03]
301        assert_eq!(slice[2], 0xE8);
302        assert_eq!(slice[3], 0x03);
303    }
304
305    #[test]
306    fn test_pq_code_hash() {
307        use std::collections::HashMap;
308        let mut map = HashMap::new();
309
310        let code1 = PQCode::<3, 8>::from_indices(&[1, 2, 3]);
311        let code2 = PQCode::<3, 8>::from_indices(&[1, 2, 3]);
312        let code3 = PQCode::<3, 8>::from_indices(&[3, 2, 1]);
313
314        map.insert(code1.clone(), "value1");
315        map.insert(code3, "value3");
316
317        assert_eq!(map.get(&code2), Some(&"value1"));
318        assert_eq!(map.len(), 2);
319    }
320
321    #[test]
322    fn test_pqcode8_alias() {
323        // PQCode8<M> should be equivalent to PQCode<M, 8>
324        let code1: PQCode8<4> = PQCode::from_indices(&[1, 2, 3, 4]);
325        let code2: PQCode<4, 8> = PQCode::from_indices(&[1, 2, 3, 4]);
326        assert_eq!(code1, code2);
327    }
328}