1use crate::error::CodecError;
18
19#[repr(u16)]
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26#[non_exhaustive]
27pub enum QuantMode {
28 Binary = 0,
29 RaBitQ = 1,
30 Bbq = 2,
31 TernaryPacked = 3,
33 TernarySimd = 4,
35 TurboQuant4b = 5,
36 Sq8 = 6,
37 Pq = 7,
38 Itq3S = 8,
40 PolarQuant = 9,
42}
43
44impl QuantMode {
45 fn bits_per_weight(self) -> u32 {
48 match self {
49 QuantMode::Binary => 1,
50 QuantMode::RaBitQ => 1,
51 QuantMode::TernaryPacked => 2, QuantMode::TernarySimd => 2,
53 QuantMode::Bbq => 1,
54 QuantMode::TurboQuant4b => 4,
55 QuantMode::Sq8 => 8,
56 QuantMode::Pq => 8,
57 QuantMode::Itq3S => 2,
58 QuantMode::PolarQuant => 4,
59 }
60 }
61}
62
63#[repr(C)]
69#[derive(Debug, Clone, Copy)]
70pub struct QuantHeader {
71 pub quant_mode: u16,
73 pub dim: u16,
74 pub global_scale: f32,
76 pub residual_norm: f32,
78 pub dot_quantized: f32,
80 pub outlier_bitmask: u64,
83 pub reserved: [u8; 8],
85}
86
87const _: () = assert!(core::mem::size_of::<QuantHeader>() == 32);
89
90const OUTLIER_ENTRY_BYTES: usize = 8;
94
95const ALIGN: usize = 128;
97
98pub fn target_size(quant_mode: QuantMode, dim: u16, outlier_count: u32) -> usize {
106 let packed_bits_bytes = packed_bits_len(quant_mode, dim);
107 let outlier_bytes = outlier_count as usize * OUTLIER_ENTRY_BYTES;
108 let raw = core::mem::size_of::<QuantHeader>() + packed_bits_bytes + outlier_bytes;
109 round_up_128(raw)
110}
111
112#[inline]
115fn packed_bits_len(quant_mode: QuantMode, dim: u16) -> usize {
116 let bpw = quant_mode.bits_per_weight() as usize;
117 let total_bits = dim as usize * bpw;
118 total_bits.div_ceil(8)
119}
120
121#[inline]
122fn round_up_128(n: usize) -> usize {
123 (n + ALIGN - 1) & !(ALIGN - 1)
124}
125
126pub struct UnifiedQuantizedVector {
138 buf: Vec<u8>,
140 packed_bits_len: usize,
142}
143
144impl UnifiedQuantizedVector {
145 pub fn new(
159 header: QuantHeader,
160 packed_bits: &[u8],
161 outliers: &[(u32, f32)],
162 ) -> Result<Self, CodecError> {
163 let expected_outlier_count = header.outlier_bitmask.count_ones() as usize;
164 if outliers.len() != expected_outlier_count {
165 return Err(CodecError::LayoutError {
166 detail: format!(
167 "outlier count mismatch: bitmask has {} bits set but {} outliers provided",
168 expected_outlier_count,
169 outliers.len()
170 ),
171 });
172 }
173 for &(dim_idx, _) in outliers {
174 if dim_idx >= 64 {
175 return Err(CodecError::LayoutError {
176 detail: format!("outlier dim_index {dim_idx} exceeds bitmask capacity of 64"),
177 });
178 }
179 }
180
181 let header_bytes = core::mem::size_of::<QuantHeader>();
182 let outlier_bytes = outliers.len() * OUTLIER_ENTRY_BYTES;
183 let raw = header_bytes + packed_bits.len() + outlier_bytes;
184 let total = round_up_128(raw);
185
186 let mut buf = vec![0u8; total];
187
188 let header_src = unsafe {
190 core::slice::from_raw_parts(&header as *const QuantHeader as *const u8, header_bytes)
191 };
192 buf[..header_bytes].copy_from_slice(header_src);
193
194 let pb_start = header_bytes;
196 let pb_end = pb_start + packed_bits.len();
197 buf[pb_start..pb_end].copy_from_slice(packed_bits);
198
199 let mut off = pb_end;
201 for &(dim_idx, value) in outliers {
202 buf[off..off + 4].copy_from_slice(&dim_idx.to_le_bytes());
203 buf[off + 4..off + 8].copy_from_slice(&value.to_le_bytes());
204 off += OUTLIER_ENTRY_BYTES;
205 }
206
207 Ok(Self {
208 buf,
209 packed_bits_len: packed_bits.len(),
210 })
211 }
212
213 #[inline]
217 pub fn header(&self) -> &QuantHeader {
218 let ptr = self.buf.as_ptr() as *const QuantHeader;
219 unsafe { &*ptr }
221 }
222
223 #[inline]
225 pub fn packed_bits(&self) -> &[u8] {
226 let start = core::mem::size_of::<QuantHeader>();
227 &self.buf[start..start + self.packed_bits_len]
228 }
229
230 #[inline]
232 pub fn outlier_count(&self) -> u32 {
233 self.header().outlier_bitmask.count_ones()
234 }
235
236 pub fn outlier_at(&self, slot: u32) -> Option<(u32, f32)> {
245 if slot >= 64 {
246 return None;
247 }
248 let bitmask = self.header().outlier_bitmask;
249 if bitmask & (1u64 << slot) == 0 {
250 return None;
251 }
252 let mask = bitmask & ((1u64 << slot).wrapping_sub(1));
254 let offset = mask.count_ones() as usize;
255
256 let header_bytes = core::mem::size_of::<QuantHeader>();
257 let base = header_bytes + self.packed_bits_len + offset * OUTLIER_ENTRY_BYTES;
258
259 let dim_idx = u32::from_le_bytes(self.buf[base..base + 4].try_into().ok()?);
260 let value = f32::from_le_bytes(self.buf[base + 4..base + 8].try_into().ok()?);
261 Some((dim_idx, value))
262 }
263
264 #[inline]
266 pub fn as_bytes(&self) -> &[u8] {
267 &self.buf
268 }
269}
270
271pub struct UnifiedQuantizedVectorRef<'a> {
277 buf: &'a [u8],
278 packed_bits_len: usize,
279}
280
281impl<'a> UnifiedQuantizedVectorRef<'a> {
282 pub fn from_bytes(buf: &'a [u8], packed_bits_len: usize) -> Result<Self, CodecError> {
289 let header_bytes = core::mem::size_of::<QuantHeader>();
290 if buf.len() < header_bytes + packed_bits_len {
291 return Err(CodecError::LayoutError {
292 detail: format!(
293 "buffer too short: need at least {} bytes, got {}",
294 header_bytes + packed_bits_len,
295 buf.len()
296 ),
297 });
298 }
299 Ok(Self {
300 buf,
301 packed_bits_len,
302 })
303 }
304
305 #[inline]
307 pub fn header(&self) -> &QuantHeader {
308 let ptr = self.buf.as_ptr() as *const QuantHeader;
309 unsafe { &*ptr }
311 }
312
313 #[inline]
315 pub fn packed_bits(&self) -> &[u8] {
316 let start = core::mem::size_of::<QuantHeader>();
317 &self.buf[start..start + self.packed_bits_len]
318 }
319
320 #[inline]
322 pub fn outlier_count(&self) -> u32 {
323 self.header().outlier_bitmask.count_ones()
324 }
325
326 pub fn outlier_at(&self, slot: u32) -> Option<(u32, f32)> {
328 if slot >= 64 {
329 return None;
330 }
331 let bitmask = self.header().outlier_bitmask;
332 if bitmask & (1u64 << slot) == 0 {
333 return None;
334 }
335 let mask = bitmask & ((1u64 << slot).wrapping_sub(1));
336 let offset = mask.count_ones() as usize;
337
338 let header_bytes = core::mem::size_of::<QuantHeader>();
339 let base = header_bytes + self.packed_bits_len + offset * OUTLIER_ENTRY_BYTES;
340
341 let dim_idx = u32::from_le_bytes(self.buf[base..base + 4].try_into().ok()?);
342 let value = f32::from_le_bytes(self.buf[base + 4..base + 8].try_into().ok()?);
343 Some((dim_idx, value))
344 }
345}
346
347#[cfg(test)]
350mod tests {
351 use super::*;
352
353 fn make_header(mode: QuantMode, dim: u16, bitmask: u64) -> QuantHeader {
354 QuantHeader {
355 quant_mode: mode as u16,
356 dim,
357 global_scale: 1.5,
358 residual_norm: 0.25,
359 dot_quantized: 2.5,
360 outlier_bitmask: bitmask,
361 reserved: [0xAB; 8],
362 }
363 }
364
365 #[test]
366 fn header_is_32_bytes() {
367 assert_eq!(core::mem::size_of::<QuantHeader>(), 32);
369 }
370
371 #[test]
372 fn target_size_is_128_multiple() {
373 for mode in [
374 QuantMode::Binary,
375 QuantMode::RaBitQ,
376 QuantMode::TernarySimd,
377 QuantMode::TurboQuant4b,
378 QuantMode::Sq8,
379 ] {
380 for dim in [64u16, 128, 256, 512, 1536] {
381 for outliers in [0u32, 1, 8, 64] {
382 let sz = target_size(mode, dim, outliers);
383 assert_eq!(
384 sz % 128,
385 0,
386 "target_size not 128-aligned for {mode:?}/{dim}/{outliers}"
387 );
388 assert!(
389 sz >= 128,
390 "target_size below minimum for {mode:?}/{dim}/{outliers}"
391 );
392 }
393 }
394 }
395 }
396
397 #[test]
398 fn no_outliers_roundtrip() {
399 let header = make_header(QuantMode::Binary, 128, 0);
400 let packed = vec![0xFFu8; 16]; let vec = UnifiedQuantizedVector::new(header, &packed, &[]).unwrap();
402
403 assert_eq!(vec.outlier_count(), 0);
404 assert_eq!(vec.packed_bits(), packed.as_slice());
405 assert_eq!(vec.as_bytes().len() % 128, 0);
406 }
407
408 #[test]
409 fn one_outlier_roundtrip() {
410 let bitmask: u64 = 1 << 5;
412 let header = make_header(QuantMode::Sq8, 64, bitmask);
413 let packed = vec![0u8; 64]; let outliers = [(5u32, 42.0f32)];
415 let vec = UnifiedQuantizedVector::new(header, &packed, &outliers).unwrap();
416
417 assert_eq!(vec.outlier_count(), 1);
418 let (dim, val) = vec.outlier_at(5).expect("bit 5 should be set");
419 assert_eq!(dim, 5);
420 assert!((val - 42.0).abs() < f32::EPSILON);
421 assert!(vec.outlier_at(0).is_none());
422 assert!(vec.outlier_at(6).is_none());
423 }
424
425 #[test]
426 fn eight_outliers_roundtrip() {
427 let bits: &[u32] = &[0, 3, 7, 12, 20, 33, 50, 63];
429 let mut bitmask: u64 = 0;
430 for &b in bits {
431 bitmask |= 1 << b;
432 }
433 let header = make_header(QuantMode::TurboQuant4b, 128, bitmask);
434 let packed = vec![0xAAu8; 64]; let outlier_list: Vec<(u32, f32)> = bits
436 .iter()
437 .enumerate()
438 .map(|(i, &b)| (b, i as f32 * 1.1))
439 .collect();
440 let vec = UnifiedQuantizedVector::new(header, &packed, &outlier_list).unwrap();
441
442 assert_eq!(vec.outlier_count(), 8);
443 for (i, &b) in bits.iter().enumerate() {
444 let (dim, val) = vec
445 .outlier_at(b)
446 .unwrap_or_else(|| panic!("outlier at {b} missing"));
447 assert_eq!(dim, b);
448 assert!(
449 (val - i as f32 * 1.1f32).abs() < 1e-5,
450 "value mismatch at dim {b}"
451 );
452 }
453 }
454
455 #[test]
456 fn as_bytes_reborrow_via_ref() {
457 let bitmask: u64 = 1 << 10;
458 let header = make_header(QuantMode::RaBitQ, 64, bitmask);
459 let packed = vec![0u8; 8]; let outliers = [(10u32, 7.77f32)];
461 let vec = UnifiedQuantizedVector::new(header, &packed, &outliers).unwrap();
462
463 let bytes = vec.as_bytes();
464 let packed_bits_len = vec.packed_bits_len;
465 let vref = UnifiedQuantizedVectorRef::from_bytes(bytes, packed_bits_len).unwrap();
466
467 assert_eq!(vref.outlier_count(), 1);
468 let (dim, val) = vref.outlier_at(10).unwrap();
469 assert_eq!(dim, 10);
470 assert!((val - 7.77).abs() < 1e-5);
471 }
472
473 #[test]
474 fn header_field_roundtrip() {
475 let header = QuantHeader {
476 quant_mode: QuantMode::Bbq as u16,
477 dim: 512,
478 global_scale: 4.5,
479 residual_norm: 0.99,
480 dot_quantized: -1.23,
481 outlier_bitmask: 0xDEAD_BEEF_0000_0001,
482 reserved: [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08],
483 };
484 let packed = vec![0u8; packed_bits_len(QuantMode::Bbq, 512)];
485
486 let bitmask = header.outlier_bitmask;
488 let count = bitmask.count_ones() as usize;
489 let mut outliers: Vec<(u32, f32)> = Vec::with_capacity(count);
491 for bit in 0u32..64 {
492 if bitmask & (1u64 << bit) != 0 {
493 outliers.push((bit, bit as f32));
494 }
495 }
496
497 let vec = UnifiedQuantizedVector::new(header, &packed, &outliers).unwrap();
498 let h = vec.header();
499
500 assert_eq!(h.quant_mode, QuantMode::Bbq as u16);
501 assert_eq!(h.dim, 512);
502 assert!((h.global_scale - 4.5).abs() < 1e-5);
503 assert!((h.residual_norm - 0.99).abs() < 1e-5);
504 assert!((h.dot_quantized - (-1.23)).abs() < 1e-5);
505 assert_eq!(h.outlier_bitmask, 0xDEAD_BEEF_0000_0001);
506 assert_eq!(h.reserved, [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]);
507 }
508
509 #[test]
510 fn outlier_ordering_popcnt() {
511 let bitmask: u64 = (1 << 3) | (1 << 17) | (1 << 40);
513 let header = make_header(QuantMode::Sq8, 64, bitmask);
514 let packed = vec![0u8; 64];
515 let outliers = [(3u32, 100.0f32), (17u32, 200.0f32), (40u32, 300.0f32)];
516 let vec = UnifiedQuantizedVector::new(header, &packed, &outliers).unwrap();
517
518 let (dim, val) = vec.outlier_at(17).expect("dim 17 should be an outlier");
520 assert_eq!(dim, 17);
521 assert!((val - 200.0).abs() < f32::EPSILON);
522
523 let (dim0, val0) = vec.outlier_at(3).expect("dim 3 should be an outlier");
524 assert_eq!(dim0, 3);
525 assert!((val0 - 100.0).abs() < f32::EPSILON);
526
527 let (dim2, val2) = vec.outlier_at(40).expect("dim 40 should be an outlier");
528 assert_eq!(dim2, 40);
529 assert!((val2 - 300.0).abs() < f32::EPSILON);
530 }
531
532 #[test]
533 fn out_of_range_slot_returns_none() {
534 let header = make_header(QuantMode::Binary, 64, 0);
535 let packed = vec![0u8; 8];
536 let vec = UnifiedQuantizedVector::new(header, &packed, &[]).unwrap();
537
538 assert!(vec.outlier_at(64).is_none(), "slot 64 is out of range");
539 assert!(vec.outlier_at(80).is_none(), "slot 80 is out of range");
540 assert!(
541 vec.outlier_at(u32::MAX).is_none(),
542 "slot u32::MAX is out of range"
543 );
544 }
545
546 #[test]
547 fn outlier_count_mismatch_is_error() {
548 let bitmask: u64 = 1 << 2;
550 let header = make_header(QuantMode::Binary, 64, bitmask);
551 let packed = vec![0u8; 8];
552 let err = UnifiedQuantizedVector::new(header, &packed, &[]);
553 assert!(
554 err.is_err(),
555 "should fail when outlier count mismatches bitmask"
556 );
557 }
558}