use std::sync::Arc;
use arrow_array::Float32Array;
use arrow_schema::Schema;
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
use lance::dataset::Dataset;
use lance_linalg::distance::MetricType;
use crate::error::Result;
use crate::utils::default_vector_column;
const DEFAULT_TOP_K: usize = 10;
#[derive(Clone)]
pub struct Query {
dataset: Arc<Dataset>,
column: Option<String>,
query_vector: Option<Float32Array>,
nprobes: usize,
refine_factor: Option<u32>,
metric_type: Option<MetricType>,
limit: Option<usize>,
filter: Option<String>,
select: Option<Vec<String>>,
use_index: bool,
prefilter: bool,
}
impl Query {
pub(crate) fn new(dataset: Arc<Dataset>) -> Self {
Query {
dataset,
query_vector: None,
column: None,
limit: None,
nprobes: 20,
refine_factor: None,
metric_type: None,
use_index: true,
filter: None,
select: None,
prefilter: false,
}
}
pub async fn execute_stream(&self) -> Result<DatasetRecordBatchStream> {
let mut scanner: Scanner = self.dataset.scan();
if let Some(query) = self.query_vector.as_ref() {
let column = if let Some(col) = self.column.as_ref() {
col.clone()
} else {
let arrow_schema = Schema::from(self.dataset.schema());
default_vector_column(&arrow_schema, Some(query.len() as i32))?
};
scanner.nearest(&column, query, self.limit.unwrap_or(DEFAULT_TOP_K))?;
} else {
scanner.limit(self.limit.map(|limit| limit as i64), None)?;
}
scanner.nprobs(self.nprobes);
scanner.use_index(self.use_index);
scanner.prefilter(self.prefilter);
self.select.as_ref().map(|p| scanner.project(p.as_slice()));
self.filter.as_ref().map(|f| scanner.filter(f));
self.refine_factor.map(|rf| scanner.refine(rf));
self.metric_type.map(|mt| scanner.distance_metric(mt));
Ok(scanner.try_into_stream().await?)
}
pub fn column(mut self, column: &str) -> Self {
self.column = Some(column.to_string());
self
}
pub fn limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
pub fn nearest_to(mut self, vector: &[f32]) -> Self {
self.query_vector = Some(Float32Array::from(vector.to_vec()));
self
}
pub fn nprobes(mut self, nprobes: usize) -> Self {
self.nprobes = nprobes;
self
}
pub fn refine_factor(mut self, refine_factor: u32) -> Self {
self.refine_factor = Some(refine_factor);
self
}
pub fn metric_type(mut self, metric_type: MetricType) -> Self {
self.metric_type = Some(metric_type);
self
}
pub fn use_index(mut self, use_index: bool) -> Self {
self.use_index = use_index;
self
}
pub fn filter(mut self, filter: impl AsRef<str>) -> Self {
self.filter = Some(filter.as_ref().to_string());
self
}
pub fn select(mut self, columns: &[impl AsRef<str>]) -> Self {
self.select = Some(columns.iter().map(|c| c.as_ref().to_string()).collect());
self
}
pub fn prefilter(mut self, prefilter: bool) -> Self {
self.prefilter = prefilter;
self
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use arrow_array::{
cast::AsArray, Float32Array, Int32Array, RecordBatch, RecordBatchIterator,
RecordBatchReader,
};
use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema};
use futures::StreamExt;
use lance::dataset::Dataset;
use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector};
use tempfile::tempdir;
use crate::query::Query;
use crate::table::{NativeTable, Table};
#[tokio::test]
async fn test_setters_getters() {
let batches = make_test_batches();
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
let vector = Some(Float32Array::from_iter_values([0.1, 0.2]));
let query = Query::new(Arc::new(ds)).nearest_to(&[0.1, 0.2]);
assert_eq!(query.query_vector, vector);
let new_vector = Float32Array::from_iter_values([9.8, 8.7]);
let query = query
.nearest_to(&[9.8, 8.7])
.limit(100)
.nprobes(1000)
.use_index(true)
.metric_type(MetricType::Cosine)
.refine_factor(999);
assert_eq!(query.query_vector.unwrap(), new_vector);
assert_eq!(query.limit.unwrap(), 100);
assert_eq!(query.nprobes, 1000);
assert_eq!(query.use_index, true);
assert_eq!(query.metric_type, Some(MetricType::Cosine));
assert_eq!(query.refine_factor, Some(999));
}
#[tokio::test]
async fn test_execute() {
let batches = make_non_empty_batches();
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
let query = Query::new(ds.clone()).nearest_to(&[0.1; 4]);
let result = query.limit(10).filter("id % 2 == 0").execute_stream().await;
let mut stream = result.expect("should have result");
while let Some(batch) = stream.next().await {
assert!(batch.expect("should be Ok").num_rows() < 10);
}
let query = Query::new(ds).nearest_to(&[0.1; 4]);
let result = query
.limit(10)
.filter(String::from("id % 2 == 0")) .prefilter(true)
.execute_stream()
.await;
let mut stream = result.expect("should have result");
while let Some(batch) = stream.next().await {
assert!(batch.expect("should be Ok").num_rows() == 10);
}
}
#[tokio::test]
async fn test_execute_no_vector() {
let batches = make_non_empty_batches();
let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap());
let query = Query::new(ds.clone());
let result = query.filter("id % 2 == 0").execute_stream().await;
let mut stream = result.expect("should have result");
while let Some(batch) = stream.next().await {
let b = batch.expect("should be Ok");
let arr: &Int32Array = b["id"].as_primitive();
assert!(arr.iter().all(|x| x.unwrap() % 2 == 0));
}
}
fn make_non_empty_batches() -> impl RecordBatchReader + Send + 'static {
let vec = Box::new(RandomVector::new().named("vector".to_string()));
let id = Box::new(IncrementingInt32::new().named("id".to_string()));
BatchGenerator::new().col(vec).col(id).batch(512)
}
fn make_test_batches() -> impl RecordBatchReader + Send + 'static {
let dim: usize = 128;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("key", DataType::Int32, false),
ArrowField::new(
"vector",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::Float32, true)),
dim as i32,
),
true,
),
ArrowField::new("uri", DataType::Utf8, true),
]));
RecordBatchIterator::new(
vec![RecordBatch::new_empty(schema.clone())]
.into_iter()
.map(Ok),
schema,
)
}
#[tokio::test]
async fn test_search() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = dataset_path.to_str().unwrap();
let batches = make_test_batches();
Dataset::write(batches, dataset_path.to_str().unwrap(), None)
.await
.unwrap();
let table = NativeTable::open(uri).await.unwrap();
let query = table.search(&[0.1, 0.2]);
assert_eq!(&[0.1, 0.2], query.query_vector.unwrap().values());
}
}