1use std::fmt;
2use std::hash::{Hash, Hasher};
3
4use bb_core::embedding::Embedding;
5
6pub const fn bytes_for_nbits(nbits: usize) -> usize {
8 (nbits + 7) / 8
9}
10
11#[derive(Clone, PartialEq, Eq)]
22pub struct PQCode<const M: usize, const NBITS: usize>
23where
24 [(); bytes_for_nbits(NBITS)]:,
25{
26 pub codes: [[u8; bytes_for_nbits(NBITS)]; M],
28}
29
30impl<const M: usize, const NBITS: usize> PQCode<M, NBITS>
31where
32 [(); bytes_for_nbits(NBITS)]:,
33{
34 pub const KSUB: usize = 1 << NBITS;
36
37 pub const BYTES_PER_CODE: usize = bytes_for_nbits(NBITS);
39
40 pub const TOTAL_BYTES: usize = M * Self::BYTES_PER_CODE;
42
43 pub fn new(codes: [[u8; bytes_for_nbits(NBITS)]; M]) -> Self {
45 Self { codes }
46 }
47
48 pub fn get(&self, m: usize) -> u32 {
50 let mut value = 0u32;
51 for (i, &byte) in self.codes[m].iter().enumerate() {
52 value |= (byte as u32) << (i * 8);
53 }
54 value
55 }
56
57 pub fn set(&mut self, m: usize, value: u32) {
59 for i in 0..Self::BYTES_PER_CODE {
60 self.codes[m][i] = ((value >> (i * 8)) & 0xFF) as u8;
61 }
62 }
63
64 pub fn m(&self) -> usize {
66 M
67 }
68
69 pub fn zeros() -> Self {
71 Self { codes: [[0u8; bytes_for_nbits(NBITS)]; M] }
72 }
73
74 pub fn from_indices(indices: &[u32]) -> Self {
76 let mut code = Self::zeros();
77 for (m, &idx) in indices.iter().take(M).enumerate() {
78 code.set(m, idx);
79 }
80 code
81 }
82}
83
84impl<const M: usize, const NBITS: usize> fmt::Debug for PQCode<M, NBITS>
85where
86 [(); bytes_for_nbits(NBITS)]:,
87{
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 let indices: Vec<u32> = (0..M).map(|m| self.get(m)).collect();
90 f.debug_struct("PQCode")
91 .field("M", &M)
92 .field("NBITS", &NBITS)
93 .field("indices", &indices)
94 .finish()
95 }
96}
97
98impl<const M: usize, const NBITS: usize> fmt::Display for PQCode<M, NBITS>
99where
100 [(); bytes_for_nbits(NBITS)]:,
101{
102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 let indices: Vec<u32> = (0..M).map(|m| self.get(m)).collect();
104 write!(f, "PQCode<{}, {}>({:?})", M, NBITS, indices)
105 }
106}
107
108impl<const M: usize, const NBITS: usize> Hash for PQCode<M, NBITS>
109where
110 [(); bytes_for_nbits(NBITS)]:,
111{
112 fn hash<H: Hasher>(&self, state: &mut H) {
113 self.codes.hash(state);
114 }
115}
116
117impl<const M: usize, const NBITS: usize> PQCode<M, NBITS>
118where
119 [(); bytes_for_nbits(NBITS)]:,
120{
121 const _SIZE_CHECK: () = assert!(
124 std::mem::size_of::<[[u8; bytes_for_nbits(NBITS)]; M]>() == M * bytes_for_nbits(NBITS)
125 );
126}
127
128impl<const M: usize, const NBITS: usize> Embedding for PQCode<M, NBITS>
129where
130 [(); bytes_for_nbits(NBITS)]:,
131{
132 type Scalar = u8;
133
134 fn length() -> usize {
135 Self::TOTAL_BYTES
136 }
137
138 fn as_slice(&self) -> &[Self::Scalar] {
139 let _ = Self::_SIZE_CHECK;
141 unsafe {
142 std::slice::from_raw_parts(
143 self.codes.as_ptr() as *const u8,
144 Self::TOTAL_BYTES
145 )
146 }
147 }
148
149 fn from_slice(data: &[Self::Scalar]) -> Self {
150 let _ = Self::_SIZE_CHECK;
151 let mut codes = [[0u8; bytes_for_nbits(NBITS)]; M];
152 let total_bytes = M * bytes_for_nbits(NBITS);
153 let copy_len = data.len().min(total_bytes);
154 let flat = unsafe {
156 std::slice::from_raw_parts_mut(
157 codes.as_mut_ptr() as *mut u8,
158 total_bytes
159 )
160 };
161 flat[..copy_len].copy_from_slice(&data[..copy_len]);
162 Self { codes }
163 }
164
165 fn zeros() -> Self {
166 Self::zeros()
167 }
168}
169
170#[cfg(feature = "proto")]
171impl<const M: usize, const NBITS: usize> From<PQCode<M, NBITS>> for bb_core::proto::TensorProto
172where
173 [(); bytes_for_nbits(NBITS)]:,
174{
175 fn from(code: PQCode<M, NBITS>) -> Self {
176 bb_core::proto::TensorProto {
177 dims: vec![M as i64, PQCode::<M, NBITS>::BYTES_PER_CODE as i64],
178 data_type: bb_core::proto::DATA_TYPE_UINT8,
179 raw_data: code.as_slice().to_vec(),
180 ..Default::default()
181 }
182 }
183}
184
185#[cfg(feature = "proto")]
186impl<const M: usize, const NBITS: usize> TryFrom<bb_core::proto::TensorProto> for PQCode<M, NBITS>
187where
188 [(); bytes_for_nbits(NBITS)]:,
189{
190 type Error = bb_core::proto::ProtoConversionError;
191
192 fn try_from(proto: bb_core::proto::TensorProto) -> Result<Self, Self::Error> {
193 use bb_core::proto::{ProtoConversionError, DATA_TYPE_UINT8};
194
195 if proto.data_type != DATA_TYPE_UINT8 {
196 return Err(ProtoConversionError::InvalidDataType {
197 expected: DATA_TYPE_UINT8,
198 actual: proto.data_type,
199 });
200 }
201
202 let expected_dims = vec![M as i64, Self::BYTES_PER_CODE as i64];
203 if proto.dims != expected_dims {
204 return Err(ProtoConversionError::InvalidTensorShape {
205 expected: expected_dims,
206 actual: proto.dims,
207 });
208 }
209
210 if proto.raw_data.len() != Self::TOTAL_BYTES {
211 return Err(ProtoConversionError::ConversionFailed(format!(
212 "Expected {} bytes in TensorProto raw_data, got {}",
213 Self::TOTAL_BYTES,
214 proto.raw_data.len()
215 )));
216 }
217
218 Ok(Self::from_slice(&proto.raw_data))
219 }
220}
221
222pub type PQCode8<const M: usize> = PQCode<M, 8>;
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn test_pq_code_creation_nbits8() {
231 let code = PQCode::<4, 8>::from_indices(&[1, 2, 3, 4]);
232 assert_eq!(code.get(0), 1);
233 assert_eq!(code.get(1), 2);
234 assert_eq!(code.get(2), 3);
235 assert_eq!(code.get(3), 4);
236 assert_eq!(code.m(), 4);
237 }
238
239 #[test]
240 fn test_pq_code_creation_nbits10() {
241 let code = PQCode::<4, 10>::from_indices(&[500, 1000, 100, 1023]);
243 assert_eq!(code.get(0), 500);
244 assert_eq!(code.get(1), 1000);
245 assert_eq!(code.get(2), 100);
246 assert_eq!(code.get(3), 1023);
247 }
248
249 #[test]
250 fn test_pq_code_creation_nbits16() {
251 let code = PQCode::<2, 16>::from_indices(&[65535, 32768]);
253 assert_eq!(code.get(0), 65535);
254 assert_eq!(code.get(1), 32768);
255 }
256
257 #[test]
258 fn test_bytes_for_nbits() {
259 assert_eq!(bytes_for_nbits(1), 1);
260 assert_eq!(bytes_for_nbits(4), 1);
261 assert_eq!(bytes_for_nbits(8), 1);
262 assert_eq!(bytes_for_nbits(9), 2);
263 assert_eq!(bytes_for_nbits(10), 2);
264 assert_eq!(bytes_for_nbits(16), 2);
265 assert_eq!(bytes_for_nbits(17), 3);
266 assert_eq!(bytes_for_nbits(24), 3);
267 }
268
269 #[test]
270 fn test_pq_code_total_bytes() {
271 assert_eq!(PQCode::<8, 8>::TOTAL_BYTES, 8);
272 assert_eq!(PQCode::<8, 10>::TOTAL_BYTES, 16);
273 assert_eq!(PQCode::<16, 8>::TOTAL_BYTES, 16);
274 assert_eq!(PQCode::<16, 10>::TOTAL_BYTES, 32);
275 }
276
277 #[test]
278 fn test_pq_code_embedding_trait() {
279 assert_eq!(PQCode::<8, 8>::length(), 8);
280 assert_eq!(PQCode::<8, 10>::length(), 16);
281
282 let code = PQCode::<4, 8>::from_indices(&[5, 6, 7, 8]);
283 assert_eq!(code.as_slice(), &[5, 6, 7, 8]);
284
285 let zeros = PQCode::<4, 8>::zeros();
286 assert_eq!(zeros.as_slice(), &[0, 0, 0, 0]);
287 }
288
289 #[test]
290 fn test_pq_code_embedding_trait_nbits10() {
291 let code = PQCode::<2, 10>::from_indices(&[500, 1000]);
293 let slice = code.as_slice();
294 assert_eq!(slice.len(), 4); assert_eq!(slice[0], 0xF4);
299 assert_eq!(slice[1], 0x01);
300 assert_eq!(slice[2], 0xE8);
302 assert_eq!(slice[3], 0x03);
303 }
304
305 #[test]
306 fn test_pq_code_hash() {
307 use std::collections::HashMap;
308 let mut map = HashMap::new();
309
310 let code1 = PQCode::<3, 8>::from_indices(&[1, 2, 3]);
311 let code2 = PQCode::<3, 8>::from_indices(&[1, 2, 3]);
312 let code3 = PQCode::<3, 8>::from_indices(&[3, 2, 1]);
313
314 map.insert(code1.clone(), "value1");
315 map.insert(code3, "value3");
316
317 assert_eq!(map.get(&code2), Some(&"value1"));
318 assert_eq!(map.len(), 2);
319 }
320
321 #[test]
322 fn test_pqcode8_alias() {
323 let code1: PQCode8<4> = PQCode::from_indices(&[1, 2, 3, 4]);
325 let code2: PQCode<4, 8> = PQCode::from_indices(&[1, 2, 3, 4]);
326 assert_eq!(code1, code2);
327 }
328}