use std::sync::Arc;
use arrow_array::{RecordBatch, UInt32Array, cast::AsArray};
use arrow_select::concat::concat_batches;
use datafusion::datasource::MemTable;
use datafusion::prelude::SessionContext;
use lance::Dataset;
use lance::dataset::scanner::ColumnOrdering;
use lance_datafusion::udf::register_functions;
use lance_index::scalar::FullTextSearchQuery;
use lance_index::scalar::inverted::query::{FtsQuery, PhraseQuery};
fn create_datafusion_context() -> SessionContext {
let ctx = SessionContext::new();
register_functions(&ctx);
ctx
}
mod inverted;
mod primitives;
mod vectors;
async fn test_scan(original: &RecordBatch, ds: &Dataset) {
let mut scanner = ds.scan();
scanner
.order_by(Some(vec![ColumnOrdering::asc_nulls_first(
"id".to_string(),
)]))
.unwrap();
let scanned = scanner.try_into_batch().await.unwrap();
assert_eq!(original, &scanned);
}
async fn test_take(original: &RecordBatch, ds: &Dataset) {
let num_rows = original.num_rows();
let cases: Vec<Vec<usize>> = vec![
vec![0, 1, 2], vec![5, 3, 1], vec![0], vec![], (0..num_rows.min(10)).collect(), vec![num_rows - 1, 0], vec![1, 1, 2], vec![0, 0, 0], vec![num_rows - 1, num_rows - 1], ];
for indices in cases {
let indices_u64: Vec<u64> = indices.iter().map(|&i| i as u64).collect();
let taken_ds = ds.take(&indices_u64, ds.schema().clone()).await.unwrap();
let indices_u32: Vec<u32> = indices.iter().map(|&i| i as u32).collect();
let indices_array = UInt32Array::from(indices_u32);
let taken_rb = arrow::compute::take_record_batch(original, &indices_array).unwrap();
assert_eq!(
taken_rb, taken_ds,
"Take results don't match for indices: {:?}",
indices
);
}
}
async fn test_filter(original: &RecordBatch, ds: &Dataset, predicate: &str) {
let mut scanner = ds.scan();
scanner
.filter(predicate)
.unwrap()
.order_by(Some(vec![ColumnOrdering::asc_nulls_first(
"id".to_string(),
)]))
.unwrap();
let scanned = scanner.try_into_batch().await.unwrap();
let ctx = create_datafusion_context();
let table = MemTable::try_new(original.schema(), vec![vec![original.clone()]]).unwrap();
ctx.register_table("t", Arc::new(table)).unwrap();
let sql = format!("SELECT * FROM t WHERE {} ORDER BY id", predicate);
let df = ctx.sql(&sql).await.unwrap();
let expected_batches = df.collect().await.unwrap();
let expected = concat_batches(&original.schema(), &expected_batches).unwrap();
assert_eq!(&expected, &scanned);
}
fn strip_score_column(batch: &RecordBatch, schema: &arrow_schema::Schema) -> RecordBatch {
let columns = schema
.fields()
.iter()
.map(|field| batch.column_by_name(field.name()).unwrap().clone())
.collect::<Vec<_>>();
RecordBatch::try_new(Arc::new(schema.clone()), columns).unwrap()
}
async fn test_fts(
original: &RecordBatch,
ds: &Dataset,
column: &str,
query: &str,
filter: Option<&str>,
lower_case: bool,
phrase_query: bool,
) {
let mut scanner = ds.scan();
let fts_query = if phrase_query {
let phrase = PhraseQuery::new(query.to_string()).with_column(Some(column.to_string()));
FullTextSearchQuery::new_query(FtsQuery::Phrase(phrase))
} else {
FullTextSearchQuery::new(query.to_string())
.with_column(column.to_string())
.unwrap()
};
scanner.full_text_search(fts_query).unwrap();
if let Some(predicate) = filter {
scanner.filter(predicate).unwrap();
}
scanner
.order_by(Some(vec![ColumnOrdering::asc_nulls_first(
"id".to_string(),
)]))
.unwrap();
let scanned = scanner.try_into_batch().await.unwrap();
let scanned = strip_score_column(&scanned, original.schema().as_ref());
let ctx = create_datafusion_context();
let table = MemTable::try_new(original.schema(), vec![vec![original.clone()]]).unwrap();
ctx.register_table("t", Arc::new(table)).unwrap();
let col_expr = if lower_case {
format!("lower(t.{})", column)
} else {
format!("t.{}", column)
};
let normalized_query = if lower_case {
query.to_lowercase()
} else {
query.to_string()
};
let expected_from_where = |where_clause: String| async move {
let sql = format!("SELECT * FROM t WHERE {} ORDER BY id", where_clause);
let df = ctx.sql(&sql).await.unwrap();
let expected_batches = df.collect().await.unwrap();
concat_batches(&original.schema(), &expected_batches).unwrap()
};
let expected = if normalized_query.is_empty() {
expected_from_where(filter.unwrap_or("true").to_string()).await
} else if phrase_query {
let predicate = format!("{} LIKE '%{}%'", col_expr, normalized_query);
let where_clause = if let Some(extra) = filter {
format!("{} AND {}", predicate, extra)
} else {
predicate
};
expected_from_where(where_clause).await
} else {
let tokens = collect_tokens(&normalized_query);
if tokens.is_empty() {
expected_from_where(filter.unwrap_or("true").to_string()).await
} else {
let predicate = tokens
.into_iter()
.map(|token| format!("{} LIKE '%{}%'", col_expr, token))
.collect::<Vec<_>>()
.join(" AND ");
let where_clause = if let Some(extra) = filter {
format!("{} AND {}", predicate, extra)
} else {
predicate
};
expected_from_where(where_clause).await
}
};
assert_eq!(&expected, &scanned);
}
fn collect_tokens(text: &str) -> Vec<&str> {
text.split(|c: char| !c.is_alphanumeric())
.filter(|word| !word.is_empty())
.collect()
}
async fn test_ann(original: &RecordBatch, ds: &Dataset, column: &str, predicate: Option<&str>) {
let vector_column = original.column_by_name(column).unwrap();
let fixed_size_list = vector_column.as_fixed_size_list();
let vector_values = fixed_size_list
.values()
.slice(0, fixed_size_list.value_length() as usize);
let query_vector = vector_values;
let mut scanner = ds.scan();
scanner
.nearest(column, query_vector.as_ref(), 10)
.unwrap()
.prefilter(true)
.refine(2);
if let Some(pred) = predicate {
scanner.filter(pred).unwrap();
}
let result = scanner.try_into_batch().await.unwrap();
let ctx = create_datafusion_context();
let table = MemTable::try_new(original.schema(), vec![vec![original.clone()]]).unwrap();
ctx.register_table("t", Arc::new(table)).unwrap();
let float_array = query_vector.as_primitive::<arrow::datatypes::Float32Type>();
let vector_values_str = float_array
.values()
.iter()
.map(|v| v.to_string())
.collect::<Vec<_>>()
.join(", ");
let sql = format!(
"SELECT * FROM t {} ORDER BY array_distance(t.{}, [{}]) LIMIT 10",
if let Some(pred) = predicate {
format!("WHERE {}", pred)
} else {
String::new()
},
column,
vector_values_str
);
let df = ctx.sql(&sql).await.unwrap();
let expected_batches = df.collect().await.unwrap();
let expected = concat_batches(&original.schema(), &expected_batches).unwrap();
assert_eq!(
expected.num_rows(),
result.num_rows(),
"Different number of results"
);
for (col_idx, field) in original.schema().fields().iter().enumerate() {
let expected_col = expected.column(col_idx);
let result_col = result.column(col_idx);
assert_eq!(
expected_col,
result_col,
"Column '{}' differs between DataFusion and Lance results",
field.name()
);
}
}