use std::any::Any;
use std::fmt::Debug;
use std::{collections::HashMap, sync::Arc};
use arrow_array::{ArrayRef, Float32Array, RecordBatch, UInt32Array};
use arrow_schema::Field;
use async_trait::async_trait;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use deepsize::DeepSizeOf;
use futures::stream;
use ivf::storage::IvfModel;
use lance_core::{Error, ROW_ID_FIELD, Result};
use lance_io::traits::Reader;
use lance_linalg::distance::DistanceType;
use quantizer::{QuantizationType, Quantizer};
use std::sync::LazyLock;
use v3::subindex::SubIndexType;
pub mod bq;
pub mod distributed;
pub mod flat;
pub mod graph;
pub mod hnsw;
pub mod ivf;
pub mod kmeans;
pub mod pq;
pub mod quantizer;
pub mod residual;
pub mod shared;
pub mod sq;
pub mod storage;
pub mod transform;
pub mod utils;
pub mod v3;
use super::pb;
use crate::metrics::MetricsCollector;
use crate::{Index, prefilter::PreFilter};
pub const DIST_COL: &str = "_distance";
pub const DISTANCE_TYPE_KEY: &str = "distance_type";
pub const INDEX_UUID_COLUMN: &str = "__index_uuid";
pub const PART_ID_COLUMN: &str = "__ivf_part_id";
pub const DIST_Q_C_COLUMN: &str = "__dist_q_c";
pub const CENTROID_DIST_COLUMN: &str = "__centroid_dist";
pub const PQ_CODE_COLUMN: &str = "__pq_code";
pub const SQ_CODE_COLUMN: &str = "__sq_code";
pub const LOSS_METADATA_KEY: &str = "_loss";
pub type PreparedPartitionSearchHandle = Box<dyn Any + Send>;
pub trait PartitionSearchControl: Send + Sync {
fn should_stop(&self) -> bool;
fn record_batch(&self, _batch: &RecordBatch) {}
}
pub static VECTOR_RESULT_SCHEMA: LazyLock<arrow_schema::SchemaRef> = LazyLock::new(|| {
arrow_schema::SchemaRef::new(arrow_schema::Schema::new(vec![
Field::new(DIST_COL, arrow_schema::DataType::Float32, true),
ROW_ID_FIELD.clone(),
]))
});
pub static PART_ID_FIELD: LazyLock<arrow_schema::Field> = LazyLock::new(|| {
arrow_schema::Field::new(PART_ID_COLUMN, arrow_schema::DataType::UInt32, true)
});
pub static CENTROID_DIST_FIELD: LazyLock<arrow_schema::Field> = LazyLock::new(|| {
arrow_schema::Field::new(CENTROID_DIST_COLUMN, arrow_schema::DataType::Float32, true)
});
pub const DEFAULT_QUERY_PARALLELISM: i32 = 0;
#[derive(Debug, Clone)]
pub struct Query {
pub column: String,
pub key: ArrayRef,
pub k: usize,
pub lower_bound: Option<f32>,
pub upper_bound: Option<f32>,
pub minimum_nprobes: usize,
pub maximum_nprobes: Option<usize>,
pub ef: Option<usize>,
pub refine_factor: Option<u32>,
pub metric_type: Option<DistanceType>,
pub use_index: bool,
pub query_parallelism: i32,
pub dist_q_c: f32,
}
impl From<pb::VectorMetricType> for DistanceType {
fn from(proto: pb::VectorMetricType) -> Self {
match proto {
pb::VectorMetricType::L2 => Self::L2,
pb::VectorMetricType::Cosine => Self::Cosine,
pb::VectorMetricType::Dot => Self::Dot,
pb::VectorMetricType::Hamming => Self::Hamming,
}
}
}
impl From<DistanceType> for pb::VectorMetricType {
fn from(mt: DistanceType) -> Self {
match mt {
DistanceType::L2 => Self::L2,
DistanceType::Cosine => Self::Cosine,
DistanceType::Dot => Self::Dot,
DistanceType::Hamming => Self::Hamming,
}
}
}
#[async_trait]
#[allow(clippy::redundant_pub_crate)]
pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index {
async fn search(
&self,
query: &Query,
pre_filter: Arc<dyn PreFilter>,
metrics: &dyn MetricsCollector,
) -> Result<RecordBatch>;
fn find_partitions(&self, query: &Query) -> Result<(UInt32Array, Float32Array)>;
fn total_partitions(&self) -> usize;
async fn search_in_partition(
&self,
partition_id: usize,
query: &Query,
pre_filter: Arc<dyn PreFilter>,
metrics: &dyn MetricsCollector,
) -> Result<RecordBatch>;
async fn prepare_partition_search(
&self,
_partition_id: usize,
_query: &Query,
_pre_filter: Arc<dyn PreFilter>,
_metrics: &dyn MetricsCollector,
) -> Result<PreparedPartitionSearchHandle> {
unimplemented!("prepared partition search is not supported for this index")
}
fn search_prepared_partition(
&self,
_prepared: PreparedPartitionSearchHandle,
_metrics: &dyn MetricsCollector,
) -> Result<RecordBatch> {
unimplemented!("prepared partition search is not supported for this index")
}
fn supports_prepared_partition_search(&self) -> bool {
false
}
fn auto_query_parallelism(&self, _cpu_pool_size: usize) -> usize {
1
}
#[allow(clippy::too_many_arguments)]
async fn search_partitions(
self: Arc<Self>,
query: Query,
partitions: Arc<UInt32Array>,
q_c_dists: Arc<Float32Array>,
start_idx: usize,
end_idx: usize,
pre_filter: Arc<dyn PreFilter>,
control: Option<Arc<dyn PartitionSearchControl>>,
metrics: Arc<dyn MetricsCollector>,
) -> Result<SendableRecordBatchStream>
where
Self: 'static,
{
if partitions.len() != q_c_dists.len() {
return Err(Error::invalid_input(format!(
"partition count {} does not match centroid distance count {}",
partitions.len(),
q_c_dists.len()
)));
}
if start_idx > end_idx || end_idx > partitions.len() {
return Err(Error::invalid_input(format!(
"invalid partition search range [{start_idx}, {end_idx}) for {} partitions",
partitions.len()
)));
}
let stream = stream::try_unfold(start_idx, move |idx| {
let index = self.clone();
let partitions = partitions.clone();
let q_c_dists = q_c_dists.clone();
let query = query.clone();
let pre_filter = pre_filter.clone();
let control = control.clone();
let metrics = metrics.clone();
async move {
if idx >= end_idx
|| control
.as_ref()
.is_some_and(|control| control.should_stop())
{
return Ok(None);
}
let part_id = partitions.value(idx);
let mut query = query;
query.dist_q_c = q_c_dists.value(idx);
index
.search_in_partition(part_id as usize, &query, pre_filter, metrics.as_ref())
.await
.map(|batch| {
if let Some(control) = control.as_ref() {
control.record_batch(&batch);
}
Some((batch, idx + 1))
})
.map_err(Into::into)
}
});
Ok(Box::pin(RecordBatchStreamAdapter::new(
VECTOR_RESULT_SCHEMA.clone(),
stream,
)))
}
fn is_loadable(&self) -> bool;
fn use_residual(&self) -> bool;
async fn load(
&self,
reader: Arc<dyn Reader>,
offset: usize,
length: usize,
) -> Result<Box<dyn VectorIndex>>;
async fn load_partition(
&self,
reader: Arc<dyn Reader>,
offset: usize,
length: usize,
_partition_id: usize,
) -> Result<Box<dyn VectorIndex>> {
self.load(reader, offset, length).await
}
async fn partition_reader(
&self,
_partition_id: usize,
_with_vector: bool,
_metrics: &dyn MetricsCollector,
) -> Result<SendableRecordBatchStream> {
unimplemented!("only for IVF")
}
async fn to_batch_stream(&self, with_vector: bool) -> Result<SendableRecordBatchStream>;
fn num_rows(&self) -> u64;
fn row_ids(&self) -> Box<dyn Iterator<Item = &'_ u64> + '_>;
async fn remap(&mut self, mapping: &HashMap<u64, Option<u64>>) -> Result<()>;
fn metric_type(&self) -> DistanceType;
fn ivf_model(&self) -> &IvfModel;
fn quantizer(&self) -> Quantizer;
fn partition_size(&self, part_id: usize) -> usize;
fn sub_index_type(&self) -> (SubIndexType, QuantizationType);
}
pub trait VectorIndexCacheEntry: Debug + Send + Sync + DeepSizeOf {
fn as_any(&self) -> &dyn Any;
}