use super::{PqError, PqResult};
use serde::{Serialize, Deserialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Codebook {
num_subquantizers: usize,
num_centroids: usize,
subvector_dimension: usize,
centroids: Vec<Vec<Vec<f32>>>,
}
#[allow(clippy::indexing_slicing)]
impl Codebook {
pub fn new(
num_subquantizers: usize,
num_centroids: usize,
subvector_dimension: usize,
) -> Self {
let centroids = vec![
vec![vec![0.0; subvector_dimension]; num_centroids];
num_subquantizers
];
Self {
num_subquantizers,
num_centroids,
subvector_dimension,
centroids,
}
}
pub fn from_centroids(centroids: Vec<Vec<Vec<f32>>>) -> PqResult<Self> {
if centroids.is_empty() {
return Err(PqError::InvalidConfig(
"Centroids cannot be empty".to_string(),
));
}
let num_subquantizers = centroids.len();
let num_centroids = centroids[0].len();
let subvector_dimension = centroids[0][0].len();
for (sq_idx, sq_centroids) in centroids.iter().enumerate() {
if sq_centroids.len() != num_centroids {
return Err(PqError::InvalidConfig(format!(
"Sub-quantizer {} has {} centroids, expected {}",
sq_idx,
sq_centroids.len(),
num_centroids
)));
}
for (c_idx, centroid) in sq_centroids.iter().enumerate() {
if centroid.len() != subvector_dimension {
return Err(PqError::InvalidConfig(format!(
"Sub-quantizer {}, centroid {} has dimension {}, expected {}",
sq_idx,
c_idx,
centroid.len(),
subvector_dimension
)));
}
}
}
Ok(Self {
num_subquantizers,
num_centroids,
subvector_dimension,
centroids,
})
}
pub fn get_centroid(&self, subquantizer_idx: usize, centroid_idx: usize) -> PqResult<&[f32]> {
if subquantizer_idx >= self.num_subquantizers {
return Err(PqError::InvalidSubQuantizerIndex(subquantizer_idx));
}
if centroid_idx >= self.num_centroids {
return Err(PqError::InvalidCentroidIndex(centroid_idx));
}
Ok(&self.centroids[subquantizer_idx][centroid_idx])
}
pub fn set_centroid(
&mut self,
subquantizer_idx: usize,
centroid_idx: usize,
centroid: Vec<f32>,
) -> PqResult<()> {
if subquantizer_idx >= self.num_subquantizers {
return Err(PqError::InvalidSubQuantizerIndex(subquantizer_idx));
}
if centroid_idx >= self.num_centroids {
return Err(PqError::InvalidCentroidIndex(centroid_idx));
}
if centroid.len() != self.subvector_dimension {
return Err(PqError::DimensionMismatch {
expected: self.subvector_dimension,
actual: centroid.len(),
});
}
self.centroids[subquantizer_idx][centroid_idx] = centroid;
Ok(())
}
pub fn get_subquantizer_centroids(&self, subquantizer_idx: usize) -> PqResult<&[Vec<f32>]> {
if subquantizer_idx >= self.num_subquantizers {
return Err(PqError::InvalidSubQuantizerIndex(subquantizer_idx));
}
Ok(&self.centroids[subquantizer_idx])
}
pub fn num_subquantizers(&self) -> usize {
self.num_subquantizers
}
pub fn num_centroids(&self) -> usize {
self.num_centroids
}
pub fn subvector_dimension(&self) -> usize {
self.subvector_dimension
}
pub fn dimension(&self) -> usize {
self.num_subquantizers * self.subvector_dimension
}
pub fn memory_size(&self) -> usize {
self.num_subquantizers
* self.num_centroids
* self.subvector_dimension
* std::mem::size_of::<f32>()
}
pub fn find_nearest_centroid(
&self,
subquantizer_idx: usize,
subvector: &[f32],
) -> PqResult<u8> {
if subquantizer_idx >= self.num_subquantizers {
return Err(PqError::InvalidSubQuantizerIndex(subquantizer_idx));
}
if subvector.len() != self.subvector_dimension {
return Err(PqError::DimensionMismatch {
expected: self.subvector_dimension,
actual: subvector.len(),
});
}
let centroids = &self.centroids[subquantizer_idx];
let mut min_distance = f32::MAX;
let mut min_idx = 0;
for (idx, centroid) in centroids.iter().enumerate() {
let distance = l2_distance_squared(subvector, centroid);
if distance < min_distance {
min_distance = distance;
min_idx = idx;
}
}
Ok(min_idx as u8)
}
pub fn validate(&self) -> PqResult<()> {
if self.num_subquantizers == 0 {
return Err(PqError::InvalidConfig(
"num_subquantizers must be > 0".to_string(),
));
}
if self.num_centroids == 0 || self.num_centroids > 256 {
return Err(PqError::InvalidConfig(format!(
"num_centroids must be between 1 and 256, got {}",
self.num_centroids
)));
}
if self.subvector_dimension == 0 {
return Err(PqError::InvalidConfig(
"subvector_dimension must be > 0".to_string(),
));
}
for (sq_idx, sq_centroids) in self.centroids.iter().enumerate() {
for (c_idx, centroid) in sq_centroids.iter().enumerate() {
for (dim_idx, &value) in centroid.iter().enumerate() {
if !value.is_finite() {
return Err(PqError::InvalidConfig(format!(
"Non-finite value at subquantizer {}, centroid {}, dimension {}",
sq_idx, c_idx, dim_idx
)));
}
}
}
}
Ok(())
}
}
fn l2_distance_squared(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum()
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_codebook_creation() {
let codebook = Codebook::new(8, 256, 96);
assert_eq!(codebook.num_subquantizers(), 8);
assert_eq!(codebook.num_centroids(), 256);
assert_eq!(codebook.subvector_dimension(), 96);
assert_eq!(codebook.dimension(), 768);
}
#[test]
fn test_codebook_get_set_centroid() {
let mut codebook = Codebook::new(2, 4, 3);
let centroid = vec![1.0, 2.0, 3.0];
codebook.set_centroid(0, 0, centroid.clone()).unwrap();
let retrieved = codebook.get_centroid(0, 0).unwrap();
assert_eq!(retrieved, ¢roid[..]);
}
#[test]
fn test_codebook_invalid_indices() {
let codebook = Codebook::new(2, 4, 3);
assert!(codebook.get_centroid(5, 0).is_err());
assert!(codebook.get_centroid(0, 10).is_err());
}
#[test]
fn test_find_nearest_centroid() {
let mut codebook = Codebook::new(1, 3, 2);
codebook.set_centroid(0, 0, vec![0.0, 0.0]).unwrap();
codebook.set_centroid(0, 1, vec![1.0, 0.0]).unwrap();
codebook.set_centroid(0, 2, vec![0.0, 1.0]).unwrap();
let nearest = codebook
.find_nearest_centroid(0, &[0.9, 0.1])
.unwrap();
assert_eq!(nearest, 1);
let nearest = codebook
.find_nearest_centroid(0, &[0.1, 0.9])
.unwrap();
assert_eq!(nearest, 2);
}
#[test]
fn test_codebook_memory_size() {
let codebook = Codebook::new(8, 256, 96);
assert_eq!(codebook.memory_size(), 786_432);
}
#[test]
fn test_l2_distance_squared() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let dist = l2_distance_squared(&a, &b);
assert_eq!(dist, 2.0);
}
#[test]
fn test_codebook_validation() {
let codebook = Codebook::new(8, 256, 96);
assert!(codebook.validate().is_ok());
}
#[test]
fn test_codebook_from_centroids() {
let centroids = vec![
vec![vec![1.0, 2.0], vec![3.0, 4.0]],
vec![vec![5.0, 6.0], vec![7.0, 8.0]],
];
let codebook = Codebook::from_centroids(centroids).unwrap();
assert_eq!(codebook.num_subquantizers(), 2);
assert_eq!(codebook.num_centroids(), 2);
assert_eq!(codebook.subvector_dimension(), 2);
let centroid = codebook.get_centroid(0, 0).unwrap();
assert_eq!(centroid, &[1.0, 2.0]);
}
}