use crate::vector::quantizer::QuantizerStorage;
use arrow::compute::concat_batches;
use arrow_array::{ArrayRef, RecordBatch};
use arrow_schema::SchemaRef;
use deepsize::DeepSizeOf;
use futures::prelude::stream::TryStreamExt;
use lance_arrow::RecordBatchExt;
use lance_core::{Error, ROW_ID, Result};
use lance_encoding::decoder::FilterExpression;
use lance_file::reader::FileReader;
use lance_io::ReadBatchParams;
use lance_linalg::distance::DistanceType;
use prost::Message;
use std::{any::Any, sync::Arc};
use crate::frag_reuse::FragReuseIndex;
use crate::{
pb,
vector::{
ivf::storage::{IVF_METADATA_KEY, IvfModel},
quantizer::Quantization,
},
};
use super::DISTANCE_TYPE_KEY;
use super::quantizer::{Quantizer, QuantizerMetadata};
pub trait DistCalculator {
fn distance(&self, id: u32) -> f32;
fn distance_all(&self, k_hint: usize) -> Vec<f32>;
fn prefetch(&self, _id: u32) {}
}
pub const STORAGE_METADATA_KEY: &str = "storage_metadata";
pub trait VectorStore: Send + Sync + Sized + Clone {
type DistanceCalculator<'a>: DistCalculator
where
Self: 'a;
fn as_any(&self) -> &dyn Any;
fn schema(&self) -> &SchemaRef;
fn to_batches(&self) -> Result<impl Iterator<Item = RecordBatch> + Send>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn distance_type(&self) -> DistanceType;
fn row_id(&self, id: u32) -> u64;
fn row_ids(&self) -> impl Iterator<Item = &u64>;
fn append_batch(&self, batch: RecordBatch, vector_column: &str) -> Result<Self>;
fn dist_calculator(&self, query: ArrayRef, dist_q_c: f32) -> Self::DistanceCalculator<'_>;
fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_>;
fn dist_between(&self, u: u32, v: u32) -> f32 {
let dist_cal_u = self.dist_calculator_from_id(u);
dist_cal_u.distance(v)
}
}
pub struct StorageBuilder<Q: Quantization> {
vector_column: String,
distance_type: DistanceType,
quantizer: Q,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
}
impl<Q: Quantization> StorageBuilder<Q> {
pub fn new(
vector_column: String,
distance_type: DistanceType,
quantizer: Q,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
) -> Result<Self> {
Ok(Self {
vector_column,
distance_type,
quantizer,
frag_reuse_index,
})
}
pub fn build(&self, batches: Vec<RecordBatch>) -> Result<Q::Storage> {
let mut batch = concat_batches(batches[0].schema_ref(), batches.iter())?;
if batch.column_by_name(self.quantizer.column()).is_none() {
let vectors = batch
.column_by_name(&self.vector_column)
.ok_or(Error::index(format!(
"Vector column {} not found in batch",
self.vector_column
)))?;
let codes = self.quantizer.quantize(vectors)?;
batch = batch.drop_column(&self.vector_column)?.try_with_column(
arrow_schema::Field::new(self.quantizer.column(), codes.data_type().clone(), true),
codes,
)?;
}
debug_assert!(batch.column_by_name(ROW_ID).is_some());
debug_assert!(batch.column_by_name(self.quantizer.column()).is_some());
Q::Storage::try_from_batch(
batch,
&self.quantizer.metadata(None),
self.distance_type,
self.frag_reuse_index.clone(),
)
}
}
#[derive(Debug)]
pub struct IvfQuantizationStorage<Q: Quantization> {
reader: FileReader,
distance_type: DistanceType,
metadata: Q::Metadata,
ivf: IvfModel,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
}
impl<Q: Quantization> DeepSizeOf for IvfQuantizationStorage<Q> {
fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
self.metadata.deep_size_of_children(context) + self.ivf.deep_size_of_children(context)
}
}
impl<Q: Quantization> IvfQuantizationStorage<Q> {
pub async fn try_new(
reader: FileReader,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
) -> Result<Self> {
let schema = reader.schema();
let distance_type = DistanceType::try_from(
schema
.metadata
.get(DISTANCE_TYPE_KEY)
.ok_or(Error::index(format!("{} not found", DISTANCE_TYPE_KEY)))?
.as_str(),
)?;
let ivf_pos = schema
.metadata
.get(IVF_METADATA_KEY)
.ok_or(Error::index(format!("{} not found", IVF_METADATA_KEY)))?
.parse()
.map_err(|e| Error::index(format!("Failed to decode IVF metadata: {}", e)))?;
let ivf_bytes = reader.read_global_buffer(ivf_pos).await?;
let ivf = IvfModel::try_from(pb::Ivf::decode(ivf_bytes)?)?;
let mut metadata: Vec<String> = serde_json::from_str(
schema
.metadata
.get(STORAGE_METADATA_KEY)
.ok_or(Error::index(format!("{} not found", STORAGE_METADATA_KEY)))?
.as_str(),
)?;
debug_assert_eq!(metadata.len(), 1);
let metadata = metadata
.pop()
.ok_or(Error::index("metadata is empty".to_string()))?;
let mut metadata: Q::Metadata = serde_json::from_str(&metadata)?;
if let Some(pos) = metadata.buffer_index() {
let bytes = reader.read_global_buffer(pos).await?;
metadata.parse_buffer(bytes)?;
}
Ok(Self {
reader,
distance_type,
metadata,
ivf,
frag_reuse_index,
})
}
pub fn num_rows(&self) -> u64 {
self.reader.num_rows()
}
pub fn partition_size(&self, part_id: usize) -> usize {
self.ivf.partition_size(part_id)
}
pub fn quantizer(&self) -> Result<Quantizer> {
let metadata = self.metadata();
Q::from_metadata(metadata, self.distance_type)
}
pub fn metadata(&self) -> &Q::Metadata {
&self.metadata
}
pub fn distance_type(&self) -> DistanceType {
self.distance_type
}
pub fn schema(&self) -> SchemaRef {
Arc::new(self.reader.schema().as_ref().into())
}
pub fn num_partitions(&self) -> usize {
self.ivf.num_partitions()
}
pub async fn load_partition(&self, part_id: usize) -> Result<Q::Storage> {
let range = self.ivf.row_range(part_id);
let batch = if range.is_empty() {
let schema = self.reader.schema();
let arrow_schema = arrow_schema::Schema::from(schema.as_ref());
RecordBatch::new_empty(Arc::new(arrow_schema))
} else {
let batches = self
.reader
.read_stream(
ReadBatchParams::Range(range),
u32::MAX,
1,
FilterExpression::no_filter(),
)?
.try_collect::<Vec<_>>()
.await?;
let schema = Arc::new(self.reader.schema().as_ref().into());
concat_batches(&schema, batches.iter())?
};
Q::Storage::try_from_batch(
batch,
self.metadata(),
self.distance_type,
self.frag_reuse_index.clone(),
)
}
}