use std::ops::Range;
use std::sync::Arc;
use arrow_array::FixedSizeListArray;
use lance_core::{Error, Result};
use lance_file::{reader::FileReader, writer::FileWriter};
use lance_io::{traits::WriteExt, utils::read_message};
use lance_table::io::manifest::ManifestDescribing;
use log::debug;
use serde::{Deserialize, Serialize};
use snafu::{location, Location};
use crate::pb::Ivf as PbIvf;
pub const IVF_METADATA_KEY: &str = "lance:ivf";
pub const IVF_PARTITION_KEY: &str = "lance:ivf:partition";
#[warn(dead_code)]
#[derive(Debug, Clone, PartialEq)]
pub struct IvfData {
centroids: Option<Arc<FixedSizeListArray>>,
lengths: Vec<u32>,
partition_row_offsets: Vec<usize>,
}
#[derive(Serialize, Deserialize, Debug)]
struct IvfMetadata {
pb_position: usize,
}
impl IvfData {
pub fn empty() -> Self {
Self {
centroids: None,
lengths: vec![],
partition_row_offsets: vec![0],
}
}
pub fn with_centroids(centroids: Arc<FixedSizeListArray>) -> Self {
Self {
centroids: Some(centroids),
lengths: vec![],
partition_row_offsets: vec![0],
}
}
pub async fn load(reader: &FileReader) -> Result<Self> {
let schema = reader.schema();
let meta_str = schema.metadata.get(IVF_METADATA_KEY).ok_or(Error::Index {
message: format!("{} not found during search", IVF_METADATA_KEY),
location: location!(),
})?;
let ivf_metadata: IvfMetadata =
serde_json::from_str(meta_str).map_err(|e| Error::Index {
message: format!("Failed to parse IVF metadata: {}", e),
location: location!(),
})?;
let pb: PbIvf = read_message(
reader.object_reader.as_ref(),
ivf_metadata.pb_position as usize,
)
.await?;
Self::try_from(pb)
}
pub async fn write(&self, writer: &mut FileWriter<ManifestDescribing>) -> Result<()> {
let pb = PbIvf::try_from(self)?;
let pos = writer.object_writer.write_protobuf(&pb).await?;
let ivf_metadata = IvfMetadata { pb_position: pos };
writer.add_metadata(IVF_METADATA_KEY, &serde_json::to_string(&ivf_metadata)?);
Ok(())
}
pub fn add_partition(&mut self, num_rows: u32) {
self.lengths.push(num_rows);
let last_offset = self.partition_row_offsets.last().copied().unwrap_or(0);
self.partition_row_offsets
.push(last_offset + num_rows as usize);
}
pub fn has_centroids(&self) -> bool {
self.centroids.is_some()
}
pub fn num_partitions(&self) -> usize {
self.lengths.len()
}
pub fn row_range(&self, partition: usize) -> Range<usize> {
let start = self.partition_row_offsets[partition];
let end = self.partition_row_offsets[partition + 1];
start..end
}
}
impl TryFrom<PbIvf> for IvfData {
type Error = Error;
fn try_from(proto: PbIvf) -> Result<Self> {
let centroids = if let Some(tensor) = proto.centroids_tensor.as_ref() {
debug!("Ivf: loading IVF centroids from index format v2");
Some(Arc::new(FixedSizeListArray::try_from(tensor)?))
} else {
None
};
let offsets = [0]
.iter()
.chain(proto.lengths.iter())
.scan(0_usize, |state, &x| {
*state += x as usize;
Some(*state)
});
Ok(Self {
centroids,
lengths: proto.lengths.clone(),
partition_row_offsets: offsets.collect(),
})
}
}
impl TryFrom<&IvfData> for PbIvf {
type Error = Error;
fn try_from(meta: &IvfData) -> Result<Self> {
let lengths = meta.lengths.clone();
Ok(Self {
centroids: vec![], lengths,
offsets: vec![], centroids_tensor: meta
.centroids
.as_ref()
.map(|c| c.as_ref().try_into())
.transpose()?,
})
}
}
#[cfg(test)]
mod tests {
use arrow_array::{Float32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use lance_core::datatypes::Schema;
use lance_io::object_store::ObjectStore;
use lance_table::format::SelfDescribingFileReader;
use object_store::path::Path;
use super::*;
#[test]
fn test_ivf_find_rows() {
let mut ivf = IvfData::empty();
ivf.add_partition(20);
ivf.add_partition(50);
assert_eq!(ivf.row_range(0), 0..20);
assert_eq!(ivf.row_range(1), 20..70);
}
#[tokio::test]
async fn test_write_and_load() {
let mut ivf = IvfData::empty();
ivf.add_partition(20);
ivf.add_partition(50);
let object_store = ObjectStore::memory();
let path = Path::from("/foo");
let arrow_schema = ArrowSchema::new(vec![Field::new("a", DataType::Float32, true)]);
let schema = Schema::try_from(&arrow_schema).unwrap();
{
let mut writer =
FileWriter::try_new(&object_store, &path, schema.clone(), &Default::default())
.await
.unwrap();
let batch = RecordBatch::try_new(
Arc::new(arrow_schema),
vec![Arc::new(Float32Array::from(vec![Some(1.0)]))],
)
.unwrap();
writer.write(&[batch]).await.unwrap();
ivf.write(&mut writer).await.unwrap();
writer.finish().await.unwrap();
}
let reader = FileReader::try_new_self_described(&object_store, &path, None)
.await
.unwrap();
assert!(reader.schema().metadata.contains_key(IVF_METADATA_KEY));
let ivf2 = IvfData::load(&reader).await.unwrap();
assert_eq!(ivf, ivf2);
assert_eq!(ivf2.num_partitions(), 2);
}
}