use super::*;
use crate::utils::kmeans::{ClusterID, KMeans, Vectors};
use rand::seq::IteratorRandom;
use std::rc::Rc;
#[derive(Debug, Serialize, Deserialize)]
pub struct IndexIVFPQ {
params: ParamsIVFPQ,
metadata: IndexMetadata,
data: HashMap<RecordID, RecordPQ>,
centroids: Vec<Vector>,
clusters: Vec<Vec<RecordID>>,
codebook: Vec<Vec<Vector>>,
}
impl IndexIVFPQ {
fn create_codebook(&mut self, vectors: Vectors) {
for i in 0..self.params.sub_dimension {
let mut subvectors = Vec::new();
for vector in vectors.iter() {
let subvector = self.get_subvector(i.into(), vector);
subvectors.push(subvector);
}
let centroids = {
let mut kmeans = KMeans::new(
self.params.sub_centroids,
self.params.max_iterations,
self.params.metric,
);
let subvectors: Vec<&Vector> = subvectors.iter().collect();
kmeans.fit(Rc::from(subvectors));
kmeans.centroids().to_vec()
};
self.codebook[i as usize] = centroids
.par_iter()
.map(|centroid| centroid.to_owned())
.collect();
}
}
fn find_nearest_centroid(&self, vector: &Vector) -> ClusterID {
self.centroids
.par_iter()
.enumerate()
.map(|(i, centroid)| (i, self.metric().distance(vector, centroid)))
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| ClusterID(i as u16))
.unwrap_or_default()
}
fn find_nearest_code(
&self,
part_index: usize,
subvector: &Vector,
) -> usize {
self.codebook[part_index]
.par_iter()
.enumerate()
.map(|(i, code)| (i, self.metric().distance(subvector, code)))
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
.unwrap_or_default()
}
fn quantize_vector(&self, vector: &Vector) -> VectorPQ {
(0..self.params.sub_dimension as usize)
.into_par_iter()
.map(|i| {
let subvector = self.get_subvector(i, vector);
self.find_nearest_code(i, &subvector) as u8
})
.collect::<Vec<u8>>()
.into()
}
fn dequantize_vector(&self, vector_pq: &VectorPQ) -> Vector {
vector_pq
.0
.par_iter()
.enumerate()
.map(|(i, code_id)| self.codebook[i][*code_id as usize].to_vec())
.flatten()
.collect::<Vec<f32>>()
.into()
}
fn get_subvector(&self, part_index: usize, vector: &Vector) -> Vector {
let dim = vector.len() / self.params.sub_dimension as usize;
let start = part_index * dim;
let end = (part_index + 1) * dim;
let subvector = vector.0[start..end].to_vec();
Vector(subvector.into_boxed_slice())
}
}
impl IndexOps for IndexIVFPQ {
fn new(params: impl IndexParams) -> Result<Self, Error> {
let params = downcast_params::<ParamsIVFPQ>(params)?;
let codebook = vec![vec![]; params.sub_dimension as usize];
let clusters = vec![vec![]; params.centroids];
if params.sampling <= 0.0 || params.sampling > 1.0 {
let code = ErrorCode::RequestError;
let message = "Sampling must be between 0.0 and 1.0.";
return Err(Error::new(code, message));
}
let index = IndexIVFPQ {
params,
metadata: IndexMetadata::default(),
data: HashMap::new(),
centroids: vec![],
clusters,
codebook,
};
Ok(index)
}
}
impl VectorIndex for IndexIVFPQ {
fn metric(&self) -> &DistanceMetric {
&self.params.metric
}
fn metadata(&self) -> &IndexMetadata {
&self.metadata
}
fn build(
&mut self,
records: HashMap<RecordID, Record>,
) -> Result<(), Error> {
let mut rng = rand::thread_rng();
let sample = (records.len() as f32 * self.params.sampling) as usize;
let vectors = records
.values()
.choose_multiple(&mut rng, sample)
.par_iter()
.map(|&record| &record.vector)
.collect::<Vec<&Vector>>();
let vectors: Vectors = Rc::from(vectors.as_slice());
self.create_codebook(vectors.clone());
let centroids = {
let mut kmeans = KMeans::new(
self.params.centroids,
self.params.max_iterations,
self.metric().to_owned(),
);
kmeans.fit(vectors.clone());
kmeans.centroids().to_vec()
};
self.centroids = centroids;
self.metadata.built = true;
self.insert(records)?;
Ok(())
}
fn insert(
&mut self,
records: HashMap<RecordID, Record>,
) -> Result<(), Error> {
if records.is_empty() {
return Ok(());
}
if !self.metadata().built {
let code = ErrorCode::RequestError;
let message = "Unable to insert records into an unbuilt index.";
return Err(Error::new(code, message));
}
let records: HashMap<RecordID, Record> = records
.into_iter()
.filter(|(id, _)| !self.data.contains_key(id))
.collect();
for (id, record) in records.iter() {
let vector = &record.vector;
let cid = self.find_nearest_centroid(vector).to_usize();
let count = self.clusters[cid].len().max(1) as f32;
let new_count = count + 1.0;
let centroid: Vec<f32> = self.centroids[cid]
.to_vec()
.par_iter()
.zip(vector.to_vec().par_iter())
.map(|(c, v)| ((c * count) + v) / new_count)
.collect();
self.centroids[cid] = centroid.into();
self.clusters[cid].push(id.to_owned());
}
self.metadata.last_inserted = records.keys().max().copied();
let records: HashMap<RecordID, RecordPQ> = records
.into_par_iter()
.map(|(id, record)| {
let vector = self.quantize_vector(&record.vector);
let data = record.data;
(id, RecordPQ { vector, data })
})
.collect();
self.data.par_extend(records);
Ok(())
}
fn update(
&mut self,
records: HashMap<RecordID, Record>,
) -> Result<(), Error> {
let records: HashMap<RecordID, Record> = records
.into_iter()
.filter(|(id, _)| self.data.contains_key(id))
.collect();
let ids: Vec<RecordID> = records.keys().cloned().collect();
self.delete(ids)?;
self.insert(records)
}
fn delete(&mut self, ids: Vec<RecordID>) -> Result<(), Error> {
self.data.retain(|id, _| !ids.contains(id));
self.clusters.par_iter_mut().for_each(|cluster| {
cluster.retain(|id| !ids.contains(id));
});
Ok(())
}
fn search(
&self,
query: Vector,
k: usize,
filters: Filters,
) -> Result<Vec<SearchResult>, Error> {
let mut centroid_distances: Vec<(usize, f32)> = self
.centroids
.par_iter()
.enumerate()
.map(|(i, centroid)| (i, self.metric().distance(centroid, &query)))
.collect();
centroid_distances
.par_sort_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap());
let nearest_centroids: Vec<ClusterID> = centroid_distances
.iter()
.take(self.params.centroids)
.map(|(i, _)| (*i).into())
.collect();
let mut probes = 0;
let mut results = BinaryHeap::new();
for centroid_id in nearest_centroids {
if probes >= self.params.num_probes {
break;
}
let cluster = &self.clusters[centroid_id.to_usize()];
if cluster.is_empty() {
continue;
}
probes += 1;
for &record_id in cluster {
let record = self.data.get(&record_id).unwrap();
let data = record.data.clone();
if !filters.apply(&data) {
continue;
}
let vector = self.dequantize_vector(&record.vector);
let distance = self.metric().distance(&vector, &query);
results.push(SearchResult { id: record_id, distance, data });
if results.len() > k {
results.pop();
}
}
}
Ok(results.into_sorted_vec())
}
fn len(&self) -> usize {
self.data.len()
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParamsIVFPQ {
pub centroids: usize,
pub max_iterations: usize,
pub sub_centroids: usize,
pub sub_dimension: u8,
pub num_probes: u8,
pub sampling: f32,
pub metric: DistanceMetric,
}
impl Default for ParamsIVFPQ {
fn default() -> Self {
Self {
centroids: 512,
max_iterations: 100,
sub_centroids: 256,
sub_dimension: 8,
num_probes: 16,
sampling: 0.25,
metric: DistanceMetric::Euclidean,
}
}
}
impl IndexParams for ParamsIVFPQ {
fn metric(&self) -> &DistanceMetric {
&self.metric
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_product_quantization() {
let data: Vec<Vector> = vec![
vec![1.0, 2.0, 3.0, 4.0].into(),
vec![5.0, 6.0, 7.0, 8.0].into(),
vec![9.0, 10.0, 11.0, 12.0].into(),
vec![13.0, 14.0, 15.0, 16.0].into(),
];
let vectors: Vectors = {
let data = data.iter().collect::<Vec<&Vector>>();
Rc::from(data.as_slice())
};
let params = ParamsIVFPQ {
max_iterations: 10,
sub_centroids: 8,
sub_dimension: 2,
sampling: 1.0,
..Default::default()
};
let mut index = IndexIVFPQ::new(params).unwrap();
index.create_codebook(vectors);
let encoded = index.quantize_vector(&data[0]);
let decoded = index.dequantize_vector(&encoded);
assert_eq!(decoded.to_vec(), data[0].to_vec());
}
#[test]
fn test_ivfpq_index() {
let params = ParamsIVFPQ {
centroids: 5,
sub_centroids: 16,
max_iterations: 20,
sampling: 0.5,
..Default::default()
};
let mut index = IndexIVFPQ::new(params).unwrap();
index_tests::test_index(&mut index);
}
}