use std::{cmp::max, sync::Arc};
use lance_index::{DatasetIndexExt, IndexType};
pub use lance_linalg::distance::MetricType;
pub mod vector;
use crate::{utils::default_vector_column, Error, Result, Table};
pub enum IndexParams {
Scalar {
replace: bool,
},
IvfPq {
replace: bool,
metric_type: MetricType,
num_partitions: u64,
num_sub_vectors: u32,
num_bits: u32,
sample_rate: u32,
max_iterations: u32,
},
}
pub struct IndexBuilder {
table: Arc<dyn Table>,
columns: Vec<String>,
name: Option<String>,
replace: bool,
index_type: IndexType,
metric_type: MetricType,
num_partitions: Option<u32>,
num_sub_vectors: Option<u32>,
num_bits: u32,
sample_rate: u32,
max_iterations: u32,
}
impl IndexBuilder {
pub(crate) fn new(table: Arc<dyn Table>, columns: &[&str]) -> Self {
IndexBuilder {
table,
columns: columns.iter().map(|c| c.to_string()).collect(),
name: None,
replace: true,
index_type: IndexType::Scalar,
metric_type: MetricType::L2,
num_partitions: None,
num_sub_vectors: None,
num_bits: 8,
sample_rate: 256,
max_iterations: 50,
}
}
pub fn scalar(&mut self) -> &mut Self {
self.index_type = IndexType::Scalar;
self
}
pub fn ivf_pq(&mut self) -> &mut Self {
self.index_type = IndexType::Vector;
self
}
pub fn columns(&mut self, cols: &[&str]) -> &mut Self {
self.columns = cols.iter().map(|s| s.to_string()).collect();
self
}
pub fn replace(&mut self, v: bool) -> &mut Self {
self.replace = v;
self
}
pub fn name(&mut self, name: &str) -> &mut Self {
self.name = Some(name.to_string());
self
}
pub fn metric_type(&mut self, metric_type: MetricType) -> &mut Self {
self.metric_type = metric_type;
self
}
pub fn num_partitions(&mut self, num_partitions: u32) -> &mut Self {
self.num_partitions = Some(num_partitions);
self
}
pub fn num_sub_vectors(&mut self, num_sub_vectors: u32) -> &mut Self {
self.num_sub_vectors = Some(num_sub_vectors);
self
}
pub fn num_bits(&mut self, num_bits: u32) -> &mut Self {
self.num_bits = num_bits;
self
}
pub fn sample_rate(&mut self, sample_rate: u32) -> &mut Self {
self.sample_rate = sample_rate;
self
}
pub fn max_iterations(&mut self, max_iterations: u32) -> &mut Self {
self.max_iterations = max_iterations;
self
}
pub async fn build(&self) -> Result<()> {
let schema = self.table.schema();
let mut index_type = &self.index_type;
let columns = if self.columns.is_empty() {
index_type = &IndexType::Vector;
vec![default_vector_column(&schema, None)?]
} else {
self.columns.clone()
};
if columns.len() != 1 {
return Err(Error::Schema {
message: "Only one column is supported for index".to_string(),
});
}
let column = &columns[0];
let field = schema.field_with_name(column)?;
let params = match index_type {
IndexType::Scalar => IndexParams::Scalar {
replace: self.replace,
},
IndexType::Vector => {
let num_partitions = if let Some(n) = self.num_partitions {
n
} else {
suggested_num_partitions(self.table.count_rows().await?)
};
let num_sub_vectors: u32 = if let Some(n) = self.num_sub_vectors {
n
} else {
match field.data_type() {
arrow_schema::DataType::FixedSizeList(_, n) => {
Ok::<u32, Error>(suggested_num_sub_vectors(*n as u32))
}
_ => Err(Error::Schema {
message: format!(
"Column '{}' is not a FixedSizeList",
&self.columns[0]
),
}),
}?
};
IndexParams::IvfPq {
replace: self.replace,
metric_type: self.metric_type,
num_partitions: num_partitions as u64,
num_sub_vectors,
num_bits: self.num_bits,
sample_rate: self.sample_rate,
max_iterations: self.max_iterations,
}
}
};
let tbl = self
.table
.as_native()
.expect("Only native table is supported here");
let mut dataset = tbl.clone_inner_dataset();
match params {
IndexParams::Scalar { replace } => {
self.table
.as_native()
.unwrap()
.create_scalar_index(column, replace)
.await?
}
IndexParams::IvfPq {
replace,
metric_type,
num_partitions,
num_sub_vectors,
num_bits,
max_iterations,
..
} => {
let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_pq(
num_partitions as usize,
num_bits as u8,
num_sub_vectors as usize,
false,
metric_type,
max_iterations as usize,
);
dataset
.create_index(
&[column],
IndexType::Vector,
None,
&lance_idx_params,
replace,
)
.await?;
}
}
tbl.reset_dataset(dataset);
Ok(())
}
}
fn suggested_num_partitions(rows: usize) -> u32 {
let num_partitions = (rows as f64).sqrt() as u32;
max(1, num_partitions)
}
fn suggested_num_sub_vectors(dim: u32) -> u32 {
if dim % 16 == 0 {
dim / 16
} else if dim % 8 == 0 {
dim / 8
} else {
log::warn!(
"The dimension of the vector is not divisible by 8 or 16, \
which may cause performance degradation in PQ"
);
1
}
}