#![allow(unsafe_code)]
use alloc::boxed::Box;
use alloc::sync::Arc;
use alloc::vec;
use alloc::vec::Vec;
use core::cmp::Ordering;
use core::fmt;
use crate::codec::codec_config::CodecConfig;
#[cfg(not(feature = "simd"))]
use crate::codec::kernels::scalar as scalar_kernel;
use crate::errors::CodecError;
#[derive(Clone)]
pub struct Codebook {
entries: Arc<[f32]>,
bit_width: u8,
}
impl Codebook {
pub fn new(entries: Box<[f32]>, bit_width: u8) -> Result<Self, CodecError> {
let expected = 1u32
.checked_shl(u32::from(bit_width))
.ok_or(CodecError::UnsupportedBitWidth { got: bit_width })?;
let got = u32::try_from(entries.len()).map_err(|_| CodecError::CodebookEntryCount {
expected,
got: u32::MAX,
bit_width,
})?;
if got != expected {
return Err(CodecError::CodebookEntryCount {
expected,
got,
bit_width,
});
}
let mut distinct: u32 = u32::from(!entries.is_empty());
let mut prev: Option<f32> = None;
for &value in &*entries {
if let Some(p) = prev {
match f32::total_cmp(&p, &value) {
Ordering::Less => distinct += 1,
Ordering::Equal => {
return Err(CodecError::CodebookDuplicate {
expected,
got: distinct,
});
}
Ordering::Greater => return Err(CodecError::CodebookNotSorted),
}
}
prev = Some(value);
}
Ok(Self {
entries: Arc::from(entries),
bit_width,
})
}
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
pub fn train(vectors: &[f32], config: &CodecConfig) -> Result<Self, CodecError> {
let num_entries_u32 = config.num_codebook_entries();
let num_entries = num_entries_u32 as usize;
if vectors.is_empty() {
return Err(CodecError::InsufficientTrainingData {
expected: num_entries_u32,
});
}
let mut flat: Vec<f64> = vectors.iter().copied().map(f64::from).collect();
flat.sort_by(f64::total_cmp);
let len = flat.len();
let last_idx = len.saturating_sub(1);
let num_entries_minus_one =
num_entries
.checked_sub(1)
.ok_or(CodecError::InsufficientTrainingData {
expected: num_entries_u32,
})?;
let divisor = num_entries_minus_one as f64;
let span = last_idx as f64;
let mut entries_f32: Vec<f32> = Vec::with_capacity(num_entries);
for k in 0..num_entries {
let q = (k as f64) / divisor;
let h = q * span;
let floor_h = libm::floor(h);
let frac = h - floor_h;
let i = floor_h as usize;
let i_plus_one = i.saturating_add(1).min(last_idx);
let lo = *flat.get(i).ok_or(CodecError::InsufficientTrainingData {
expected: num_entries_u32,
})?;
let hi = *flat
.get(i_plus_one)
.ok_or(CodecError::InsufficientTrainingData {
expected: num_entries_u32,
})?;
let value_f64 = lo + frac * (hi - lo);
entries_f32.push(value_f64 as f32);
}
entries_f32.sort_by(f32::total_cmp);
let mut distinct: u32 = 1;
let mut iter = entries_f32.iter();
if let Some(first) = iter.next() {
let mut prev = *first;
for &value in iter {
if f32::total_cmp(&prev, &value) == Ordering::Less {
distinct += 1;
}
prev = value;
}
}
if distinct < num_entries_u32 {
return Err(CodecError::InsufficientTrainingData {
expected: num_entries_u32,
});
}
Self::new(entries_f32.into_boxed_slice(), config.bit_width())
}
#[inline]
pub fn num_entries(&self) -> u32 {
u32::try_from(self.entries.len()).unwrap_or(u32::MAX)
}
#[inline]
pub const fn bit_width(&self) -> u8 {
self.bit_width
}
#[inline]
pub fn entries(&self) -> &[f32] {
&self.entries
}
pub fn quantize_into(&self, values: &[f32], indices: &mut [u8]) -> Result<(), CodecError> {
let entries = &self.entries;
#[cfg(feature = "simd")]
{
crate::codec::simd_api::quantize_into(entries, values, indices)
}
#[cfg(not(feature = "simd"))]
{
scalar_kernel::quantize_into(entries, values, indices)
}
}
pub fn dequantize_into(&self, indices: &[u8], values: &mut [f32]) -> Result<(), CodecError> {
let entries = &self.entries;
#[cfg(feature = "simd")]
{
crate::codec::simd_api::dequantize_into(entries, indices, values)
}
#[cfg(not(feature = "simd"))]
{
scalar_kernel::dequantize_into(entries, indices, values)
}
}
pub fn quantize(&self, values: &[f32]) -> Result<Vec<u8>, CodecError> {
let mut out = vec![0u8; values.len()];
self.quantize_into(values, &mut out)?;
Ok(out)
}
pub fn dequantize(&self, indices: &[u8]) -> Result<Vec<f32>, CodecError> {
let mut out = vec![0.0f32; indices.len()];
self.dequantize_into(indices, &mut out)?;
Ok(out)
}
}
impl fmt::Debug for Codebook {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Codebook")
.field("bit_width", &self.bit_width)
.field("num_entries", &self.num_entries())
.field("entries", &self.entries)
.finish()
}
}
impl PartialEq for Codebook {
fn eq(&self, other: &Self) -> bool {
if self.bit_width != other.bit_width {
return false;
}
if self.entries.len() != other.entries.len() {
return false;
}
self.entries
.iter()
.zip(other.entries.iter())
.all(|(a, b)| a.to_bits() == b.to_bits())
}
}