use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use arrow_array::{
cast::as_primitive_array, types::Float32Type, Array, ArrayRef, FixedSizeListArray,
Float32Array, RecordBatch, UInt64Array, UInt8Array,
};
use arrow_ord::sort::sort_to_indices;
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use arrow_select::take::take;
use async_trait::async_trait;
use lance_core::{
io::{read_fixed_stride_array, Reader},
ROW_ID_FIELD,
};
pub use lance_index::vector::pq::{PQBuildParams, ProductQuantizer};
use lance_index::vector::{Query, DIST_COL};
use lance_linalg::{
distance::{l2::l2_distance_batch, norm_l2::norm_l2, Dot, MetricType, Normalize},
matrix::MatrixView,
};
use serde::Serialize;
use tracing::instrument;
use super::VectorIndex;
use crate::index::{pb, prefilter::PreFilter, Index};
use crate::{arrow::*, utils::tokio::spawn_cpu};
use crate::{Error, Result};
#[derive(Clone)]
pub struct PQIndex {
pub nbits: u32,
pub num_sub_vectors: usize,
pub dimension: usize,
pub pq: Arc<ProductQuantizer>,
pub code: Option<Arc<UInt8Array>>,
pub row_ids: Option<Arc<UInt64Array>>,
metric_type: MetricType,
}
impl std::fmt::Debug for PQIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"PQ(m={}, nbits={}, {})",
self.num_sub_vectors, self.nbits, self.metric_type
)
}
}
impl PQIndex {
pub(crate) fn new(pq: Arc<ProductQuantizer>, metric_type: MetricType) -> Self {
Self {
nbits: pq.num_bits,
num_sub_vectors: pq.num_sub_vectors,
dimension: pq.dimension,
code: None,
row_ids: None,
pq,
metric_type,
}
}
fn fast_l2_distances(&self, key: &Float32Array, code: &UInt8Array) -> Result<ArrayRef> {
let mut distance_table: Vec<f32> = vec![];
let sub_vector_length = self.dimension / self.num_sub_vectors;
for i in 0..self.num_sub_vectors {
let from = key.slice(i * sub_vector_length, sub_vector_length);
let subvec_centroids = self.pq.centroids(i);
let distances = l2_distance_batch(
as_primitive_array::<Float32Type>(&from).values(),
subvec_centroids,
sub_vector_length,
);
distance_table.extend(distances.values());
}
Ok(Arc::new(unsafe {
Float32Array::from_trusted_len_iter(
code.values().chunks_exact(self.num_sub_vectors).map(|c| {
Some(
c.iter()
.enumerate()
.map(|(sub_vec_idx, centroid)| {
distance_table[sub_vec_idx * 256 + *centroid as usize]
})
.sum::<f32>(),
)
}),
)
}))
}
fn cosine_distances(&self, query: &Float32Array, code: &UInt8Array) -> Result<ArrayRef> {
let num_centroids = ProductQuantizer::num_centroids(self.nbits);
let mut xy_table: Vec<f32> = Vec::with_capacity(self.num_sub_vectors * num_centroids);
let mut y2_table: Vec<f32> = Vec::with_capacity(self.num_sub_vectors * num_centroids);
let x_norm = norm_l2(query.values());
let sub_vector_length = self.dimension / self.num_sub_vectors;
for i in 0..self.num_sub_vectors {
let key_sub_vector =
&query.values()[i * sub_vector_length..(i + 1) * sub_vector_length];
let sub_vector_centroids = self.pq.centroids(i);
let xy = sub_vector_centroids
.chunks_exact(sub_vector_length)
.map(|cent| cent.dot(key_sub_vector));
xy_table.extend(xy);
let y2 = sub_vector_centroids
.chunks_exact(sub_vector_length)
.map(|cent| cent.norm_l2().powf(2.0));
y2_table.extend(y2);
}
Ok(Arc::new(Float32Array::from_iter(
code.values().chunks_exact(self.num_sub_vectors).map(|c| {
let xy = c
.iter()
.enumerate()
.map(|(sub_vec_idx, centroid)| {
let idx = sub_vec_idx * num_centroids + *centroid as usize;
xy_table[idx]
})
.sum::<f32>();
let y2 = c
.iter()
.enumerate()
.map(|(sub_vec_idx, centroid)| {
let idx = sub_vec_idx * num_centroids + *centroid as usize;
y2_table[idx]
})
.sum::<f32>();
1.0 - xy / (x_norm * y2.sqrt())
}),
)))
}
fn filter_arrays(
pre_filter: &PreFilter,
code: Arc<UInt8Array>,
row_ids: Arc<UInt64Array>,
num_sub_vectors: i32,
) -> Result<(Arc<UInt8Array>, Arc<UInt64Array>)> {
let indices_to_keep = pre_filter.filter_row_ids(row_ids.values());
let indices_to_keep = UInt64Array::from(indices_to_keep);
let row_ids = take(row_ids.as_ref(), &indices_to_keep, None)?;
let row_ids = Arc::new(as_primitive_array(&row_ids).clone());
let code = FixedSizeListArray::try_new_from_values(code.as_ref().clone(), num_sub_vectors)
.unwrap();
let code = take(&code, &indices_to_keep, None)?;
let code = as_fixed_size_list_array(&code).values().clone();
let code = Arc::new(as_primitive_array(&code).clone());
Ok((code, row_ids))
}
}
#[derive(Serialize)]
pub struct PQIndexStatistics {
index_type: String,
nbits: u32,
num_sub_vectors: usize,
dimension: usize,
metric_type: String,
}
impl Index for PQIndex {
fn as_any(&self) -> &dyn Any {
self
}
fn statistics(&self) -> Result<serde_json::Value> {
Ok(serde_json::to_value(PQIndexStatistics {
index_type: "PQ".to_string(),
nbits: self.nbits,
num_sub_vectors: self.num_sub_vectors,
dimension: self.dimension,
metric_type: self.metric_type.to_string(),
})?)
}
}
#[async_trait]
impl VectorIndex for PQIndex {
#[instrument(level = "debug", skip_all, name = "PQIndex::search")]
async fn search(&self, query: &Query, pre_filter: Arc<PreFilter>) -> Result<RecordBatch> {
if self.code.is_none() || self.row_ids.is_none() {
return Err(Error::Index {
message: "PQIndex::search: PQ is not initialized".to_string(),
});
}
pre_filter.wait_for_ready().await?;
let code = self.code.as_ref().unwrap().clone();
let row_ids = self.row_ids.as_ref().unwrap().clone();
let this = self.clone();
let query = query.clone();
let num_sub_vectors = self.pq.num_sub_vectors as i32;
spawn_cpu(move || {
let (code, row_ids) = if pre_filter.is_empty() {
Ok((code, row_ids))
} else {
Self::filter_arrays(pre_filter.as_ref(), code, row_ids, num_sub_vectors)
}?;
let distances = if this.metric_type == MetricType::L2 {
this.fast_l2_distances(&query.key, code.as_ref())?
} else {
this.cosine_distances(&query.key, code.as_ref())?
};
debug_assert_eq!(distances.len(), row_ids.len());
let limit = query.k * query.refine_factor.unwrap_or(1) as usize;
let indices = sort_to_indices(&distances, None, Some(limit))?;
let distances = take(&distances, &indices, None)?;
let row_ids = take(row_ids.as_ref(), &indices, None)?;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new(DIST_COL, DataType::Float32, true),
ROW_ID_FIELD.clone(),
]));
Ok(RecordBatch::try_new(schema, vec![distances, row_ids])?)
})
.await
}
fn is_loadable(&self) -> bool {
true
}
async fn load(
&self,
reader: &dyn Reader,
offset: usize,
length: usize,
) -> Result<Box<dyn VectorIndex>> {
let pq_code_length = self.pq.num_sub_vectors * length;
let pq_code =
read_fixed_stride_array(reader, &DataType::UInt8, offset, pq_code_length, ..).await?;
let row_id_offset = offset + pq_code_length ;
let row_ids =
read_fixed_stride_array(reader, &DataType::UInt64, row_id_offset, length, ..).await?;
Ok(Box::new(Self {
nbits: self.pq.num_bits,
num_sub_vectors: self.pq.num_sub_vectors,
dimension: self.pq.dimension,
code: Some(Arc::new(as_primitive_array(&pq_code).clone())),
row_ids: Some(Arc::new(as_primitive_array(&row_ids).clone())),
pq: self.pq.clone(),
metric_type: self.metric_type,
}))
}
fn check_can_remap(&self) -> Result<()> {
Ok(())
}
fn remap(&mut self, mapping: &HashMap<u64, Option<u64>>) -> Result<()> {
let code = self
.code
.as_ref()
.unwrap()
.values()
.chunks_exact(self.num_sub_vectors);
let row_ids = self.row_ids.as_ref().unwrap().values().iter();
let remapped = row_ids
.zip(code)
.filter_map(|(old_row_id, code)| {
let new_row_id = mapping.get(old_row_id).cloned();
let new_row_id = new_row_id.unwrap_or(Some(*old_row_id));
new_row_id.map(|new_row_id| (new_row_id, code))
})
.collect::<Vec<_>>();
self.row_ids = Some(Arc::new(UInt64Array::from_iter_values(
remapped.iter().map(|(row_id, _)| *row_id),
)));
self.code = Some(Arc::new(UInt8Array::from_iter_values(
remapped.into_iter().flat_map(|(_, code)| code).copied(),
)));
Ok(())
}
}
#[allow(clippy::fallible_impl_from)]
impl From<&ProductQuantizer> for pb::Pq {
fn from(pq: &ProductQuantizer) -> Self {
Self {
num_bits: pq.num_bits,
num_sub_vectors: pq.num_sub_vectors as u32,
dimension: pq.dimension as u32,
codebook: pq.codebook.values().to_vec(),
}
}
}
pub(crate) async fn train_pq(
data: &MatrixView<Float32Type>,
params: &PQBuildParams,
) -> Result<ProductQuantizer> {
ProductQuantizer::train(data, params).await
}
#[cfg(test)]
mod tests {}