use std::any::Any;
use std::sync::Arc;
use arrow_array::{
cast::{as_primitive_array, AsArray},
Array, FixedSizeListArray, RecordBatch, UInt64Array, UInt8Array,
};
use arrow_ord::sort::sort_to_indices;
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use arrow_select::take::take;
use async_trait::async_trait;
use lance_core::{
format::RowAddress,
io::{read_fixed_stride_array, Reader},
ROW_ID_FIELD,
};
pub use lance_index::vector::pq::{PQBuildParams, ProductQuantizerImpl};
use lance_index::{
vector::{pq::ProductQuantizer, Query, DIST_COL},
Index, IndexType,
};
use lance_linalg::distance::MetricType;
use nohash_hasher::IntMap;
use roaring::RoaringBitmap;
use serde::Serialize;
use snafu::{location, Location};
use tracing::{instrument, Instrument};
use super::VectorIndex;
use crate::index::prefilter::PreFilter;
use crate::{arrow::*, utils::tokio::spawn_cpu};
use crate::{Error, Result};
#[derive(Clone)]
pub struct PQIndex {
pub pq: Arc<dyn ProductQuantizer>,
pub code: Option<Arc<UInt8Array>>,
pub row_ids: Option<Arc<UInt64Array>>,
metric_type: MetricType,
}
impl std::fmt::Debug for PQIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"PQ(m={}, nbits={}, {})",
self.pq.num_sub_vectors(),
self.pq.num_bits(),
self.metric_type
)
}
}
impl PQIndex {
pub(crate) fn new(pq: Arc<dyn ProductQuantizer>, metric_type: MetricType) -> Self {
Self {
code: None,
row_ids: None,
pq,
metric_type,
}
}
fn filter_arrays(
pre_filter: &PreFilter,
code: Arc<UInt8Array>,
row_ids: Arc<UInt64Array>,
num_sub_vectors: i32,
) -> Result<(Arc<UInt8Array>, Arc<UInt64Array>)> {
let indices_to_keep = pre_filter.filter_row_ids(row_ids.values());
let indices_to_keep = UInt64Array::from(indices_to_keep);
let row_ids = take(row_ids.as_ref(), &indices_to_keep, None)?;
let row_ids = Arc::new(as_primitive_array(&row_ids).clone());
let code = FixedSizeListArray::try_new_from_values(code.as_ref().clone(), num_sub_vectors)
.unwrap();
let code = take(&code, &indices_to_keep, None)?;
let code = as_fixed_size_list_array(&code).values().clone();
let code = Arc::new(as_primitive_array(&code).clone());
Ok((code, row_ids))
}
}
#[derive(Serialize)]
pub struct PQIndexStatistics {
index_type: String,
nbits: u32,
num_sub_vectors: usize,
dimension: usize,
metric_type: String,
}
#[async_trait]
impl Index for PQIndex {
fn as_any(&self) -> &dyn Any {
self
}
fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
self
}
fn index_type(&self) -> IndexType {
IndexType::Vector
}
fn statistics(&self) -> Result<String> {
Ok(serde_json::to_string(&PQIndexStatistics {
index_type: "PQ".to_string(),
nbits: self.pq.num_bits(),
num_sub_vectors: self.pq.num_sub_vectors(),
dimension: self.pq.dimension(),
metric_type: self.metric_type.to_string(),
})?)
}
async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
if let Some(row_ids) = &self.row_ids {
let mut frag_ids = row_ids
.values()
.iter()
.map(|&row_id| RowAddress::new_from_id(row_id).fragment_id())
.collect::<Vec<_>>();
frag_ids.sort();
frag_ids.dedup();
Ok(RoaringBitmap::from_sorted_iter(frag_ids).unwrap())
} else {
Err(Error::Index {
message: "PQIndex::caclulate_included_frags: PQ is not initialized".to_string(),
location: location!(),
})
}
}
}
#[async_trait]
impl VectorIndex for PQIndex {
#[instrument(level = "debug", skip_all, name = "PQIndex::search")]
async fn search(&self, query: &Query, pre_filter: Arc<PreFilter>) -> Result<RecordBatch> {
if self.code.is_none() || self.row_ids.is_none() {
return Err(Error::Index {
message: "PQIndex::search: PQ is not initialized".to_string(),
location: location!(),
});
}
pre_filter.wait_for_ready().await?;
let code = self.code.as_ref().unwrap().clone();
let row_ids = self.row_ids.as_ref().unwrap().clone();
let pq = self.pq.clone();
let query = query.clone();
let num_sub_vectors = self.pq.num_sub_vectors() as i32;
spawn_cpu(move || {
let (code, row_ids) = if pre_filter.is_empty() {
Ok((code, row_ids))
} else {
Self::filter_arrays(pre_filter.as_ref(), code, row_ids, num_sub_vectors)
}?;
let distances = pq.build_distance_table(query.key.as_ref(), &code)?;
debug_assert_eq!(distances.len(), row_ids.len());
let limit = query.k * query.refine_factor.unwrap_or(1) as usize;
let indices = sort_to_indices(&distances, None, Some(limit))?;
let distances = take(&distances, &indices, None)?;
let row_ids = take(row_ids.as_ref(), &indices, None)?;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new(DIST_COL, DataType::Float32, true),
ROW_ID_FIELD.clone(),
]));
Ok(RecordBatch::try_new(schema, vec![distances, row_ids])?)
})
.in_current_span()
.await
}
fn is_loadable(&self) -> bool {
true
}
async fn load(
&self,
reader: &dyn Reader,
offset: usize,
length: usize,
) -> Result<Box<dyn VectorIndex>> {
let pq_code_length = self.pq.num_sub_vectors() * length;
let pq_code =
read_fixed_stride_array(reader, &DataType::UInt8, offset, pq_code_length, ..).await?;
let row_id_offset = offset + pq_code_length ;
let row_ids =
read_fixed_stride_array(reader, &DataType::UInt64, row_id_offset, length, ..).await?;
Ok(Box::new(Self {
code: Some(Arc::new(pq_code.as_primitive().clone())),
row_ids: Some(Arc::new(row_ids.as_primitive().clone())),
pq: self.pq.clone(),
metric_type: self.metric_type,
}))
}
fn check_can_remap(&self) -> Result<()> {
Ok(())
}
fn remap(&mut self, mapping: &IntMap<u64, Option<u64>>) -> Result<()> {
let code = self
.code
.as_ref()
.unwrap()
.values()
.chunks_exact(self.pq.num_sub_vectors());
let row_ids = self.row_ids.as_ref().unwrap().values().iter();
let remapped = row_ids
.zip(code)
.filter_map(|(old_row_id, code)| {
let new_row_id = mapping.get(old_row_id).cloned();
let new_row_id = new_row_id.unwrap_or(Some(*old_row_id));
new_row_id.map(|new_row_id| (new_row_id, code))
})
.collect::<Vec<_>>();
self.row_ids = Some(Arc::new(UInt64Array::from_iter_values(
remapped.iter().map(|(row_id, _)| *row_id),
)));
self.code = Some(Arc::new(UInt8Array::from_iter_values(
remapped.into_iter().flat_map(|(_, code)| code).copied(),
)));
Ok(())
}
}
#[cfg(test)]
mod tests {}