use std::sync::Arc;
use std::{any::Any, collections::HashMap};
use arrow::compute::concat;
use arrow_array::UInt32Array;
use arrow_array::{
cast::{as_primitive_array, AsArray},
Array, FixedSizeListArray, RecordBatch, UInt64Array, UInt8Array,
};
use arrow_ord::sort::sort_to_indices;
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use arrow_select::take::take;
use async_trait::async_trait;
use deepsize::DeepSizeOf;
use lance_core::utils::tokio::spawn_cpu;
use lance_core::ROW_ID;
use lance_core::{utils::address::RowAddress, ROW_ID_FIELD};
use lance_index::vector::ivf::storage::IvfModel;
use lance_index::vector::pq::storage::{transpose, ProductQuantizationStorage};
use lance_index::vector::quantizer::{Quantization, QuantizationType, Quantizer};
use lance_index::vector::v3::subindex::SubIndexType;
use lance_index::{
vector::{pq::ProductQuantizer, Query, DIST_COL},
Index, IndexType,
};
use lance_io::{traits::Reader, utils::read_fixed_stride_array};
use lance_linalg::distance::{DistanceType, MetricType};
use log::info;
use roaring::RoaringBitmap;
use serde_json::json;
use snafu::{location, Location};
use tracing::{instrument, span, Level};
pub use lance_index::vector::pq::PQBuildParams;
use lance_linalg::kernels::normalize_fsl;
use super::VectorIndex;
use crate::index::prefilter::PreFilter;
use crate::index::vector::utils::maybe_sample_training_data;
use crate::{arrow::*, Dataset};
use crate::{Error, Result};
#[derive(Clone)]
pub struct PQIndex {
pub pq: ProductQuantizer,
pub code: Option<Arc<UInt8Array>>,
pub row_ids: Option<Arc<UInt64Array>>,
metric_type: MetricType,
}
impl DeepSizeOf for PQIndex {
fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
self.pq.deep_size_of_children(context)
+ self
.code
.as_ref()
.map(|code| code.get_array_memory_size())
.unwrap_or(0)
+ self
.row_ids
.as_ref()
.map(|row_ids| row_ids.get_array_memory_size())
.unwrap_or(0)
}
}
impl std::fmt::Debug for PQIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"PQ(m={}, nbits={}, {})",
self.pq.code_dim(),
self.pq.num_bits,
self.metric_type
)
}
}
impl PQIndex {
pub(crate) fn new(pq: ProductQuantizer, metric_type: MetricType) -> Self {
Self {
code: None,
row_ids: None,
pq,
metric_type,
}
}
fn filter_arrays(
pre_filter: &dyn PreFilter,
code: Arc<UInt8Array>,
row_ids: Arc<UInt64Array>,
_num_sub_vectors: i32,
) -> Result<(Arc<UInt8Array>, Arc<UInt64Array>)> {
let num_vectors = row_ids.len();
let indices_to_keep = pre_filter.filter_row_ids(Box::new(row_ids.values().iter()));
let indices_to_keep = UInt64Array::from(indices_to_keep);
let row_ids = take(row_ids.as_ref(), &indices_to_keep, None)?;
let row_ids = Arc::new(as_primitive_array(&row_ids).clone());
let code = code
.values()
.chunks_exact(num_vectors)
.flat_map(|c| {
let mut filtered = Vec::with_capacity(indices_to_keep.len());
for idx in indices_to_keep.values() {
filtered.push(c[*idx as usize]);
}
filtered
})
.collect();
Ok((Arc::new(code), row_ids))
}
fn get_pq_codes(transposed_codes: &UInt8Array, vec_idx: usize, num_vectors: usize) -> Vec<u8> {
transposed_codes
.values()
.iter()
.skip(vec_idx)
.step_by(num_vectors)
.cloned()
.collect()
}
}
#[async_trait]
impl Index for PQIndex {
fn as_any(&self) -> &dyn Any {
self
}
fn as_index(self: Arc<Self>) -> Arc<dyn Index> {
self
}
fn as_vector_index(self: Arc<Self>) -> Result<Arc<dyn VectorIndex>> {
Ok(self)
}
fn index_type(&self) -> IndexType {
IndexType::Vector
}
fn statistics(&self) -> Result<serde_json::Value> {
Ok(json!({
"index_type": "PQ",
"nbits": self.pq.num_bits,
"num_sub_vectors": self.pq.code_dim(),
"dimension": self.pq.dimension,
"metric_type": self.metric_type.to_string(),
}))
}
async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
if let Some(row_ids) = &self.row_ids {
let mut frag_ids = row_ids
.values()
.iter()
.map(|&row_id| RowAddress::from(row_id).fragment_id())
.collect::<Vec<_>>();
frag_ids.sort();
frag_ids.dedup();
Ok(RoaringBitmap::from_sorted_iter(frag_ids).unwrap())
} else {
Err(Error::Index {
message: "PQIndex::calculate_included_frags: PQ is not initialized".to_string(),
location: location!(),
})
}
}
}
#[async_trait]
impl VectorIndex for PQIndex {
#[instrument(level = "debug", skip_all, name = "PQIndex::search")]
async fn search(&self, query: &Query, pre_filter: Arc<dyn PreFilter>) -> Result<RecordBatch> {
if self.code.is_none() || self.row_ids.is_none() {
return Err(Error::Index {
message: "PQIndex::search: PQ is not initialized".to_string(),
location: location!(),
});
}
pre_filter.wait_for_ready().await?;
let code = self.code.as_ref().unwrap().clone();
let row_ids = self.row_ids.as_ref().unwrap().clone();
let pq = self.pq.clone();
let query = query.clone();
let num_sub_vectors = self.pq.code_dim() as i32;
spawn_cpu(move || {
let (code, row_ids) = if pre_filter.is_empty() {
Ok((code, row_ids))
} else {
Self::filter_arrays(pre_filter.as_ref(), code, row_ids, num_sub_vectors)
}?;
let distances = pq.compute_distances(query.key.as_ref(), &code)?;
debug_assert_eq!(distances.len(), row_ids.len());
let limit = query.k * query.refine_factor.unwrap_or(1) as usize;
let indices = sort_to_indices(&distances, None, Some(limit))?;
let distances = take(&distances, &indices, None)?;
let row_ids = take(row_ids.as_ref(), &indices, None)?;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new(DIST_COL, DataType::Float32, true),
ROW_ID_FIELD.clone(),
]));
Ok(RecordBatch::try_new(schema, vec![distances, row_ids])?)
})
.await
}
fn find_partitions(&self, _: &Query) -> Result<UInt32Array> {
unimplemented!("only for IVF")
}
async fn search_in_partition(
&self,
_: usize,
_: &Query,
_: Arc<dyn PreFilter>,
) -> Result<RecordBatch> {
unimplemented!("only for IVF")
}
fn is_loadable(&self) -> bool {
true
}
fn use_residual(&self) -> bool {
ProductQuantizer::use_residual(self.metric_type)
}
async fn load(
&self,
reader: Arc<dyn Reader>,
offset: usize,
length: usize,
) -> Result<Box<dyn VectorIndex>> {
let pq_code_length = self.pq.code_dim() * length;
let pq_codes = read_fixed_stride_array(
reader.as_ref(),
&DataType::UInt8,
offset,
pq_code_length,
..,
)
.await?;
let row_id_offset = offset + pq_code_length ;
let row_ids = read_fixed_stride_array(
reader.as_ref(),
&DataType::UInt64,
row_id_offset,
length,
..,
)
.await?;
let pq_codes = transpose(
pq_codes.as_primitive(),
row_ids.len(),
self.pq.num_sub_vectors,
);
Ok(Box::new(Self {
code: Some(Arc::new(pq_codes)),
row_ids: Some(Arc::new(row_ids.as_primitive().clone())),
pq: self.pq.clone(),
metric_type: self.metric_type,
}))
}
fn row_ids(&self) -> Box<dyn Iterator<Item = &u64>> {
todo!("this method is for only IVF_HNSW_* index");
}
fn check_can_remap(&self) -> Result<()> {
Ok(())
}
fn remap(&mut self, mapping: &HashMap<u64, Option<u64>>) -> Result<()> {
let num_vectors = self.row_ids.as_ref().unwrap().len();
let row_ids = self.row_ids.as_ref().unwrap().values().iter();
let transposed_codes = self.code.as_ref().unwrap();
let remapped = row_ids
.enumerate()
.filter_map(|(vec_idx, old_row_id)| {
let new_row_id = mapping.get(old_row_id).cloned();
let new_row_id = new_row_id.unwrap_or(Some(*old_row_id));
new_row_id.map(|new_row_id| {
(
new_row_id,
Self::get_pq_codes(transposed_codes, vec_idx, num_vectors),
)
})
})
.collect::<Vec<_>>();
self.row_ids = Some(Arc::new(UInt64Array::from_iter_values(
remapped.iter().map(|(row_id, _)| *row_id),
)));
let pq_codes =
UInt8Array::from_iter_values(remapped.into_iter().flat_map(|(_, code)| code));
let transposed_codes = transpose(
&pq_codes,
self.row_ids.as_ref().unwrap().len(),
self.pq.num_sub_vectors,
);
self.code = Some(Arc::new(transposed_codes));
Ok(())
}
fn ivf_model(&self) -> IvfModel {
unimplemented!("only for IVF")
}
fn quantizer(&self) -> Quantizer {
unimplemented!("only for IVF")
}
fn sub_index_type(&self) -> (SubIndexType, QuantizationType) {
(SubIndexType::Flat, QuantizationType::Product)
}
fn metric_type(&self) -> MetricType {
self.metric_type
}
}
pub async fn build_pq_model(
dataset: &Dataset,
column: &str,
dim: usize,
metric_type: MetricType,
params: &PQBuildParams,
ivf: Option<&IvfModel>,
) -> Result<ProductQuantizer> {
if let Some(codebook) = ¶ms.codebook {
let dt = if metric_type == MetricType::Cosine {
info!("Normalize training data for PQ training: Cosine");
MetricType::L2
} else {
metric_type
};
return match codebook.data_type() {
DataType::Float16 | DataType::Float32 | DataType::Float64 => Ok(ProductQuantizer::new(
params.num_sub_vectors,
params.num_bits as u32,
dim,
FixedSizeListArray::try_new_from_values(
codebook.slice(0, codebook.len()),
dim as i32,
)?,
dt,
)),
_ => Err(Error::Index {
message: format!("Wrong codebook data type: {:?}", codebook.data_type()),
location: location!(),
}),
};
}
info!(
"Start to train PQ code: PQ{}, bits={}",
params.num_sub_vectors, params.num_bits
);
let expected_sample_size =
lance_index::vector::pq::num_centroids(params.num_bits as u32) * params.sample_rate;
info!(
"Loading training data for PQ. Sample size: {}",
expected_sample_size
);
let start = std::time::Instant::now();
let mut training_data =
maybe_sample_training_data(dataset, column, expected_sample_size).await?;
info!(
"Finished loading training data in {:02} seconds",
start.elapsed().as_secs_f32()
);
info!(
"starting to compute partitions for PQ training, sample size: {}",
training_data.value_length()
);
if metric_type == MetricType::Cosine {
info!("Normalize training data for PQ training: Cosine");
training_data = normalize_fsl(&training_data)?;
}
let training_data = if let Some(ivf) = ivf {
let ivf2 = lance_index::vector::ivf::new_ivf_transformer(
ivf.centroids.clone().unwrap(),
MetricType::L2,
vec![],
);
span!(Level::INFO, "compute residual for PQ training")
.in_scope(|| ivf2.compute_residual(&training_data))?
} else {
training_data
};
let num_codes = 2_usize.pow(params.num_bits as u32);
if training_data.len() < num_codes {
return Err(Error::Index {
message: format!(
"Not enough rows to train PQ. Requires {:?} rows but only {:?} available",
num_codes,
training_data.len()
),
location: location!(),
});
}
info!("Start train PQ: params={:#?}", params);
let pq = ProductQuantizer::build(&training_data, DistanceType::L2, params)?;
info!("Trained PQ in: {} seconds", start.elapsed().as_secs_f32());
Ok(pq)
}
pub(crate) fn build_pq_storage(
distance_type: DistanceType,
row_ids: Arc<dyn Array>,
code_array: Vec<Arc<dyn Array>>,
pq: ProductQuantizer,
) -> Result<ProductQuantizationStorage> {
let pq_arrs = code_array.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
let pq_column = concat(&pq_arrs)?;
std::mem::drop(code_array);
let pq_batch = RecordBatch::try_from_iter_with_nullable(vec![
(ROW_ID, row_ids, true),
(pq.column(), pq_column, false),
])?;
let pq_store = ProductQuantizationStorage::new(
pq.codebook.clone(),
pq_batch,
pq.num_bits,
pq.code_dim(),
pq.dimension,
distance_type,
false,
)?;
Ok(pq_store)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::vector::ivf::build_ivf_model;
use arrow::datatypes::Float32Type;
use arrow_array::RecordBatchIterator;
use arrow_schema::{Field, Schema};
use lance_index::vector::ivf::IvfBuildParams;
use lance_testing::datagen::generate_random_array_with_range;
use std::ops::Range;
use tempfile::tempdir;
const DIM: usize = 128;
async fn generate_dataset(
test_uri: &str,
range: Range<f32>,
) -> (Dataset, Arc<FixedSizeListArray>) {
let vectors = generate_random_array_with_range::<Float32Type>(1000 * DIM, range);
let metadata: HashMap<String, String> = vec![("test".to_string(), "ivf_pq".to_string())]
.into_iter()
.collect();
let schema = Arc::new(
Schema::new(vec![Field::new(
"vector",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float32, true)),
DIM as i32,
),
true,
)])
.with_metadata(metadata),
);
let fsl = Arc::new(FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap());
let batch = RecordBatch::try_new(schema.clone(), vec![fsl.clone()]).unwrap();
let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema.clone());
(Dataset::write(batches, test_uri, None).await.unwrap(), fsl)
}
#[tokio::test]
async fn test_build_pq_model_l2() {
let test_dir = tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();
let (dataset, _) = generate_dataset(test_uri, 100.0..120.0).await;
let centroids = generate_random_array_with_range::<Float32Type>(4 * DIM, -1.0..1.0);
let fsl = FixedSizeListArray::try_new_from_values(centroids, DIM as i32).unwrap();
let ivf = IvfModel::new(fsl);
let params = PQBuildParams::new(16, 8);
let pq = build_pq_model(&dataset, "vector", DIM, MetricType::L2, ¶ms, Some(&ivf))
.await
.unwrap();
assert_eq!(pq.code_dim(), 16);
assert_eq!(pq.num_bits, 8);
assert_eq!(pq.dimension, DIM);
let codebook = pq.codebook;
assert_eq!(codebook.len(), 256);
codebook
.values()
.as_primitive::<Float32Type>()
.values()
.iter()
.for_each(|v| {
assert!((99.0..121.0).contains(v));
});
}
#[tokio::test]
async fn test_build_pq_model_cosine() {
let test_dir = tempdir().unwrap();
let test_uri = test_dir.path().to_str().unwrap();
let (dataset, vectors) = generate_dataset(test_uri, 100.0..120.0).await;
let ivf_params = IvfBuildParams::new(4);
let ivf = build_ivf_model(&dataset, "vector", DIM, MetricType::Cosine, &ivf_params)
.await
.unwrap();
let params = PQBuildParams::new(16, 8);
let pq = build_pq_model(
&dataset,
"vector",
DIM,
MetricType::Cosine,
¶ms,
Some(&ivf),
)
.await
.unwrap();
assert_eq!(pq.code_dim(), 16);
assert_eq!(pq.num_bits, 8);
assert_eq!(pq.dimension, DIM);
#[allow(clippy::redundant_clone)]
let codebook = pq.codebook.clone();
assert_eq!(codebook.len(), 256);
codebook
.values()
.as_primitive::<Float32Type>()
.values()
.iter()
.for_each(|v| {
assert!((-1.0..1.0).contains(v));
});
let vectors = normalize_fsl(&vectors).unwrap();
let row = vectors.slice(0, 1);
let ivf2 = lance_index::vector::ivf::new_ivf_transformer(
ivf.centroids.clone().unwrap(),
MetricType::L2,
vec![],
);
let residual_query = ivf2.compute_residual(&row).unwrap();
let pq_code = pq.quantize(&residual_query).unwrap();
let distances = pq
.compute_distances(
&residual_query.value(0),
pq_code.as_fixed_size_list().values().as_primitive(),
)
.unwrap();
assert!(
distances.values().iter().all(|&d| d <= 0.001),
"distances: {:?}",
distances
);
}
}