use std::ops::Range;
use std::sync::Arc;
use arrow_array::{Array, FixedSizeListArray, Float32Array, RecordBatch, UInt32Array};
pub use builder::IvfBuildParams;
use lance_core::Result;
use lance_linalg::distance::{DistanceType, MetricType};
use tracing::instrument;
use crate::vector::bq::builder::RabitQuantizer;
use crate::vector::bq::transform::RQTransformer;
use crate::vector::ivf::transform::PartitionTransformer;
use crate::vector::kmeans::{compute_partitions_arrow_array, kmeans_find_partitions_arrow_array};
use crate::vector::{pq::ProductQuantizer, transform::Transformer};
use super::flat::transform::FlatTransformer;
use super::pq::transform::PQTransformer;
use super::quantizer::Quantization;
use super::residual::ResidualTransform;
use super::sq::ScalarQuantizer;
use super::sq::transform::SQTransformer;
use super::transform::KeepFiniteVectors;
use super::{PART_ID_COLUMN, PQ_CODE_COLUMN, SQ_CODE_COLUMN};
use super::{quantizer::Quantizer, residual::compute_residual};
pub mod builder;
pub mod shuffler;
pub mod storage;
mod transform;
pub fn new_ivf_transformer(
centroids: FixedSizeListArray,
metric_type: DistanceType,
transforms: Vec<Arc<dyn Transformer>>,
) -> IvfTransformer {
IvfTransformer::new(centroids, metric_type, transforms)
}
pub fn new_ivf_transformer_with_quantizer(
centroids: FixedSizeListArray,
metric_type: MetricType,
vector_column: &str,
quantizer: Quantizer,
range: Option<Range<u32>>,
) -> Result<IvfTransformer> {
match quantizer {
Quantizer::Flat(_) | Quantizer::FlatBin(_) => Ok(IvfTransformer::new_flat(
centroids,
metric_type,
vector_column,
range,
)),
Quantizer::Product(pq) => Ok(IvfTransformer::with_pq(
centroids,
metric_type,
vector_column,
pq,
range,
)),
Quantizer::Scalar(sq) => Ok(IvfTransformer::with_sq(
centroids,
metric_type,
vector_column,
sq,
range,
)),
Quantizer::Rabit(rq) => Ok(IvfTransformer::with_rq(
centroids,
metric_type,
vector_column,
rq,
range,
)),
}
}
#[derive(Debug)]
pub struct IvfTransformer {
centroids: FixedSizeListArray,
transforms: Vec<Arc<dyn Transformer>>,
distance_type: DistanceType,
}
impl IvfTransformer {
pub fn new(
centroids: FixedSizeListArray,
metric_type: MetricType,
transforms: Vec<Arc<dyn Transformer>>,
) -> Self {
Self {
centroids,
distance_type: metric_type,
transforms,
}
}
pub fn new_partition_transformer(
centroids: FixedSizeListArray,
distance_type: DistanceType,
vector_column: &str,
) -> Self {
let mut transforms: Vec<Arc<dyn Transformer>> =
vec![Arc::new(super::transform::Flatten::new(vector_column))];
let distance_type = if distance_type == MetricType::Cosine {
transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
vector_column,
)));
MetricType::L2
} else {
distance_type
};
transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
let partition_transform = Arc::new(PartitionTransformer::new(
centroids.clone(),
distance_type,
vector_column,
));
transforms.push(partition_transform);
Self::new(centroids, distance_type, transforms)
}
pub fn new_flat(
centroids: FixedSizeListArray,
distance_type: DistanceType,
vector_column: &str,
range: Option<Range<u32>>,
) -> Self {
let mut transforms: Vec<Arc<dyn Transformer>> =
vec![Arc::new(super::transform::Flatten::new(vector_column))];
let dt = if distance_type == DistanceType::Cosine {
transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
vector_column,
)));
MetricType::L2
} else {
distance_type
};
transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
let ivf_transform = Arc::new(PartitionTransformer::new(
centroids.clone(),
dt,
vector_column,
));
transforms.push(ivf_transform);
if let Some(range) = range {
transforms.push(Arc::new(transform::PartitionFilter::new(
PART_ID_COLUMN,
range,
)));
}
transforms.push(Arc::new(FlatTransformer::new(vector_column)));
Self::new(centroids, distance_type, transforms)
}
pub fn with_pq(
centroids: FixedSizeListArray,
distance_type: DistanceType,
vector_column: &str,
pq: ProductQuantizer,
range: Option<Range<u32>>,
) -> Self {
let mut transforms: Vec<Arc<dyn Transformer>> =
vec![Arc::new(super::transform::Flatten::new(vector_column))];
let distance_type = if distance_type == MetricType::Cosine {
transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
vector_column,
)));
MetricType::L2
} else {
distance_type
};
transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
let partition_transform = Arc::new(PartitionTransformer::new(
centroids.clone(),
distance_type,
vector_column,
));
transforms.push(partition_transform);
if let Some(range) = range {
transforms.push(Arc::new(transform::PartitionFilter::new(
PART_ID_COLUMN,
range,
)));
}
if ProductQuantizer::use_residual(distance_type) {
transforms.push(Arc::new(ResidualTransform::new(
centroids.clone(),
PART_ID_COLUMN,
vector_column,
)));
}
transforms.push(Arc::new(PQTransformer::new(
pq,
vector_column,
PQ_CODE_COLUMN,
)));
Self::new(centroids, distance_type, transforms)
}
fn with_sq(
centroids: FixedSizeListArray,
metric_type: MetricType,
vector_column: &str,
sq: ScalarQuantizer,
range: Option<Range<u32>>,
) -> Self {
let mut transforms: Vec<Arc<dyn Transformer>> =
vec![Arc::new(super::transform::Flatten::new(vector_column))];
let distance_type = if metric_type == MetricType::Cosine {
transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
vector_column,
)));
MetricType::L2
} else {
metric_type
};
transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
let partition_transformer = Arc::new(PartitionTransformer::new(
centroids.clone(),
distance_type,
vector_column,
));
transforms.push(partition_transformer);
if let Some(range) = range {
transforms.push(Arc::new(transform::PartitionFilter::new(
PART_ID_COLUMN,
range,
)));
}
transforms.push(Arc::new(SQTransformer::new(
sq,
vector_column.to_owned(),
SQ_CODE_COLUMN.to_owned(),
)));
Self::new(centroids, distance_type, transforms)
}
fn with_rq(
centroids: FixedSizeListArray,
distance_type: DistanceType,
vector_column: &str,
rq: RabitQuantizer,
range: Option<Range<u32>>,
) -> Self {
let mut transforms: Vec<Arc<dyn Transformer>> =
vec![Arc::new(super::transform::Flatten::new(vector_column))];
let distance_type = if distance_type == MetricType::Cosine {
transforms.push(Arc::new(super::transform::NormalizeTransformer::new(
vector_column,
)));
MetricType::L2
} else {
distance_type
};
transforms.push(Arc::new(KeepFiniteVectors::new(vector_column)));
let partition_transform = Arc::new(
PartitionTransformer::new(centroids.clone(), distance_type, vector_column)
.with_distance(true),
);
transforms.push(partition_transform);
if let Some(range) = range {
transforms.push(Arc::new(transform::PartitionFilter::new(
PART_ID_COLUMN,
range,
)));
}
transforms.push(Arc::new(ResidualTransform::new(
centroids.clone(),
PART_ID_COLUMN,
vector_column,
)));
transforms.push(Arc::new(RQTransformer::new(
rq,
distance_type,
centroids.clone(),
vector_column,
)));
Self::new(centroids, distance_type, transforms)
}
#[inline]
pub fn compute_residual(&self, data: &FixedSizeListArray) -> Result<FixedSizeListArray> {
compute_residual(&self.centroids, data, Some(self.distance_type), None)
}
#[inline]
pub fn compute_partitions(&self, data: &FixedSizeListArray) -> Result<UInt32Array> {
Ok(
compute_partitions_arrow_array(&self.centroids, data, self.distance_type)
.map(|(part_ids, _)| part_ids.into())?,
)
}
pub fn find_partitions(
&self,
query: &dyn Array,
nprobes: usize,
) -> Result<(UInt32Array, Float32Array)> {
Ok(kmeans_find_partitions_arrow_array(
&self.centroids,
query,
nprobes,
self.distance_type,
)?)
}
}
impl Transformer for IvfTransformer {
#[instrument(name = "IvfTransformer::transform", level = "debug", skip_all)]
fn transform(&self, batch: &RecordBatch) -> Result<RecordBatch> {
let mut batch = batch.clone();
for transform in self.transforms.as_slice() {
batch = transform.transform(&batch)?;
}
Ok(batch)
}
}