use std::{cmp::min, collections::HashMap, sync::Arc};
use arrow_array::{
cast::AsArray,
types::{Float32Type, UInt64Type, UInt8Type},
FixedSizeListArray, Float32Array, RecordBatch, UInt64Array, UInt8Array,
};
use arrow_schema::SchemaRef;
use async_trait::async_trait;
use lance_core::{datatypes::Schema, Error, Result, ROW_ID};
use lance_file::{reader::FileReader, writer::FileWriter};
use lance_io::{
object_store::ObjectStore,
traits::{WriteExt, Writer},
utils::read_message,
};
use lance_linalg::{distance::MetricType, MatrixView};
use lance_table::{format::SelfDescribingFileReader, io::manifest::ManifestDescribing};
use object_store::path::Path;
use serde::{Deserialize, Serialize};
use snafu::{location, Location};
use super::{distance::build_distance_table_l2, num_centroids, ProductQuantizerImpl};
use crate::{
pb,
vector::{
graph::storage::{DistCalculator, VectorStorage},
ivf::storage::IvfData,
pq::transform::PQTransformer,
quantizer::{QuantizerMetadata, QuantizerStorage},
transform::Transformer,
PQ_CODE_COLUMN,
},
IndexMetadata, INDEX_METADATA_SCHEMA_KEY,
};
pub const PQ_METADTA_KEY: &str = "lance:pq";
#[derive(Clone, Serialize, Deserialize)]
pub struct ProductQuantizationMetadata {
pub codebook_position: usize,
pub num_bits: u32,
pub num_sub_vectors: usize,
pub dimension: usize,
#[serde(skip)]
pub codebook: Option<FixedSizeListArray>,
}
#[async_trait]
impl QuantizerMetadata for ProductQuantizationMetadata {
async fn load(reader: &FileReader) -> Result<Self> {
let metadata = reader
.schema()
.metadata
.get(PQ_METADTA_KEY)
.ok_or(Error::Index {
message: format!(
"Reading PQ storage: metadata key {} not found",
PQ_METADTA_KEY
),
location: location!(),
})?;
let mut metadata: Self = serde_json::from_str(metadata).map_err(|_| Error::Index {
message: format!("Failed to parse PQ metadata: {}", metadata),
location: location!(),
})?;
let codebook_tensor: pb::Tensor =
read_message(reader.object_reader.as_ref(), metadata.codebook_position).await?;
metadata.codebook = Some(FixedSizeListArray::try_from(&codebook_tensor)?);
Ok(metadata)
}
}
#[allow(dead_code)]
pub async fn write_parted_product_quantizations(
object_store: &ObjectStore,
path: &Path,
partitions: Box<dyn Iterator<Item = ProductQuantizationStorage>>,
) -> Result<()> {
let mut peek = partitions.peekable();
let first = peek.peek().ok_or(Error::Index {
message: "No partitions to write".to_string(),
location: location!(),
})?;
let schema = first.schema();
let lance_schema = Schema::try_from(schema.as_ref())?;
let mut writer = FileWriter::<ManifestDescribing>::try_new(
object_store,
path,
lance_schema,
&Default::default(), )
.await?;
let mut ivf_data = IvfData::empty();
for storage in peek {
let num_rows = storage.write_partition(&mut writer).await?;
ivf_data.add_partition(num_rows as u32);
}
ivf_data.write(&mut writer).await?;
Ok(())
}
#[derive(Clone, Debug)]
pub struct ProductQuantizationStorage {
codebook: Arc<Float32Array>,
batch: RecordBatch,
num_bits: u32,
num_sub_vectors: usize,
dimension: usize,
metric_type: MetricType,
pq_code: Arc<UInt8Array>,
row_ids: Arc<UInt64Array>,
}
impl PartialEq for ProductQuantizationStorage {
fn eq(&self, other: &Self) -> bool {
self.metric_type.eq(&other.metric_type)
&& self.codebook.eq(&other.codebook)
&& self.num_bits.eq(&other.num_bits)
&& self.num_sub_vectors.eq(&other.num_sub_vectors)
&& self.dimension.eq(&other.dimension)
&& self.batch.columns().eq(other.batch.columns())
}
}
#[allow(dead_code)]
impl ProductQuantizationStorage {
pub fn new(
codebook: Arc<Float32Array>,
batch: RecordBatch,
num_bits: u32,
num_sub_vectors: usize,
dimension: usize,
metric_type: MetricType,
) -> Result<Self> {
let Some(row_ids) = batch.column_by_name(ROW_ID) else {
return Err(Error::Index {
message: "Row ID column not found from PQ storage".to_string(),
location: location!(),
});
};
let row_ids: Arc<UInt64Array> = row_ids
.as_primitive_opt::<UInt64Type>()
.ok_or(Error::Index {
message: "Row ID column is not of type UInt64".to_string(),
location: location!(),
})?
.clone()
.into();
let Some(pq_col) = batch.column_by_name(PQ_CODE_COLUMN) else {
return Err(Error::Index {
message: format!("{PQ_CODE_COLUMN} column not found from PQ storage"),
location: location!(),
});
};
let pq_code_fsl = pq_col.as_fixed_size_list_opt().ok_or(Error::Index {
message: format!(
"{PQ_CODE_COLUMN} column is not of type UInt8: {}",
pq_col.data_type()
),
location: location!(),
})?;
let pq_code: Arc<UInt8Array> = pq_code_fsl
.values()
.as_primitive_opt::<UInt8Type>()
.ok_or(Error::Index {
message: format!(
"{PQ_CODE_COLUMN} column is not of type UInt8: {}",
pq_col.data_type()
),
location: location!(),
})?
.clone()
.into();
Ok(Self {
codebook,
batch,
pq_code,
row_ids,
num_sub_vectors,
num_bits,
dimension,
metric_type,
})
}
pub fn batch(&self) -> &RecordBatch {
&self.batch
}
pub async fn build(
quantizer: Arc<ProductQuantizerImpl<Float32Type>>,
batch: &RecordBatch,
vector_col: &str,
) -> Result<Self> {
let codebook = quantizer.codebook.clone();
let num_bits = quantizer.num_bits;
let dimension = quantizer.dimension;
let num_sub_vectors = quantizer.num_sub_vectors;
let metric_type = quantizer.metric_type;
let transform = PQTransformer::new(quantizer, vector_col, PQ_CODE_COLUMN);
let batch = transform.transform(batch).await?;
Self::new(
codebook,
batch,
num_bits,
num_sub_vectors,
dimension,
metric_type,
)
}
pub async fn load(object_store: &ObjectStore, path: &Path) -> Result<Self> {
let reader = FileReader::try_new_self_described(object_store, path, None).await?;
let schema = reader.schema();
let metadata_str = schema
.metadata
.get(INDEX_METADATA_SCHEMA_KEY)
.ok_or(Error::Index {
message: format!(
"Reading PQ storage: index key {} not found",
INDEX_METADATA_SCHEMA_KEY
),
location: location!(),
})?;
let index_metadata: IndexMetadata =
serde_json::from_str(metadata_str).map_err(|_| Error::Index {
message: format!("Failed to parse index metadata: {}", metadata_str),
location: location!(),
})?;
let metric_type: MetricType = MetricType::try_from(index_metadata.distance_type.as_str())?;
let metadata = ProductQuantizationMetadata::load(&reader).await?;
Self::load_partition(&reader, 0..reader.len(), metric_type, &metadata).await
}
pub fn schema(&self) -> SchemaRef {
self.batch.schema()
}
pub fn get_row_ids(&self, ids: &[u32]) -> Vec<u64> {
ids.iter()
.map(|&id| self.row_ids.value(id as usize))
.collect()
}
pub async fn write_partition(
&self,
writer: &mut FileWriter<ManifestDescribing>,
) -> Result<usize> {
let batch_size: usize = 10240; for offset in (0..self.batch.num_rows()).step_by(batch_size) {
let length = min(batch_size, self.batch.num_rows() - offset);
let slice = self.batch.slice(offset, length);
writer.write(&[slice]).await?;
}
Ok(self.batch.num_rows())
}
pub async fn write_full(&self, writer: &mut FileWriter<ManifestDescribing>) -> Result<()> {
let pos = writer.object_writer.tell().await?;
let mat = MatrixView::<Float32Type>::new(self.codebook.clone(), self.dimension);
let codebook_tensor = pb::Tensor::from(&mat);
writer
.object_writer
.write_protobuf(&codebook_tensor)
.await?;
self.write_partition(writer).await?;
let metadata = ProductQuantizationMetadata {
codebook_position: pos,
num_bits: self.num_bits,
num_sub_vectors: self.num_sub_vectors,
dimension: self.dimension,
codebook: None,
};
let index_metadata = IndexMetadata {
index_type: "PQ".to_string(),
distance_type: self.metric_type.to_string(),
};
let mut schema_metadata = HashMap::new();
schema_metadata.insert(
PQ_METADTA_KEY.to_string(),
serde_json::to_string(&metadata)?,
);
schema_metadata.insert(
INDEX_METADATA_SCHEMA_KEY.to_string(),
serde_json::to_string(&index_metadata)?,
);
writer.finish_with_metadata(&schema_metadata).await?;
Ok(())
}
}
#[async_trait]
impl QuantizerStorage for ProductQuantizationStorage {
type Metadata = ProductQuantizationMetadata;
async fn load_partition(
reader: &FileReader,
range: std::ops::Range<usize>,
metric_type: MetricType,
metadata: &Self::Metadata,
) -> Result<Self> {
let codebook = Arc::new(
metadata
.codebook
.as_ref()
.ok_or(Error::Index {
message: "Codebook not found in PQ metadata".to_string(),
location: location!(),
})?
.values()
.as_primitive::<Float32Type>()
.clone(),
);
let schema = reader.schema();
let batch = reader.read_range(range, schema, None).await?;
Self::new(
codebook,
batch,
metadata.num_bits,
metadata.num_sub_vectors,
metadata.dimension,
metric_type,
)
}
}
impl VectorStorage for ProductQuantizationStorage {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn len(&self) -> usize {
self.batch.num_rows()
}
fn row_ids(&self) -> &[u64] {
self.row_ids.values()
}
fn metric_type(&self) -> MetricType {
self.metric_type
}
fn dist_calculator(&self, query: &[f32]) -> Box<dyn DistCalculator> {
Box::new(PQDistCalculator::new(
self.codebook.values(),
self.num_bits,
self.num_sub_vectors,
self.pq_code.clone(),
query,
self.metric_type(),
))
}
}
struct PQDistCalculator {
distance_table: Vec<f32>,
pq_code: Arc<UInt8Array>,
num_sub_vectors: usize,
num_centroids: usize,
}
impl PQDistCalculator {
fn new(
codebook: &[f32],
num_bits: u32,
num_sub_vectors: usize,
pq_code: Arc<UInt8Array>,
query: &[f32],
metric_type: MetricType,
) -> Self {
let distance_table = if matches!(metric_type, MetricType::Cosine | MetricType::L2) {
build_distance_table_l2(codebook, num_bits, num_sub_vectors, query)
} else {
unimplemented!("Metric type not supported: {:?}", metric_type);
};
Self {
distance_table,
num_sub_vectors,
pq_code,
num_centroids: num_centroids(num_bits),
}
}
fn get_pq_code(&self, id: u32) -> &[u8] {
let start = id as usize * self.num_sub_vectors;
let end = start + self.num_sub_vectors;
&self.pq_code.values()[start..end]
}
}
impl DistCalculator for PQDistCalculator {
fn distance(&self, ids: &[u32]) -> Vec<f32> {
ids.iter()
.map(|&id| {
let pq_code = self.get_pq_code(id);
pq_code
.iter()
.enumerate()
.map(|(i, &c)| self.distance_table[i * self.num_centroids + c as usize])
.sum()
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use lance_arrow::FixedSizeListArrayExt;
use lance_core::ROW_ID_FIELD;
const DIM: usize = 32;
const TOTAL: usize = 512;
const NUM_SUB_VECTORS: usize = 16;
async fn create_pq_storage() -> ProductQuantizationStorage {
let codebook = Arc::new(Float32Array::from_iter_values(
(0..256 * DIM).map(|v| v as f32),
));
let pq = Arc::new(ProductQuantizerImpl::<Float32Type>::new(
NUM_SUB_VECTORS,
8,
DIM,
codebook,
MetricType::L2,
));
let schema = ArrowSchema::new(vec![
Field::new(
"vectors",
DataType::FixedSizeList(
Field::new_list_field(DataType::Float32, true).into(),
DIM as i32,
),
true,
),
ROW_ID_FIELD.clone(),
]);
let vectors = Float32Array::from_iter_values((0..TOTAL * DIM).map(|v| v as f32));
let row_ids = UInt64Array::from_iter_values((0..TOTAL).map(|v| v as u64));
let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap();
let batch =
RecordBatch::try_new(schema.into(), vec![Arc::new(fsl), Arc::new(row_ids)]).unwrap();
ProductQuantizationStorage::build(pq.clone(), &batch, "vectors")
.await
.unwrap()
}
#[tokio::test]
async fn test_build_pq_storage() {
let storage = create_pq_storage().await;
assert_eq!(storage.len(), TOTAL);
assert_eq!(storage.num_sub_vectors, NUM_SUB_VECTORS);
assert_eq!(storage.codebook.len(), 256 * DIM);
assert_eq!(storage.pq_code.len(), TOTAL * NUM_SUB_VECTORS);
assert_eq!(storage.row_ids.len(), TOTAL);
}
#[tokio::test]
async fn test_read_write_pq_storage() {
let storage = create_pq_storage().await;
let store = ObjectStore::memory();
let path = Path::from("pq_storage");
let schema = Schema::try_from(storage.schema().as_ref()).unwrap();
let mut file_writer = FileWriter::<ManifestDescribing>::try_new(
&store,
&path,
schema.clone(),
&Default::default(),
)
.await
.unwrap();
storage.write_full(&mut file_writer).await.unwrap();
let storage2 = ProductQuantizationStorage::load(&store, &path)
.await
.unwrap();
assert_eq!(storage, storage2);
}
}