mod code;
mod distance;
mod sdc;
pub use code::{PQCode, PQCode8, bytes_for_nbits};
pub use distance::PQDistanceTable;
pub use sdc::SDCTable;
use std::fmt;
use std::sync::Arc;
use bb_core::{
Codec,
embedding::{Embedding, EmbeddingSpace},
index::{OpId, OpRef},
};
use bb_ml::KMeans;
#[derive(Clone)]
pub struct ProductQuantizer<S: EmbeddingSpace, const M: usize, const NBITS: usize>
where
[(); bytes_for_nbits(NBITS)]:,
{
space: S,
dsub: usize,
d: usize,
centroids: Arc<Vec<f32>>,
sdc_table: Option<Arc<SDCTable<M, NBITS>>>,
trained: bool,
next_op_id: u64,
subvec_buffer: Vec<f32>,
}
impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> fmt::Debug for ProductQuantizer<S, M, NBITS>
where
[(); bytes_for_nbits(NBITS)]:,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ProductQuantizer")
.field("M", &M)
.field("NBITS", &NBITS)
.field("ksub", &(1usize << NBITS))
.field("dsub", &self.dsub)
.field("trained", &self.trained)
.finish()
}
}
impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> PartialEq for ProductQuantizer<S, M, NBITS>
where
[(); bytes_for_nbits(NBITS)]:,
{
fn eq(&self, other: &Self) -> bool {
self.dsub == other.dsub
&& self.trained == other.trained
&& self.centroids == other.centroids
}
}
impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> Eq for ProductQuantizer<S, M, NBITS>
where
[(); bytes_for_nbits(NBITS)]:,
{}
impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> ProductQuantizer<S, M, NBITS>
where
[(); bytes_for_nbits(NBITS)]:,
<S::EmbeddingData as Embedding>::Scalar: Into<f32> + From<f32>,
{
pub const KSUB: usize = 1 << NBITS;
pub fn new(space: S) -> Self {
let d = S::EmbeddingData::length();
assert!(
d % M == 0,
"dimension {} must be divisible by M={}",
d,
M
);
let dsub = d / M;
Self {
space,
dsub,
d,
centroids: Arc::new(Vec::new()),
sdc_table: None,
trained: false,
next_op_id: 1,
subvec_buffer: vec![0.0; dsub],
}
}
pub fn space(&self) -> &S {
&self.space
}
fn alloc_op_id(&mut self) -> OpId {
let id = OpId(self.next_op_id);
self.next_op_id += 1;
id
}
pub fn m(&self) -> usize {
M
}
pub fn ksub(&self) -> usize {
Self::KSUB
}
pub fn dsub(&self) -> usize {
self.dsub
}
fn find_nearest_centroid(&self, subspace: usize) -> usize {
let ksub = Self::KSUB;
let mut best_idx = 0;
let mut best_dist = f32::MAX;
for k in 0..ksub {
let centroid_offset = (subspace * ksub + k) * self.dsub;
let centroid = &self.centroids[centroid_offset..centroid_offset + self.dsub];
let dist: f32 = self.subvec_buffer
.iter()
.zip(centroid.iter())
.map(|(&s, &c)| {
let diff = s - c;
diff * diff
})
.sum();
if dist < best_dist {
best_dist = dist;
best_idx = k;
}
}
best_idx
}
fn fill_subvec_buffer(&mut self, slice: &[<S::EmbeddingData as Embedding>::Scalar], subspace: usize) {
let start = subspace * self.dsub;
for i in 0..self.dsub {
self.subvec_buffer[i] = slice[start + i].into();
}
}
pub fn encode_embedding(&mut self, embedding: &S::EmbeddingData) -> PQCode<M, NBITS> {
assert!(self.trained, "codec must be trained before encoding");
let slice = embedding.as_slice();
let mut code = PQCode::<M, NBITS>::zeros();
for subspace in 0..M {
self.fill_subvec_buffer(slice, subspace);
let nearest = self.find_nearest_centroid(subspace);
code.set(subspace, nearest as u32);
}
code
}
pub fn decode_code(&self, code: &PQCode<M, NBITS>) -> S::EmbeddingData {
assert!(self.trained, "codec must be trained before decoding");
let ksub = Self::KSUB;
let mut result = vec![0.0f32; self.d];
for m in 0..M {
let c = code.get(m) as usize;
let centroid_offset = (m * ksub + c) * self.dsub;
let centroid = &self.centroids[centroid_offset..centroid_offset + self.dsub];
let start = m * self.dsub;
result[start..start + self.dsub].copy_from_slice(centroid);
}
let scalars: Vec<<S::EmbeddingData as Embedding>::Scalar> =
result.into_iter().map(|x| x.into()).collect();
S::EmbeddingData::from_slice(&scalars)
}
pub fn train_on(&mut self, data: &[S::EmbeddingData]) {
assert!(!data.is_empty(), "training data cannot be empty");
let ksub = Self::KSUB;
assert!(
data.len() >= ksub,
"need at least {} data points (ksub), got {}",
ksub,
data.len()
);
let mut centroids = vec![0.0; M * ksub * self.dsub];
for subspace in 0..M {
let subvectors: Vec<Vec<f32>> = data
.iter()
.map(|emb| {
let slice = emb.as_slice();
let start = subspace * self.dsub;
let end = start + self.dsub;
slice[start..end].iter().map(|&s| s.into()).collect()
})
.collect();
let kmeans = KMeans::fit(&subvectors, ksub, 25);
for (k, centroid) in kmeans.centroids.iter().enumerate() {
let offset = (subspace * ksub + k) * self.dsub;
centroids[offset..offset + self.dsub].copy_from_slice(centroid);
}
}
let sdc = SDCTable::from_centroids_with_distance(
¢roids,
self.dsub,
S::slice_distance,
);
self.centroids = Arc::new(centroids);
self.sdc_table = Some(Arc::new(sdc));
self.trained = true;
}
pub fn build_distance_table(&mut self, query: &S::EmbeddingData) -> PQDistanceTable<S, M, NBITS> {
assert!(self.trained, "codec must be trained before distance computation");
let ksub = Self::KSUB;
let query_slice = query.as_slice();
let mut table = Vec::with_capacity(M * ksub);
for subspace in 0..M {
self.fill_subvec_buffer(query_slice, subspace);
for k in 0..ksub {
let centroid_offset = (subspace * ksub + k) * self.dsub;
let centroid = &self.centroids[centroid_offset..centroid_offset + self.dsub];
let dist = S::slice_distance(&self.subvec_buffer, centroid);
table.push(S::DistanceValue::from(dist));
}
}
PQDistanceTable::new(table, ksub)
}
pub fn sdc_table(&self) -> Option<&SDCTable<M, NBITS>> {
self.sdc_table.as_deref()
}
}
pub struct EagerOpRef<T, E> {
id: OpId,
result: Option<Result<T, E>>,
cached_error: Option<E>,
}
impl<T, E: Clone> EagerOpRef<T, E> {
pub fn ok(id: OpId, value: T) -> Self {
Self {
id,
result: Some(Ok(value)),
cached_error: None,
}
}
pub fn err(id: OpId, error: E) -> Self {
Self {
id,
result: Some(Err(error.clone())),
cached_error: Some(error),
}
}
}
pub trait FinishableError: Clone {
fn already_finished() -> Self;
}
impl FinishableError for PQError {
fn already_finished() -> Self {
PQError::AlreadyFinished
}
}
impl<T, E: FinishableError> OpRef for EagerOpRef<T, E> {
type Info = ();
type Stats = ();
type Result = T;
type Error = E;
fn id(&self) -> &OpId {
&self.id
}
fn info(&self) -> Option<Self::Info> {
Some(())
}
fn stats(&self) -> Option<Self::Stats> {
Some(())
}
fn is_finished(&self) -> bool {
true
}
fn finish(&mut self) -> Result<Self::Result, Self::Error> {
match self.result.take() {
Some(result) => result,
None => {
Err(self.cached_error.clone().unwrap_or_else(E::already_finished))
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PQError {
NotTrained,
AlreadyFinished,
#[cfg(feature = "codec")]
SerializationError(String),
}
impl std::fmt::Display for PQError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PQError::NotTrained => write!(f, "codec not trained"),
PQError::AlreadyFinished => write!(f, "operation already finished"),
#[cfg(feature = "codec")]
PQError::SerializationError(e) => write!(f, "serialization error: {}", e),
}
}
}
impl std::error::Error for PQError {}
impl<S: EmbeddingSpace + Default, const M: usize, const NBITS: usize> Default
for ProductQuantizer<S, M, NBITS>
where
[(); bytes_for_nbits(NBITS)]:,
<S::EmbeddingData as Embedding>::Scalar: Into<f32> + From<f32>,
{
fn default() -> Self {
Self::new(S::default())
}
}
#[cfg(feature = "serde")]
#[derive(serde::Serialize, serde::Deserialize)]
pub struct PQCodebook {
pub centroids: Vec<f32>,
pub sdc_table: Vec<f32>,
pub sdc_ksub: usize,
pub dsub: usize,
pub d: usize,
pub m: usize,
pub nbits: usize,
}
#[cfg(feature = "codec")]
impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> ProductQuantizer<S, M, NBITS>
where
[(); bytes_for_nbits(NBITS)]:,
<S::EmbeddingData as Embedding>::Scalar: Into<f32> + From<f32>,
{
pub fn save_codebook(&self, path: &std::path::Path) -> Result<(), PQError> {
if !self.trained {
return Err(PQError::NotTrained);
}
let sdc = self.sdc_table.as_ref().ok_or(PQError::NotTrained)?;
let codebook = PQCodebook {
centroids: (*self.centroids).clone(),
sdc_table: sdc.table_data().to_vec(),
sdc_ksub: sdc.ksub(),
dsub: self.dsub,
d: self.d,
m: M,
nbits: NBITS,
};
let encoded = bincode::serialize(&codebook)
.map_err(|e| PQError::SerializationError(e.to_string()))?;
std::fs::write(path, encoded)
.map_err(|e| PQError::SerializationError(e.to_string()))?;
Ok(())
}
pub fn load_codebook(path: &std::path::Path, space: S) -> Result<Self, PQError> {
let data = std::fs::read(path)
.map_err(|e| PQError::SerializationError(e.to_string()))?;
let codebook: PQCodebook = bincode::deserialize(&data)
.map_err(|e| PQError::SerializationError(e.to_string()))?;
if codebook.m != M || codebook.nbits != NBITS {
return Err(PQError::SerializationError(format!(
"Codebook M={}, NBITS={} does not match expected M={}, NBITS={}",
codebook.m, codebook.nbits, M, NBITS
)));
}
let sdc_table = SDCTable::from_raw(codebook.sdc_table, codebook.sdc_ksub);
Ok(Self {
space,
dsub: codebook.dsub,
d: codebook.d,
centroids: Arc::new(codebook.centroids),
sdc_table: Some(Arc::new(sdc_table)),
trained: true,
next_op_id: 0,
subvec_buffer: vec![0.0; codebook.dsub],
})
}
}
impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> Codec<S> for ProductQuantizer<S, M, NBITS>
where
[(); bytes_for_nbits(NBITS)]:,
<S::EmbeddingData as Embedding>::Scalar: Into<f32> + From<f32>,
{
type Encoded = PQCode<M, NBITS>;
type EncodeRef<'b> = EagerOpRef<PQCode<M, NBITS>, PQError> where Self: 'b;
type DecodeRef<'b> = EagerOpRef<S::EmbeddingData, PQError> where Self: 'b;
type TrainRef<'b> = EagerOpRef<(), PQError> where Self: 'b;
type ObserveRef<'b> = EagerOpRef<(), PQError> where Self: 'b;
fn encode(&mut self, embedding: &S::EmbeddingData) -> Self::EncodeRef<'_> {
let id = OpId(0);
if !self.trained {
return EagerOpRef::err(id, PQError::NotTrained);
}
EagerOpRef::ok(id, self.encode_embedding(embedding))
}
fn encode_batch(&mut self, embeddings: &[S::EmbeddingData]) -> Vec<Self::EncodeRef<'_>> {
embeddings.iter().map(|e| {
let id = OpId(0);
if !self.trained {
return EagerOpRef::err(id, PQError::NotTrained);
}
EagerOpRef::ok(id, self.encode_embedding(e))
}).collect()
}
fn decode(&self, encoded: &Self::Encoded) -> Self::DecodeRef<'_> {
let id = OpId(0);
if !self.trained {
return EagerOpRef::err(id, PQError::NotTrained);
}
EagerOpRef::ok(id, self.decode_code(encoded))
}
fn decode_batch(&self, encoded: &[Self::Encoded]) -> Vec<Self::DecodeRef<'_>> {
encoded.iter().map(|e| self.decode(e)).collect()
}
fn code_size(&self) -> Option<usize> {
Some(PQCode::<M, NBITS>::TOTAL_BYTES)
}
fn train(&mut self, embeddings: &[S::EmbeddingData]) -> Self::TrainRef<'_> {
let id = self.alloc_op_id();
self.train_on(embeddings);
EagerOpRef::ok(id, ())
}
fn observe(&mut self, _embedding: &S::EmbeddingData) -> Self::ObserveRef<'_> {
EagerOpRef::ok(self.alloc_op_id(), ())
}
fn observe_batch(&mut self, embeddings: &[S::EmbeddingData]) -> Vec<Self::ObserveRef<'_>> {
embeddings.iter().map(|_| self.observe(&S::EmbeddingData::zeros())).collect()
}
fn is_trained(&self) -> bool {
self.trained
}
}
impl<S: EmbeddingSpace, const M: usize, const NBITS: usize> EmbeddingSpace for ProductQuantizer<S, M, NBITS>
where
[(); bytes_for_nbits(NBITS)]:,
<S::EmbeddingData as Embedding>::Scalar: Into<f32> + From<f32>,
{
type EmbeddingData = PQCode<M, NBITS>;
type DistanceValue = S::DistanceValue;
type Prepared = PQCode<M, NBITS>;
fn space_id(&self) -> &'static str {
"pq"
}
fn distance(&self, lhs: &Self::EmbeddingData, rhs: &Self::EmbeddingData) -> Self::DistanceValue {
let sdc = self.sdc_table.as_ref().expect("ProductQuantizer must be trained before computing distances");
S::DistanceValue::from(sdc.distance(lhs, rhs))
}
fn prepare(&self, embedding: &Self::EmbeddingData) -> Self::Prepared {
embedding.clone()
}
fn distance_prepared(
&self,
prepared: &Self::Prepared,
target: &Self::EmbeddingData,
) -> Self::DistanceValue {
self.distance(prepared, target)
}
fn length() -> usize {
PQCode::<M, NBITS>::TOTAL_BYTES
}
fn slice_distance(a: &[f32], b: &[f32]) -> f32 {
S::slice_distance(a, b)
}
fn infinite_mapping(native_distance: &Self::DistanceValue) -> f32 {
S::infinite_mapping(native_distance)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bb_core::embedding::{F32Embedding, F32L2Space};
type Space = F32L2Space<8>;
fn make_test_vectors(n: usize) -> Vec<F32Embedding<8>> {
(0..n)
.map(|i| {
let val = i as f32;
F32Embedding([val, val + 0.1, val + 0.2, val + 0.3, val + 0.4, val + 0.5, val + 0.6, val + 0.7])
})
.collect()
}
#[test]
fn test_pq_creation() {
let space = F32L2Space::<8>;
let pq = ProductQuantizer::<Space, 2, 8>::new(space);
assert_eq!(pq.m(), 2);
assert_eq!(pq.ksub(), 256);
assert_eq!(pq.dsub(), 4);
assert!(!pq.is_trained());
}
#[test]
fn test_pq_creation_nbits2() {
let space = F32L2Space::<8>;
let pq = ProductQuantizer::<Space, 2, 2>::new(space);
assert_eq!(pq.m(), 2);
assert_eq!(pq.ksub(), 4); assert_eq!(pq.dsub(), 4);
assert!(!pq.is_trained());
}
#[test]
fn test_pq_train_and_encode() {
let space = F32L2Space::<8>;
let mut pq = ProductQuantizer::<Space, 2, 2>::new(space); let data = make_test_vectors(100);
pq.train_on(&data);
assert!(pq.is_trained());
let code = pq.encode_embedding(&data[0]);
assert_eq!(code.m(), 2);
}
#[test]
fn test_pq_encode_decode() {
let space = F32L2Space::<8>;
let mut pq = ProductQuantizer::<Space, 2, 4>::new(space);
let data = make_test_vectors(100);
pq.train_on(&data);
let original = &data[50];
let code = pq.encode_embedding(original);
let decoded = pq.decode_code(&code);
let orig_slice = original.as_slice();
let dec_slice = decoded.as_slice();
for i in 0..8 {
let diff = (orig_slice[i] - dec_slice[i]).abs();
assert!(diff < 10.0, "dimension {} differs by {}", i, diff);
}
}
#[test]
fn test_pq_adc_distance() {
let space = F32L2Space::<8>;
let mut pq = ProductQuantizer::<Space, 2, 2>::new(space);
let data = make_test_vectors(100);
pq.train_on(&data);
let query = &data[0];
let target = pq.encode_embedding(&data[1]);
let table = pq.build_distance_table(query);
let dist = table.distance(&target);
assert!(dist.value() >= 0.0);
}
#[test]
fn test_pq_sdc() {
let space = F32L2Space::<8>;
let mut pq = ProductQuantizer::<Space, 2, 2>::new(space);
let data = make_test_vectors(100);
pq.train_on(&data);
let code1 = pq.encode_embedding(&data[0]);
let code2 = pq.encode_embedding(&data[1]);
let sdc = pq.sdc_table().expect("should have SDC table after training");
assert_eq!(sdc.distance(&code1, &code1), 0.0);
let dist = sdc.distance(&code1, &code2);
assert!(dist >= 0.0);
}
#[test]
fn test_pq_embedding_space() {
let space = F32L2Space::<8>;
let mut pq = ProductQuantizer::<Space, 2, 2>::new(space);
let data = make_test_vectors(100);
pq.train_on(&data);
let code1 = pq.encode_embedding(&data[0]);
let code2 = pq.encode_embedding(&data[1]);
let dist = pq.distance(&code1, &code2);
assert!(dist.value() >= 0.0);
assert_eq!(pq.distance(&code1, &code1).value(), 0.0);
let prepared = pq.prepare(&code1);
let dist_prepared = pq.distance_prepared(&prepared, &code2);
assert_eq!(dist, dist_prepared);
}
#[test]
fn test_codec_trait() {
let space = F32L2Space::<8>;
let mut pq = ProductQuantizer::<Space, 2, 2>::new(space);
let data = make_test_vectors(100);
let mut train_ref = pq.train(&data);
assert!(train_ref.is_finished());
train_ref.finish().unwrap();
assert!(pq.is_trained());
let mut encode_ref = pq.encode(&data[0]);
assert!(encode_ref.is_finished());
let code = encode_ref.finish().unwrap();
let mut decode_ref = pq.decode(&code);
assert!(decode_ref.is_finished());
let _decoded = decode_ref.finish().unwrap();
}
#[test]
fn test_code_size() {
let pq = ProductQuantizer::<Space, 4, 8>::new(F32L2Space::<8>);
assert_eq!(pq.code_size(), Some(4));
let pq2 = ProductQuantizer::<Space, 4, 10>::new(F32L2Space::<8>);
assert_eq!(pq2.code_size(), Some(8));
}
#[test]
#[ignore] fn test_pq_nbits10() {
let space = F32L2Space::<8>;
let mut pq = ProductQuantizer::<Space, 2, 10>::new(space);
let data: Vec<F32Embedding<8>> = (0..2000)
.map(|i| {
let val = (i as f32) * 0.01;
F32Embedding([val, val + 0.1, val + 0.2, val + 0.3, val + 0.4, val + 0.5, val + 0.6, val + 0.7])
})
.collect();
pq.train_on(&data);
assert!(pq.is_trained());
assert_eq!(pq.ksub(), 1024);
let code = pq.encode_embedding(&data[500]);
let idx0 = code.get(0);
let idx1 = code.get(1);
assert!(idx0 < 1024);
assert!(idx1 < 1024);
}
}