use std::{collections::HashMap, sync::Arc};
use arrow_array::{ArrayRef, RecordBatch, UInt32Array};
use arrow_schema::Field;
use async_trait::async_trait;
use ivf::storage::IvfModel;
use lance_core::{Result, ROW_ID_FIELD};
use lance_io::traits::Reader;
use lance_linalg::distance::DistanceType;
use lazy_static::lazy_static;
use quantizer::{QuantizationType, Quantizer};
use v3::subindex::SubIndexType;
pub mod bq;
pub mod flat;
pub mod graph;
pub mod hnsw;
pub mod ivf;
pub mod kmeans;
pub mod pq;
pub mod quantizer;
pub mod residual;
pub mod sq;
pub mod storage;
pub mod transform;
pub mod utils;
pub mod v3;
use super::pb;
use crate::{prefilter::PreFilter, Index};
pub use residual::RESIDUAL_COLUMN;
pub const DIST_COL: &str = "_distance";
pub const DISTANCE_TYPE_KEY: &str = "distance_type";
pub const INDEX_UUID_COLUMN: &str = "__index_uuid";
pub const PART_ID_COLUMN: &str = "__ivf_part_id";
pub const PQ_CODE_COLUMN: &str = "__pq_code";
pub const SQ_CODE_COLUMN: &str = "__sq_code";
lazy_static! {
pub static ref VECTOR_RESULT_SCHEMA: arrow_schema::SchemaRef =
arrow_schema::SchemaRef::new(arrow_schema::Schema::new(vec![
Field::new(DIST_COL, arrow_schema::DataType::Float32, false),
ROW_ID_FIELD.clone(),
]));
}
#[derive(Debug, Clone)]
pub struct Query {
pub column: String,
pub key: ArrayRef,
pub k: usize,
pub nprobes: usize,
pub ef: Option<usize>,
pub refine_factor: Option<u32>,
pub metric_type: DistanceType,
pub use_index: bool,
}
impl From<pb::VectorMetricType> for DistanceType {
fn from(proto: pb::VectorMetricType) -> Self {
match proto {
pb::VectorMetricType::L2 => Self::L2,
pb::VectorMetricType::Cosine => Self::Cosine,
pb::VectorMetricType::Dot => Self::Dot,
pb::VectorMetricType::Hamming => Self::Hamming,
}
}
}
impl From<DistanceType> for pb::VectorMetricType {
fn from(mt: DistanceType) -> Self {
match mt {
DistanceType::L2 => Self::L2,
DistanceType::Cosine => Self::Cosine,
DistanceType::Dot => Self::Dot,
DistanceType::Hamming => Self::Hamming,
}
}
}
#[async_trait]
#[allow(clippy::redundant_pub_crate)]
pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index {
async fn search(&self, query: &Query, pre_filter: Arc<dyn PreFilter>) -> Result<RecordBatch>;
fn find_partitions(&self, query: &Query) -> Result<UInt32Array>;
async fn search_in_partition(
&self,
partition_id: usize,
query: &Query,
pre_filter: Arc<dyn PreFilter>,
) -> Result<RecordBatch>;
fn is_loadable(&self) -> bool;
fn use_residual(&self) -> bool;
fn check_can_remap(&self) -> Result<()>;
async fn load(
&self,
reader: Arc<dyn Reader>,
offset: usize,
length: usize,
) -> Result<Box<dyn VectorIndex>>;
async fn load_partition(
&self,
reader: Arc<dyn Reader>,
offset: usize,
length: usize,
_partition_id: usize,
) -> Result<Box<dyn VectorIndex>> {
self.load(reader, offset, length).await
}
fn row_ids(&self) -> Box<dyn Iterator<Item = &'_ u64> + '_>;
fn remap(&mut self, mapping: &HashMap<u64, Option<u64>>) -> Result<()>;
fn metric_type(&self) -> DistanceType;
fn ivf_model(&self) -> IvfModel;
fn quantizer(&self) -> Quantizer;
fn sub_index_type(&self) -> (SubIndexType, QuantizationType);
}