use std::ops::Range;
use std::sync::Arc;
use arrow_array::types::{Float16Type, Float32Type, Float64Type};
use arrow_array::{
cast::AsArray, types::UInt32Type, Array, FixedSizeListArray, RecordBatch, UInt32Array,
};
use arrow_schema::{DataType, Field};
use arrow_select::take::take;
use async_trait::async_trait;
use futures::{stream, StreamExt};
use lance_arrow::*;
use lance_core::{Error, Result};
use lance_linalg::{
distance::{Cosine, Dot, MetricType, L2},
MatrixView,
};
use log::info;
use snafu::{location, Location};
use tracing::{instrument, Instrument};
pub mod builder;
pub mod shuffler;
use super::{PART_ID_COLUMN, PQ_CODE_COLUMN, RESIDUAL_COLUMN};
use crate::vector::{
pq::{transform::PQTransformer, ProductQuantizer},
residual::ResidualTransform,
transform::Transformer,
};
pub use builder::IvfBuildParams;
use lance_linalg::kmeans::KMeans;
fn new_ivf_impl<T: ArrowFloatType + Dot + Cosine + L2 + 'static>(
centroids: &T::ArrayType,
dimension: usize,
metric_type: MetricType,
transforms: Vec<Arc<dyn Transformer>>,
range: Option<Range<u32>>,
) -> Arc<dyn Ivf> {
let mat = MatrixView::<T>::new(Arc::new(centroids.clone()), dimension);
Arc::new(IvfImpl::<T>::new(mat, metric_type, transforms, range))
}
pub fn new_ivf(
centroids: &dyn Array,
dimension: usize,
metric_type: MetricType,
transforms: Vec<Arc<dyn Transformer>>,
range: Option<Range<u32>>,
) -> Result<Arc<dyn Ivf>> {
match centroids.data_type() {
DataType::Float16 => Ok(new_ivf_impl::<Float16Type>(
centroids.as_primitive(),
dimension,
metric_type,
transforms,
range,
)),
DataType::Float32 => Ok(new_ivf_impl::<Float32Type>(
centroids.as_primitive(),
dimension,
metric_type,
transforms,
range,
)),
DataType::Float64 => Ok(new_ivf_impl::<Float64Type>(
centroids.as_primitive(),
dimension,
metric_type,
transforms,
range,
)),
_ => Err(Error::Index {
message: format!(
"new_ivf: centroids is not expected type: {}",
centroids.data_type()
),
location: location!(),
}),
}
}
fn new_ivf_with_pq_impl<T: ArrowFloatType + Dot + Cosine + L2 + 'static>(
centroids: &T::ArrayType,
dimension: usize,
metric_type: MetricType,
vector_column: &str,
pq: Arc<dyn ProductQuantizer>,
range: Option<Range<u32>>,
) -> Arc<dyn Ivf> {
let mat = MatrixView::<T>::new(Arc::new(centroids.clone()), dimension);
Arc::new(IvfImpl::<T>::new_with_pq(
mat,
metric_type,
vector_column,
pq,
range,
))
}
pub fn new_ivf_with_pq(
centroids: &dyn Array,
dimension: usize,
metric_type: MetricType,
vector_column: &str,
pq: Arc<dyn ProductQuantizer>,
range: Option<Range<u32>>,
) -> Result<Arc<dyn Ivf>> {
match centroids.data_type() {
DataType::Float16 => Ok(new_ivf_with_pq_impl::<Float16Type>(
centroids.as_primitive(),
dimension,
metric_type,
vector_column,
pq,
range,
)),
DataType::Float32 => Ok(new_ivf_with_pq_impl::<Float32Type>(
centroids.as_primitive(),
dimension,
metric_type,
vector_column,
pq,
range,
)),
DataType::Float64 => Ok(new_ivf_with_pq_impl::<Float64Type>(
centroids.as_primitive(),
dimension,
metric_type,
vector_column,
pq,
range,
)),
_ => Err(Error::Index {
message: format!(
"new_ivf_with_pq: centroids is not expected type: {}",
centroids.data_type()
),
location: location!(),
}),
}
}
#[async_trait]
pub trait Ivf: Send + Sync + std::fmt::Debug {
async fn compute_partitions(&self, data: &FixedSizeListArray) -> Result<UInt32Array>;
async fn compute_residual(
&self,
original: &FixedSizeListArray,
partitions: Option<&UInt32Array>,
) -> Result<FixedSizeListArray>;
fn find_partitions(&self, query: &dyn Array, nprobes: usize) -> Result<UInt32Array>;
async fn partition_transform(
&self,
batch: &RecordBatch,
column: &str,
partition_ids: Option<UInt32Array>,
) -> Result<RecordBatch>;
}
#[derive(Debug, Clone)]
pub struct IvfImpl<T: ArrowFloatType + Dot + L2 + Cosine> {
centroids: MatrixView<T>,
transforms: Vec<Arc<dyn Transformer>>,
metric_type: MetricType,
partition_range: Option<Range<u32>>,
}
impl<T: ArrowFloatType + Dot + L2 + Cosine + 'static> IvfImpl<T> {
pub fn new(
centroids: MatrixView<T>,
metric_type: MetricType,
transforms: Vec<Arc<dyn Transformer>>,
range: Option<Range<u32>>,
) -> Self {
Self {
centroids,
metric_type,
transforms,
partition_range: range,
}
}
fn new_with_pq(
centroids: MatrixView<T>,
metric_type: MetricType,
vector_column: &str,
pq: Arc<dyn ProductQuantizer>,
range: Option<Range<u32>>,
) -> Self {
let transforms: Vec<Arc<dyn Transformer>> = if pq.use_residual() {
vec![
Arc::new(ResidualTransform::new(
centroids.clone(),
PART_ID_COLUMN,
vector_column,
)),
Arc::new(PQTransformer::new(
pq.clone(),
RESIDUAL_COLUMN,
PQ_CODE_COLUMN,
)),
]
} else {
vec![Arc::new(PQTransformer::new(
pq.clone(),
vector_column,
PQ_CODE_COLUMN,
))]
};
Self {
centroids: centroids.clone(),
metric_type,
transforms,
partition_range: range,
}
}
fn dimension(&self) -> usize {
self.centroids.ndim()
}
#[instrument(level = "debug", skip(data))]
async fn do_compute_partitions(&self, data: &MatrixView<T>) -> UInt32Array {
use lance_linalg::kmeans::compute_partitions;
let dimension = data.ndim();
let centroids = self.centroids.data();
let data = data.data();
let metric_type = self.metric_type;
let num_centroids = centroids.len() / dimension;
let num_rows = data.len() / dimension;
let chunks = std::cmp::min(num_cpus::get(), num_rows);
info!(
"computing partition on {} chunks, out of {} centroids, and {} vectors",
chunks, num_centroids, num_rows,
);
let chunk_size = num_rows / chunks + if num_rows % chunks > 0 { 1 } else { 0 };
let stride = chunk_size * dimension;
let result: Vec<Vec<Option<u32>>> = stream::iter(0..chunks)
.map(|chunk_id| stride * chunk_id..std::cmp::min(stride * (chunk_id + 1), data.len()))
.filter(|range| futures::future::ready(range.start < range.end))
.map(|range| async {
let range: Range<usize> = range;
let centroids = centroids.clone();
let data = Arc::new(
data.slice(range.start, range.end - range.start)
.as_any()
.downcast_ref::<T::ArrayType>()
.unwrap()
.clone(),
);
compute_partitions::<T>(centroids, data, dimension, metric_type)
.in_current_span()
.await
})
.buffered(chunks)
.collect::<Vec<_>>()
.await;
UInt32Array::from_iter(result.iter().flatten().copied())
}
}
#[async_trait]
impl<T: ArrowFloatType + Dot + L2 + Cosine + 'static> Ivf for IvfImpl<T> {
async fn compute_partitions(&self, data: &FixedSizeListArray) -> Result<UInt32Array> {
let array = data
.values()
.as_any()
.downcast_ref::<T::ArrayType>()
.ok_or(Error::Index {
message: format!(
"Ivf::compute_partitions: data is not expected type: {} got {}",
T::FLOAT_TYPE,
data.values().data_type()
),
location: Default::default(),
})?;
let mat = MatrixView::<T>::new(Arc::new(array.clone()), data.value_length());
Ok(self.do_compute_partitions(&mat).await)
}
async fn compute_residual(
&self,
original: &FixedSizeListArray,
partitions: Option<&UInt32Array>,
) -> Result<FixedSizeListArray> {
let flatten_arr = original
.values()
.as_any()
.downcast_ref::<T::ArrayType>()
.ok_or(Error::Index {
message: format!(
"Ivf::compute_residual: original is not expected type: {} got {}",
T::FLOAT_TYPE,
original.values().data_type()
),
location: Default::default(),
})?;
let part_ids = if let Some(part_ids) = partitions {
part_ids.clone()
} else {
self.compute_partitions(original).await?
};
let dim = original.value_length() as usize;
let mut residual_arr: Vec<T::Native> = Vec::with_capacity(original.values().len());
flatten_arr
.as_slice()
.chunks_exact(dim)
.zip(part_ids.values())
.for_each(|(vector, &part_id)| {
let centroid = self.centroids.row(part_id as usize).unwrap();
residual_arr.extend(vector.iter().zip(centroid.iter()).map(|(&v, &c)| v - c));
});
let arr = T::ArrayType::from(residual_arr);
Ok(FixedSizeListArray::try_new_from_values(arr, dim as i32)?)
}
fn find_partitions(&self, query: &dyn Array, nprobes: usize) -> Result<UInt32Array> {
let query = query
.as_any()
.downcast_ref::<T::ArrayType>()
.ok_or(Error::Index {
message: format!(
"Ivf::find_partition: query is not expected type: {} got {}",
T::FLOAT_TYPE,
query.data_type()
),
location: Default::default(),
})?;
let kmeans = KMeans::<T>::with_centroids(
self.centroids.data().clone(),
self.dimension(),
self.metric_type,
);
Ok(kmeans.find_partitions(query.as_slice(), nprobes)?)
}
async fn partition_transform(
&self,
batch: &RecordBatch,
column: &str,
partition_ids: Option<UInt32Array>,
) -> Result<RecordBatch> {
let vector_arr = batch.column_by_name(column).ok_or(Error::Index {
message: format!("Column {} does not exist.", column),
location: location!(),
})?;
let data = vector_arr.as_fixed_size_list_opt().ok_or(Error::Index {
message: format!(
"Column {} is not a vector type: {}",
column,
vector_arr.data_type()
),
location: location!(),
})?;
let part_ids = if let Some(part_ids) = partition_ids {
part_ids
} else {
self.compute_partitions(data).await?
};
let (part_ids, batch) = if let Some(part_range) = self.partition_range.as_ref() {
let idx_in_range: UInt32Array = part_ids
.iter()
.enumerate()
.filter(|(_idx, part_id)| part_id.map(|r| part_range.contains(&r)).unwrap_or(false))
.map(|(idx, _)| idx as u32)
.collect();
let part_ids = take(&part_ids, &idx_in_range, None)?
.as_primitive::<UInt32Type>()
.clone();
let batch = batch.take(&idx_in_range)?;
(part_ids, batch)
} else {
(part_ids, batch.clone())
};
let field = Field::new(PART_ID_COLUMN, part_ids.data_type().clone(), false);
let mut batch = batch.try_with_column(field, Arc::new(part_ids))?;
for transform in self.transforms.as_slice() {
batch = transform.transform(&batch).await?;
}
Ok(batch)
}
}