use std::sync::Arc;
use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, UInt32Array};
use lance::Dataset;
use lance::dataset::scanner::ColumnOrdering;
use lance::dataset::{InsertBuilder, WriteParams};
use lance_index::scalar::inverted::query::{FtsQuery, PhraseQuery};
use lance_index::scalar::{FullTextSearchQuery, InvertedIndexParams};
use lance_index::{DatasetIndexExt, IndexType};
use tantivy::tokenizer::Language;
use super::{strip_score_column, test_fts, test_scan, test_take};
use crate::utils::DatasetTestCases;
fn base_inverted_params(with_position: bool) -> InvertedIndexParams {
InvertedIndexParams::new("simple".to_string(), Language::English)
.with_position(with_position)
.lower_case(true)
.stem(false)
.remove_stop_words(false)
.ascii_folding(false)
.max_token_length(None)
}
fn params_for(base_tokenizer: &str, lower_case: bool, with_position: bool) -> InvertedIndexParams {
InvertedIndexParams::new(base_tokenizer.to_string(), Language::English)
.with_position(with_position)
.lower_case(lower_case)
.stem(false)
.remove_stop_words(false)
.ascii_folding(false)
.max_token_length(None)
}
async fn run_fts(ds: &Dataset, query: FullTextSearchQuery, filter: Option<&str>) -> RecordBatch {
let mut scanner = ds.scan();
scanner.full_text_search(query).unwrap();
if let Some(predicate) = filter {
scanner.filter(predicate).unwrap();
}
scanner
.order_by(Some(vec![ColumnOrdering::asc_nulls_first(
"id".to_string(),
)]))
.unwrap();
scanner.try_into_batch().await.unwrap()
}
async fn assert_fts_expected(
original: &RecordBatch,
ds: &Dataset,
query: FullTextSearchQuery,
filter: Option<&str>,
expected_ids: &[i32],
) {
let scanned = run_fts(ds, query, filter).await;
let scanned = strip_score_column(&scanned, original.schema().as_ref());
let indices_u32: Vec<u32> = expected_ids.iter().map(|&i| i as u32).collect();
let indices_array = UInt32Array::from(indices_u32);
let expected = arrow::compute::take_record_batch(original, &indices_array).unwrap();
assert_eq!(&expected, &scanned);
}
#[tokio::test]
async fn test_inverted_basic_equivalence() {
let ids = Arc::new(Int32Array::from((0..10).collect::<Vec<i32>>()));
let text_values = vec![
Some("hello world"),
Some("world hello"),
Some("hello"),
Some("lance database"),
Some(""),
None,
Some("hello lance"),
Some("lance"),
Some("database"),
Some("world"),
];
let text = Arc::new(StringArray::from(text_values)) as ArrayRef;
let batch = RecordBatch::try_from_iter(vec![("id", ids as ArrayRef), ("text", text)]).unwrap();
DatasetTestCases::from_data(batch.clone())
.run(|ds, original| async move {
let mut ds = ds;
let query = FullTextSearchQuery::new("hello".to_string())
.with_column("text".to_string())
.unwrap();
let expected_ids = vec![0, 1, 2, 6];
assert_fts_expected(&original, &ds, query.clone(), None, &expected_ids).await;
let params = base_inverted_params(false);
ds.create_index(&["text"], IndexType::Inverted, None, ¶ms, true)
.await
.unwrap();
assert_fts_expected(&original, &ds, query.clone(), None, &expected_ids).await;
test_fts(&original, &ds, "text", "hello", None, true, false).await;
test_scan(&original, &ds).await;
test_take(&original, &ds).await;
})
.await;
}
#[tokio::test]
async fn test_inverted_phrase_query_with_positions() {
let ids = Arc::new(Int32Array::from((0..6).collect::<Vec<i32>>()));
let text_values = vec![
Some("lance database"),
Some("lance and database"),
Some("database lance"),
Some("lance database test"),
Some("lance database"),
None,
];
let text = Arc::new(StringArray::from(text_values)) as ArrayRef;
let batch = RecordBatch::try_from_iter(vec![("id", ids as ArrayRef), ("text", text)]).unwrap();
DatasetTestCases::from_data(batch.clone())
.run(|ds, original| async move {
let mut ds = ds;
let params = base_inverted_params(true);
ds.create_index(&["text"], IndexType::Inverted, None, ¶ms, true)
.await
.unwrap();
let phrase = PhraseQuery::new("lance database".to_string())
.with_column(Some("text".to_string()));
let query = FullTextSearchQuery::new_query(FtsQuery::Phrase(phrase));
assert_fts_expected(&original, &ds, query, None, &[0, 3, 4]).await;
test_fts(&original, &ds, "text", "lance database", None, true, true).await;
})
.await;
}
#[tokio::test]
async fn test_inverted_with_filter() {
let ids = Arc::new(Int32Array::from((0..5).collect::<Vec<i32>>()));
let text_values = vec![
Some("lance database"),
Some("lance vector"),
Some("random text"),
Some("lance"),
None,
];
let categories = vec![
Some("keep"),
Some("drop"),
Some("keep"),
Some("keep"),
Some("keep"),
];
let text = Arc::new(StringArray::from(text_values)) as ArrayRef;
let category = Arc::new(StringArray::from(categories)) as ArrayRef;
let batch = RecordBatch::try_from_iter(vec![
("id", ids as ArrayRef),
("text", text),
("category", category),
])
.unwrap();
DatasetTestCases::from_data(batch.clone())
.with_index_types(
"category",
[
None,
Some(IndexType::Bitmap),
Some(IndexType::BTree),
Some(IndexType::BloomFilter),
Some(IndexType::ZoneMap),
],
)
.run(|ds, original| async move {
let mut ds = ds;
let params = base_inverted_params(false);
ds.create_index(&["text"], IndexType::Inverted, None, ¶ms, true)
.await
.unwrap();
let query = FullTextSearchQuery::new("lance".to_string())
.with_column("text".to_string())
.unwrap();
assert_fts_expected(&original, &ds, query, Some("category = 'keep'"), &[0, 3]).await;
test_fts(
&original,
&ds,
"text",
"lance",
Some("category = 'keep'"),
true,
false,
)
.await;
})
.await;
}
#[tokio::test]
async fn test_inverted_params_combinations() {
let ids = Arc::new(Int32Array::from((0..5).collect::<Vec<i32>>()));
let text_values = vec![
Some("Hello there, this is a longer sentence about Lance."),
Some("In this longer sentence we say hello to the database."),
Some("Another line: hello world appears in a longer phrase."),
Some("Saying HELLO loudly in a long sentence for testing."),
None,
];
let text = Arc::new(StringArray::from(text_values)) as ArrayRef;
let batch = RecordBatch::try_from_iter(vec![("id", ids as ArrayRef), ("text", text)]).unwrap();
let cases = vec![
(
"simple_lc_pos",
params_for("simple", true, true),
vec![0, 1, 2, 3],
true,
),
(
"simple_no_lc",
params_for("simple", false, false),
vec![1, 2],
false,
),
(
"whitespace_lc",
params_for("whitespace", true, false),
vec![0, 1, 2, 3],
true,
),
(
"whitespace_no_lc_pos",
params_for("whitespace", false, true),
vec![1, 2],
false,
),
];
for (_name, params, expected, lower_case) in cases {
let params = params.clone();
let expected = expected.clone();
DatasetTestCases::from_data(batch.clone())
.with_index_types_and_inverted_index_params("text", [Some(IndexType::Inverted)], params)
.run(|ds, original| {
let expected = expected.clone();
async move {
let query = FullTextSearchQuery::new("hello".to_string())
.with_column("text".to_string())
.unwrap();
assert_fts_expected(&original, &ds, query.clone(), None, &expected).await;
test_fts(&original, &ds, "text", "hello", None, lower_case, false).await;
}
})
.await;
}
}
#[tokio::test]
async fn test_fts_after_delete_with_stable_row_ids() {
let ids = Arc::new(Int32Array::from((0..20).collect::<Vec<i32>>()));
let texts: Vec<Option<&str>> = (0..20)
.map(|i| match i % 4 {
0 => Some("alpha shared"),
1 => Some("beta shared"),
2 => Some("gamma shared"),
_ => Some("delta shared"),
})
.collect();
let text_col = Arc::new(StringArray::from(texts));
let batch = RecordBatch::try_from_iter(vec![
("id", ids as ArrayRef),
("text", text_col as ArrayRef),
])
.unwrap();
let mut ds = InsertBuilder::new("memory://")
.with_params(&WriteParams {
enable_stable_row_ids: true,
..Default::default()
})
.execute(vec![batch])
.await
.unwrap();
let params = InvertedIndexParams::default();
ds.create_index_builder(&["text"], IndexType::Inverted, ¶ms)
.await
.unwrap();
ds.delete("id IN (0, 1, 2, 3, 4)").await.unwrap();
let query = FullTextSearchQuery::new("shared".to_string())
.with_column("text".to_string())
.unwrap();
let mut scanner = ds.scan();
scanner.full_text_search(query).unwrap();
scanner
.order_by(Some(vec![ColumnOrdering::asc_nulls_first(
"id".to_string(),
)]))
.unwrap();
let result = scanner.try_into_batch().await.unwrap();
assert_eq!(result.num_rows(), 15);
let result_ids = result
.column_by_name("id")
.unwrap()
.as_any()
.downcast_ref::<Int32Array>()
.unwrap();
for id in result_ids.values().iter() {
assert!(*id >= 5, "Deleted row id {} should not appear", id);
}
}