use std::collections::HashMap;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
use arrow_array::RecordBatch;
use crossbeam_skiplist::SkipMap;
use datafusion::common::ScalarValue;
use lance_core::{Error, Result};
use lance_index::scalar::InvertedIndexParams;
use lance_index::scalar::inverted::tokenizer::lance_tokenizer::LanceTokenizer;
use tantivy::tokenizer::TokenStream;
use super::RowPosition;
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct FtsKey {
pub token: String,
pub row_position: RowPosition,
}
#[derive(Debug, Clone)]
pub struct FtsEntry {
pub row_position: RowPosition,
pub score: f32,
}
#[derive(Debug, Clone)]
pub enum FtsQueryExpr {
Match {
query: String,
boost: f32,
},
Phrase {
query: String,
slop: u32,
boost: f32,
},
Fuzzy {
query: String,
fuzziness: Option<u32>,
max_expansions: usize,
boost: f32,
},
Boolean {
must: Vec<Self>,
should: Vec<Self>,
must_not: Vec<Self>,
},
Boost {
positive: Box<Self>,
negative: Option<Box<Self>>,
negative_boost: f32,
},
}
pub const DEFAULT_MAX_EXPANSIONS: usize = 50;
pub const DEFAULT_WAND_FACTOR: f32 = 1.0;
#[derive(Debug, Clone)]
pub struct SearchOptions {
pub wand_factor: f32,
pub limit: Option<usize>,
}
impl Default for SearchOptions {
fn default() -> Self {
Self {
wand_factor: DEFAULT_WAND_FACTOR,
limit: None,
}
}
}
impl SearchOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_wand_factor(mut self, wand_factor: f32) -> Self {
self.wand_factor = wand_factor.clamp(0.0, 1.0);
self
}
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
}
impl FtsQueryExpr {
pub fn match_query(query: impl Into<String>) -> Self {
Self::Match {
query: query.into(),
boost: 1.0,
}
}
pub fn phrase(query: impl Into<String>) -> Self {
Self::Phrase {
query: query.into(),
slop: 0,
boost: 1.0,
}
}
pub fn phrase_with_slop(query: impl Into<String>, slop: u32) -> Self {
Self::Phrase {
query: query.into(),
slop,
boost: 1.0,
}
}
pub fn fuzzy(query: impl Into<String>) -> Self {
Self::Fuzzy {
query: query.into(),
fuzziness: None, max_expansions: DEFAULT_MAX_EXPANSIONS,
boost: 1.0,
}
}
pub fn fuzzy_with_distance(query: impl Into<String>, fuzziness: u32) -> Self {
Self::Fuzzy {
query: query.into(),
fuzziness: Some(fuzziness),
max_expansions: DEFAULT_MAX_EXPANSIONS,
boost: 1.0,
}
}
pub fn fuzzy_with_options(
query: impl Into<String>,
fuzziness: Option<u32>,
max_expansions: usize,
) -> Self {
Self::Fuzzy {
query: query.into(),
fuzziness,
max_expansions,
boost: 1.0,
}
}
pub fn boolean() -> BooleanQueryBuilder {
BooleanQueryBuilder::new()
}
pub fn boosting(positive: Self) -> Self {
Self::Boost {
positive: Box::new(positive),
negative: None,
negative_boost: 1.0,
}
}
pub fn boosting_with_negative(positive: Self, negative: Self, negative_boost: f32) -> Self {
Self::Boost {
positive: Box::new(positive),
negative: Some(Box::new(negative)),
negative_boost,
}
}
pub fn with_boost(self, boost: f32) -> Self {
match self {
Self::Match { query, .. } => Self::Match { query, boost },
Self::Phrase { query, slop, .. } => Self::Phrase { query, slop, boost },
Self::Fuzzy {
query,
fuzziness,
max_expansions,
..
} => Self::Fuzzy {
query,
fuzziness,
max_expansions,
boost,
},
Self::Boolean {
must,
should,
must_not,
} => {
Self::Boolean {
must,
should,
must_not,
}
}
Self::Boost {
positive,
negative,
negative_boost,
} => {
Self::Boost {
positive,
negative,
negative_boost,
}
}
}
}
}
pub fn auto_fuzziness(token: &str) -> u32 {
match token.chars().count() {
0..=2 => 0,
3..=5 => 1,
_ => 2,
}
}
pub fn levenshtein_distance(a: &str, b: &str) -> u32 {
let a_chars: Vec<char> = a.chars().collect();
let b_chars: Vec<char> = b.chars().collect();
let m = a_chars.len();
let n = b_chars.len();
if m == 0 {
return n as u32;
}
if n == 0 {
return m as u32;
}
let mut prev_row: Vec<u32> = (0..=n as u32).collect();
let mut curr_row: Vec<u32> = vec![0; n + 1];
for (i, a_char) in a_chars.iter().enumerate() {
curr_row[0] = (i + 1) as u32;
for (j, b_char) in b_chars.iter().enumerate() {
let cost = if a_char == b_char { 0 } else { 1 };
curr_row[j + 1] = (prev_row[j + 1] + 1) .min(curr_row[j] + 1) .min(prev_row[j] + cost); }
std::mem::swap(&mut prev_row, &mut curr_row);
}
prev_row[n]
}
#[derive(Debug, Clone, Default)]
pub struct BooleanQueryBuilder {
must: Vec<FtsQueryExpr>,
should: Vec<FtsQueryExpr>,
must_not: Vec<FtsQueryExpr>,
}
impl BooleanQueryBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn must(mut self, query: FtsQueryExpr) -> Self {
self.must.push(query);
self
}
pub fn should(mut self, query: FtsQueryExpr) -> Self {
self.should.push(query);
self
}
pub fn must_not(mut self, query: FtsQueryExpr) -> Self {
self.must_not.push(query);
self
}
pub fn build(self) -> FtsQueryExpr {
FtsQueryExpr::Boolean {
must: self.must,
should: self.should,
must_not: self.must_not,
}
}
}
#[derive(Clone, Debug)]
pub struct PostingValue {
pub frequency: u32,
pub positions: Vec<u32>,
}
pub struct FtsMemIndex {
field_id: i32,
column_name: String,
postings: SkipMap<FtsKey, PostingValue>,
doc_count: AtomicUsize,
tokenizer: Mutex<Box<dyn LanceTokenizer>>,
params: InvertedIndexParams,
doc_lengths: SkipMap<u64, u32>,
total_tokens: AtomicUsize,
doc_freq: SkipMap<String, AtomicUsize>,
}
impl std::fmt::Debug for FtsMemIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FtsMemIndex")
.field("field_id", &self.field_id)
.field("column_name", &self.column_name)
.field("doc_count", &self.doc_count)
.field("params", &self.params)
.finish()
}
}
impl FtsMemIndex {
pub fn new(field_id: i32, column_name: String) -> Self {
Self::with_params(field_id, column_name, InvertedIndexParams::default())
}
pub fn with_params(field_id: i32, column_name: String, params: InvertedIndexParams) -> Self {
let tokenizer = params.build().expect("Failed to build tokenizer");
Self {
field_id,
column_name,
postings: SkipMap::new(),
doc_count: AtomicUsize::new(0),
tokenizer: Mutex::new(tokenizer),
params,
doc_lengths: SkipMap::new(),
total_tokens: AtomicUsize::new(0),
doc_freq: SkipMap::new(),
}
}
pub fn field_id(&self) -> i32 {
self.field_id
}
pub fn params(&self) -> &InvertedIndexParams {
&self.params
}
pub fn insert(&self, batch: &RecordBatch, row_offset: u64) -> Result<()> {
let col_idx = batch
.schema()
.column_with_name(&self.column_name)
.map(|(idx, _)| idx);
if col_idx.is_none() {
return Ok(());
}
let column = batch.column(col_idx.unwrap());
for row_idx in 0..batch.num_rows() {
let value = ScalarValue::try_from_array(column.as_ref(), row_idx)?;
let row_position = row_offset + row_idx as u64;
if let ScalarValue::Utf8(Some(text)) | ScalarValue::LargeUtf8(Some(text)) = value {
let mut term_data: HashMap<String, (u32, Vec<u32>)> = HashMap::new();
{
let mut tokenizer = self.tokenizer.lock().unwrap();
let mut token_stream = tokenizer.token_stream_for_doc(&text);
let mut position: u32 = 0;
while let Some(token) = token_stream.next() {
let entry = term_data.entry(token.text.clone()).or_default();
entry.0 += 1; entry.1.push(position); position += 1;
}
}
let doc_length: u32 = term_data.values().map(|(freq, _)| freq).sum();
self.doc_lengths.insert(row_position, doc_length);
self.total_tokens
.fetch_add(doc_length as usize, Ordering::Relaxed);
for (token, (freq, positions)) in term_data {
if let Some(entry) = self.doc_freq.get(&token) {
entry.value().fetch_add(1, Ordering::Relaxed);
} else {
self.doc_freq.insert(token.clone(), AtomicUsize::new(1));
}
let key = FtsKey {
token,
row_position,
};
self.postings.insert(
key,
PostingValue {
frequency: freq,
positions,
},
);
}
}
self.doc_count.fetch_add(1, Ordering::Relaxed);
}
Ok(())
}
pub fn search(&self, term: &str) -> Vec<FtsEntry> {
let tokens: Vec<String> = {
let mut tokenizer = self.tokenizer.lock().unwrap();
let mut token_stream = tokenizer.token_stream_for_search(term);
let mut tokens = Vec::new();
while let Some(token) = token_stream.next() {
tokens.push(token.text.clone());
}
tokens
};
const K1: f32 = 1.2;
const B: f32 = 0.75;
let n = self.doc_count.load(Ordering::Relaxed) as f32;
let total_tokens = self.total_tokens.load(Ordering::Relaxed) as f32;
let avgdl = if n > 0.0 { total_tokens / n } else { 1.0 };
let mut doc_term_info: HashMap<RowPosition, Vec<(u32, usize)>> = HashMap::new();
for token in &tokens {
let df = self
.doc_freq
.get(token)
.map(|e| e.value().load(Ordering::Relaxed))
.unwrap_or(0);
if df == 0 {
continue;
}
let start = FtsKey {
token: token.clone(),
row_position: 0,
};
let end = FtsKey {
token: token.clone(),
row_position: u64::MAX,
};
for entry in self.postings.range(start..=end) {
doc_term_info
.entry(entry.key().row_position)
.or_default()
.push((entry.value().frequency, df));
}
}
doc_term_info
.into_iter()
.map(|(row_position, term_infos)| {
let dl = self
.doc_lengths
.get(&row_position)
.map(|e| *e.value() as f32)
.unwrap_or(1.0);
let mut score: f32 = 0.0;
for (tf, df) in term_infos {
let df_f = df as f32;
let idf = ((n - df_f + 0.5) / (df_f + 0.5) + 1.0).ln();
let tf_f = tf as f32;
let numerator = tf_f * (K1 + 1.0);
let denominator = tf_f + K1 * (1.0 - B + B * (dl / avgdl));
score += idf * (numerator / denominator);
}
FtsEntry {
row_position,
score,
}
})
.collect()
}
pub fn search_phrase(&self, phrase: &str, slop: u32) -> Vec<FtsEntry> {
let tokens: Vec<String> = {
let mut tokenizer = self.tokenizer.lock().unwrap();
let mut token_stream = tokenizer.token_stream_for_search(phrase);
let mut tokens = Vec::new();
while let Some(token) = token_stream.next() {
tokens.push(token.text.clone());
}
tokens
};
if tokens.is_empty() {
return vec![];
}
if tokens.len() == 1 {
return self.search(phrase);
}
const K1: f32 = 1.2;
const B: f32 = 0.75;
let n = self.doc_count.load(Ordering::Relaxed) as f32;
let total_tokens = self.total_tokens.load(Ordering::Relaxed) as f32;
let avgdl = if n > 0.0 { total_tokens / n } else { 1.0 };
let mut token_postings: Vec<HashMap<RowPosition, PostingValue>> = Vec::new();
for token in &tokens {
let start = FtsKey {
token: token.clone(),
row_position: 0,
};
let end = FtsKey {
token: token.clone(),
row_position: u64::MAX,
};
let mut postings_for_token: HashMap<RowPosition, PostingValue> = HashMap::new();
for entry in self.postings.range(start..=end) {
postings_for_token.insert(entry.key().row_position, entry.value().clone());
}
token_postings.push(postings_for_token);
}
let first_token_docs: Vec<RowPosition> = token_postings[0].keys().copied().collect();
let mut matching_docs: Vec<FtsEntry> = Vec::new();
for row_position in first_token_docs {
let all_tokens_present = token_postings
.iter()
.all(|tp| tp.contains_key(&row_position));
if !all_tokens_present {
continue;
}
if self.check_phrase_positions(&token_postings, row_position, slop) {
let dl = self
.doc_lengths
.get(&row_position)
.map(|e| *e.value() as f32)
.unwrap_or(1.0);
let mut score: f32 = 0.0;
for (token_idx, token) in tokens.iter().enumerate() {
let df = self
.doc_freq
.get(token)
.map(|e| e.value().load(Ordering::Relaxed))
.unwrap_or(1) as f32;
let tf = token_postings[token_idx]
.get(&row_position)
.map(|p| p.frequency as f32)
.unwrap_or(1.0);
let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
let numerator = tf * (K1 + 1.0);
let denominator = tf + K1 * (1.0 - B + B * (dl / avgdl));
score += idf * (numerator / denominator);
}
matching_docs.push(FtsEntry {
row_position,
score,
});
}
}
matching_docs
}
fn check_phrase_positions(
&self,
token_postings: &[HashMap<RowPosition, PostingValue>],
row_position: RowPosition,
slop: u32,
) -> bool {
let mut all_positions: Vec<&Vec<u32>> = Vec::new();
for tp in token_postings {
if let Some(posting) = tp.get(&row_position) {
all_positions.push(&posting.positions);
} else {
return false;
}
}
for &first_pos in all_positions[0] {
if Self::check_phrase_from_position(&all_positions, first_pos, slop) {
return true;
}
}
false
}
fn check_phrase_from_position(all_positions: &[&Vec<u32>], first_pos: u32, slop: u32) -> bool {
let mut expected_pos = first_pos;
for positions in all_positions.iter().skip(1) {
let min_pos = expected_pos.saturating_add(1);
let max_pos = expected_pos.saturating_add(1 + slop);
if let Some(&actual_pos) = positions
.iter()
.filter(|&&pos| pos >= min_pos && pos <= max_pos)
.min()
{
expected_pos = actual_pos;
} else {
return false;
}
}
true
}
pub fn entry_count(&self) -> usize {
self.postings.len()
}
pub fn doc_count(&self) -> usize {
self.doc_count.load(Ordering::Relaxed)
}
pub fn is_empty(&self) -> bool {
self.doc_count.load(Ordering::Relaxed) == 0
}
pub fn column_name(&self) -> &str {
&self.column_name
}
pub fn expand_fuzzy(
&self,
term: &str,
max_distance: u32,
max_expansions: usize,
) -> Vec<(String, u32)> {
let mut matches: Vec<(String, u32)> = Vec::new();
if max_distance == 0 {
if self.doc_freq.get(term).is_some() {
matches.push((term.to_string(), 0));
}
return matches;
}
for entry in self.doc_freq.iter() {
let indexed_term = entry.key();
let distance = levenshtein_distance(term, indexed_term);
if distance <= max_distance {
matches.push((indexed_term.clone(), distance));
}
}
matches.sort_by_key(|(_, d)| *d);
matches.truncate(max_expansions);
matches
}
pub fn search_fuzzy(
&self,
query: &str,
fuzziness: Option<u32>,
max_expansions: usize,
) -> Vec<FtsEntry> {
let tokens: Vec<String> = {
let mut tokenizer = self.tokenizer.lock().unwrap();
let mut token_stream = tokenizer.token_stream_for_search(query);
let mut tokens = Vec::new();
while let Some(token) = token_stream.next() {
tokens.push(token.text.clone());
}
tokens
};
if tokens.is_empty() {
return vec![];
}
const K1: f32 = 1.2;
const B: f32 = 0.75;
let n = self.doc_count.load(Ordering::Relaxed) as f32;
let total_tokens = self.total_tokens.load(Ordering::Relaxed) as f32;
let avgdl = if n > 0.0 { total_tokens / n } else { 1.0 };
let mut doc_term_info: HashMap<RowPosition, Vec<(u32, usize)>> = HashMap::new();
for token in &tokens {
let max_distance = fuzziness.unwrap_or_else(|| auto_fuzziness(token));
let expanded = self.expand_fuzzy(token, max_distance, max_expansions);
for (matched_term, _distance) in expanded {
let df = self
.doc_freq
.get(&matched_term)
.map(|e| e.value().load(Ordering::Relaxed))
.unwrap_or(0);
if df == 0 {
continue;
}
let start = FtsKey {
token: matched_term.clone(),
row_position: 0,
};
let end = FtsKey {
token: matched_term,
row_position: u64::MAX,
};
for entry in self.postings.range(start..=end) {
doc_term_info
.entry(entry.key().row_position)
.or_default()
.push((entry.value().frequency, df));
}
}
}
doc_term_info
.into_iter()
.map(|(row_position, term_infos)| {
let dl = self
.doc_lengths
.get(&row_position)
.map(|e| *e.value() as f32)
.unwrap_or(1.0);
let mut score: f32 = 0.0;
for (tf, df) in term_infos {
let df_f = df as f32;
let idf = ((n - df_f + 0.5) / (df_f + 0.5) + 1.0).ln();
let tf_f = tf as f32;
let numerator = tf_f * (K1 + 1.0);
let denominator = tf_f + K1 * (1.0 - B + B * (dl / avgdl));
score += idf * (numerator / denominator);
}
FtsEntry {
row_position,
score,
}
})
.collect()
}
pub fn search_query(&self, query: &FtsQueryExpr) -> Vec<FtsEntry> {
match query {
FtsQueryExpr::Match { query, boost } => {
let mut results = self.search(query);
if *boost != 1.0 {
for entry in &mut results {
entry.score *= boost;
}
}
results
}
FtsQueryExpr::Phrase { query, slop, boost } => {
let mut results = self.search_phrase(query, *slop);
if *boost != 1.0 {
for entry in &mut results {
entry.score *= boost;
}
}
results
}
FtsQueryExpr::Fuzzy {
query,
fuzziness,
max_expansions,
boost,
} => {
let mut results = self.search_fuzzy(query, *fuzziness, *max_expansions);
if *boost != 1.0 {
for entry in &mut results {
entry.score *= boost;
}
}
results
}
FtsQueryExpr::Boolean {
must,
should,
must_not,
} => self.search_boolean(must, should, must_not),
FtsQueryExpr::Boost {
positive,
negative,
negative_boost,
} => self.search_boost(positive, negative.as_deref(), *negative_boost),
}
}
pub fn search_with_options(
&self,
query: &FtsQueryExpr,
options: SearchOptions,
) -> Vec<FtsEntry> {
let mut results = self.search_query(query);
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
if options.wand_factor < 1.0 {
if let Some(limit) = options.limit {
if results.len() > limit {
let top_k_score = results[limit - 1].score;
let threshold = top_k_score * options.wand_factor;
results.retain(|e| e.score >= threshold);
}
} else {
if let Some(max_entry) = results.first() {
let threshold = max_entry.score * options.wand_factor;
results.retain(|e| e.score >= threshold);
}
}
}
if let Some(limit) = options.limit {
results.truncate(limit);
}
results
}
fn search_boost(
&self,
positive: &FtsQueryExpr,
negative: Option<&FtsQueryExpr>,
negative_boost: f32,
) -> Vec<FtsEntry> {
let mut results = self.search_query(positive);
let Some(neg_query) = negative else {
return results;
};
let negative_results = self.search_query(neg_query);
let negative_positions: std::collections::HashSet<RowPosition> =
negative_results.iter().map(|e| e.row_position).collect();
for entry in &mut results {
if negative_positions.contains(&entry.row_position) {
entry.score *= negative_boost;
}
}
results
}
fn search_boolean(
&self,
must: &[FtsQueryExpr],
should: &[FtsQueryExpr],
must_not: &[FtsQueryExpr],
) -> Vec<FtsEntry> {
let excluded: std::collections::HashSet<RowPosition> = must_not
.iter()
.flat_map(|q| self.search_query(q))
.map(|e| e.row_position)
.collect();
let mut result_map: HashMap<RowPosition, f32> = if must.is_empty() {
let mut map = HashMap::new();
for q in should {
for entry in self.search_query(q) {
*map.entry(entry.row_position).or_default() += entry.score;
}
}
map
} else {
let first_results = self.search_query(&must[0]);
let mut map: HashMap<RowPosition, f32> = first_results
.into_iter()
.map(|e| (e.row_position, e.score))
.collect();
for q in must.iter().skip(1) {
let results = self.search_query(q);
let result_set: HashMap<RowPosition, f32> = results
.into_iter()
.map(|e| (e.row_position, e.score))
.collect();
map = map
.into_iter()
.filter_map(|(pos, score)| result_set.get(&pos).map(|s| (pos, score + s)))
.collect();
}
for q in should {
for entry in self.search_query(q) {
if let Some(score) = map.get_mut(&entry.row_position) {
*score += entry.score;
}
}
}
map
};
for pos in &excluded {
result_map.remove(pos);
}
result_map
.into_iter()
.map(|(row_position, score)| FtsEntry {
row_position,
score,
})
.collect()
}
pub fn to_index_builder_reversed(
&self,
partition_id: u64,
total_rows: usize,
) -> Result<lance_index::scalar::inverted::builder::InnerBuilder> {
use lance_index::scalar::inverted::builder::{InnerBuilder, PositionRecorder};
use lance_index::scalar::inverted::{DocSet, PostingListBuilder, TokenSet};
if self.is_empty() {
return Ok(InnerBuilder::new(
partition_id,
self.params.has_positions(),
Default::default(),
));
}
let total_rows_u64 = total_rows as u64;
let with_position = self.params.has_positions();
let mut doc_entries: Vec<(u64, u32)> = self
.doc_lengths
.iter()
.map(|e| {
let original_pos = *e.key();
let reversed_pos = total_rows_u64 - original_pos - 1;
(reversed_pos, *e.value())
})
.collect();
doc_entries.sort_by_key(|(pos, _)| *pos);
let mut docs = DocSet::default();
let mut reversed_pos_to_doc_id: HashMap<u64, u32> =
HashMap::with_capacity(doc_entries.len());
for (idx, (reversed_pos, num_tokens)) in doc_entries.into_iter().enumerate() {
docs.append(reversed_pos, num_tokens);
reversed_pos_to_doc_id.insert(reversed_pos, idx as u32);
}
let mut tokens = TokenSet::default();
let mut token_postings: HashMap<String, Vec<(u32, PostingValue)>> = HashMap::new();
for entry in self.postings.iter() {
let token = entry.key().token.clone();
let original_pos = entry.key().row_position;
let reversed_pos = total_rows_u64 - original_pos - 1;
let doc_id = *reversed_pos_to_doc_id.get(&reversed_pos).ok_or_else(|| {
Error::io(format!(
"FTS index internal error: doc_id not found for reversed position {} (original: {}, total_rows: {})",
reversed_pos, original_pos, total_rows
))
})?;
token_postings
.entry(token)
.or_default()
.push((doc_id, entry.value().clone()));
}
let mut sorted_tokens: Vec<_> = token_postings.keys().cloned().collect();
sorted_tokens.sort();
for token in &sorted_tokens {
tokens.add(token.clone());
}
let mut posting_lists: Vec<PostingListBuilder> = (0..tokens.len())
.map(|_| PostingListBuilder::new(with_position))
.collect();
for (token, mut postings) in token_postings {
let token_id = tokens.get(&token).ok_or_else(|| {
Error::io(format!(
"FTS index internal error: token '{}' not found in TokenSet",
token
))
})? as usize;
postings.sort_by_key(|(doc_id, _)| *doc_id);
for (doc_id, value) in postings {
let position_recorder = if with_position {
PositionRecorder::Position(value.positions.into())
} else {
PositionRecorder::Count(value.frequency)
};
posting_lists[token_id].add(doc_id, position_recorder);
}
}
let mut builder = InnerBuilder::new(partition_id, with_position, Default::default());
builder.set_tokens(tokens);
builder.set_docs(docs);
builder.set_posting_lists(posting_lists);
Ok(builder)
}
}
#[derive(Debug, Clone)]
pub struct FtsIndexConfig {
pub name: String,
pub field_id: i32,
pub column: String,
pub params: InvertedIndexParams,
}
impl FtsIndexConfig {
pub fn new(name: String, field_id: i32, column: String) -> Self {
Self {
name,
field_id,
column,
params: InvertedIndexParams::default(),
}
}
pub fn with_params(
name: String,
field_id: i32,
column: String,
params: InvertedIndexParams,
) -> Self {
Self {
name,
field_id,
column,
params,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::{Int32Array, StringArray};
use arrow_schema::{DataType, Field, Schema as ArrowSchema};
use std::sync::Arc;
fn create_test_schema() -> Arc<ArrowSchema> {
Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("description", DataType::Utf8, true),
]))
}
fn create_test_batch(schema: &ArrowSchema) -> RecordBatch {
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![0, 1, 2])),
Arc::new(StringArray::from(vec![
"hello world",
"goodbye world",
"hello again",
])),
],
)
.unwrap()
}
#[test]
fn test_fts_index_insert_and_search() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_test_batch(&schema);
index.insert(&batch, 0).unwrap();
assert_eq!(index.doc_count(), 3);
let entries = index.search("hello");
assert!(!entries.is_empty());
assert_eq!(entries.len(), 2);
let entries = index.search("world");
assert!(!entries.is_empty());
assert_eq!(entries.len(), 2);
let entries = index.search("goodbye");
assert!(!entries.is_empty());
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].row_position, 1);
let entries = index.search("nonexistent");
assert!(entries.is_empty());
}
fn create_phrase_test_batch(schema: &ArrowSchema) -> RecordBatch {
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4])),
Arc::new(StringArray::from(vec![
"alpha beta gamma", "beta alpha gamma", "alpha delta beta gamma", "alpha gamma", "alpha delta epsilon beta gamma", ])),
],
)
.unwrap()
}
#[test]
fn test_phrase_search_exact_match() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_phrase_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let entries = index.search_phrase("alpha beta", 0);
assert_eq!(
entries.len(),
1,
"Expected 1 match for 'alpha beta', got {:?}",
entries.iter().map(|e| e.row_position).collect::<Vec<_>>()
);
assert_eq!(entries[0].row_position, 0);
let batch2 = create_test_batch(&schema);
let index2 = FtsMemIndex::new(1, "description".to_string());
index2.insert(&batch2, 0).unwrap();
let entries = index2.search_phrase("hello world", 0);
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].row_position, 0);
let entries = index2.search_phrase("goodbye world", 0);
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].row_position, 1);
}
#[test]
fn test_phrase_search_with_slop() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_phrase_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let entries = index.search_phrase("alpha beta", 0);
assert_eq!(
entries.len(),
1,
"slop=0 matches: {:?}",
entries.iter().map(|e| e.row_position).collect::<Vec<_>>()
);
assert_eq!(entries[0].row_position, 0);
let entries = index.search_phrase("alpha beta", 1);
assert_eq!(
entries.len(),
2,
"slop=1 matches: {:?}",
entries.iter().map(|e| e.row_position).collect::<Vec<_>>()
);
let positions: Vec<_> = entries.iter().map(|e| e.row_position).collect();
assert!(positions.contains(&0));
assert!(positions.contains(&2));
let entries = index.search_phrase("alpha beta", 2);
assert_eq!(
entries.len(),
3,
"slop=2 matches: {:?}",
entries.iter().map(|e| e.row_position).collect::<Vec<_>>()
);
let entries = index.search_phrase("alpha gamma", 0);
assert_eq!(
entries.len(),
2,
"alpha gamma slop=0: {:?}",
entries.iter().map(|e| e.row_position).collect::<Vec<_>>()
);
let entries = index.search_phrase("alpha gamma", 1);
assert_eq!(
entries.len(),
3,
"alpha gamma slop=1: {:?}",
entries.iter().map(|e| e.row_position).collect::<Vec<_>>()
);
}
#[test]
fn test_phrase_search_no_match() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_phrase_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let entries = index.search_phrase("beta alpha", 0);
assert_eq!(entries.len(), 1); assert_eq!(entries[0].row_position, 1);
let entries = index.search_phrase("nonexistent phrase", 0);
assert!(entries.is_empty());
let entries = index.search_phrase("alpha hello", 0);
assert!(entries.is_empty());
let entries = index.search_phrase("gamma alpha", 0);
assert!(entries.is_empty());
}
#[test]
fn test_phrase_search_single_token() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_phrase_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let phrase_entries = index.search_phrase("alpha", 0);
let search_entries = index.search("alpha");
assert_eq!(phrase_entries.len(), search_entries.len());
}
#[test]
fn test_phrase_search_empty() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let entries = index.search_phrase("", 0);
assert!(entries.is_empty());
}
fn create_boolean_test_batch(schema: &ArrowSchema) -> RecordBatch {
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4])),
Arc::new(StringArray::from(vec![
"rust programming language",
"python programming language",
"rust web server",
"python web framework",
"javascript programming",
])),
],
)
.unwrap()
}
#[test]
fn test_boolean_must_only() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_boolean_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::boolean()
.must(FtsQueryExpr::match_query("rust"))
.must(FtsQueryExpr::match_query("programming"))
.build();
let entries = index.search_query(&query);
assert_eq!(
entries.len(),
1,
"Expected 1 match for MUST(rust, programming), got {:?}",
entries.iter().map(|e| e.row_position).collect::<Vec<_>>()
);
assert_eq!(entries[0].row_position, 0);
}
#[test]
fn test_boolean_should_only() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_boolean_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::boolean()
.should(FtsQueryExpr::match_query("rust"))
.should(FtsQueryExpr::match_query("python"))
.build();
let entries = index.search_query(&query);
assert_eq!(
entries.len(),
4,
"Expected 4 matches for SHOULD(rust, python), got {:?}",
entries.iter().map(|e| e.row_position).collect::<Vec<_>>()
);
let positions: Vec<_> = entries.iter().map(|e| e.row_position).collect();
assert!(positions.contains(&0));
assert!(positions.contains(&1));
assert!(positions.contains(&2));
assert!(positions.contains(&3));
}
#[test]
fn test_boolean_must_not_only() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_boolean_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::boolean()
.must_not(FtsQueryExpr::match_query("rust"))
.build();
let entries = index.search_query(&query);
assert!(
entries.is_empty(),
"MUST_NOT only should return empty, got {:?}",
entries.iter().map(|e| e.row_position).collect::<Vec<_>>()
);
}
#[test]
fn test_boolean_must_with_should() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_boolean_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::boolean()
.must(FtsQueryExpr::match_query("programming"))
.should(FtsQueryExpr::match_query("rust"))
.build();
let entries = index.search_query(&query);
assert_eq!(
entries.len(),
3,
"Expected 3 matches for MUST(programming) SHOULD(rust), got {:?}",
entries.iter().map(|e| e.row_position).collect::<Vec<_>>()
);
let doc0 = entries.iter().find(|e| e.row_position == 0).unwrap();
let doc1 = entries.iter().find(|e| e.row_position == 1).unwrap();
assert!(
doc0.score > doc1.score,
"Doc 0 (rust+programming) should score higher than doc 1 (programming only). Doc0: {}, Doc1: {}",
doc0.score,
doc1.score
);
}
#[test]
fn test_boolean_must_with_must_not() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_boolean_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::boolean()
.must(FtsQueryExpr::match_query("programming"))
.must_not(FtsQueryExpr::match_query("python"))
.build();
let entries = index.search_query(&query);
assert_eq!(
entries.len(),
2,
"Expected 2 matches for MUST(programming) MUST_NOT(python), got {:?}",
entries.iter().map(|e| e.row_position).collect::<Vec<_>>()
);
let positions: Vec<_> = entries.iter().map(|e| e.row_position).collect();
assert!(positions.contains(&0)); assert!(positions.contains(&4)); assert!(!positions.contains(&1)); }
#[test]
fn test_boolean_combined() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_boolean_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::boolean()
.must(FtsQueryExpr::match_query("web"))
.should(FtsQueryExpr::match_query("rust"))
.must_not(FtsQueryExpr::match_query("framework"))
.build();
let entries = index.search_query(&query);
assert_eq!(
entries.len(),
1,
"Expected 1 match for MUST(web) SHOULD(rust) MUST_NOT(framework), got {:?}",
entries.iter().map(|e| e.row_position).collect::<Vec<_>>()
);
assert_eq!(entries[0].row_position, 2);
}
#[test]
fn test_boolean_nested_phrase() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_boolean_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::boolean()
.must(FtsQueryExpr::phrase("programming language"))
.build();
let entries = index.search_query(&query);
assert_eq!(
entries.len(),
2,
"Expected 2 matches for MUST(phrase 'programming language'), got {:?}",
entries.iter().map(|e| e.row_position).collect::<Vec<_>>()
);
let positions: Vec<_> = entries.iter().map(|e| e.row_position).collect();
assert!(positions.contains(&0));
assert!(positions.contains(&1));
}
#[test]
fn test_search_query_match() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::match_query("hello");
let entries = index.search_query(&query);
assert_eq!(entries.len(), 2);
}
#[test]
fn test_search_query_phrase() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::phrase("hello world");
let entries = index.search_query(&query);
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].row_position, 0);
}
#[test]
fn test_search_query_with_boost() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query_no_boost = FtsQueryExpr::match_query("hello");
let query_with_boost = FtsQueryExpr::match_query("hello").with_boost(2.0);
let entries_no_boost = index.search_query(&query_no_boost);
let entries_with_boost = index.search_query(&query_with_boost);
assert_eq!(entries_no_boost.len(), entries_with_boost.len());
for (e1, e2) in entries_no_boost.iter().zip(entries_with_boost.iter()) {
let expected = e1.score * 2.0;
assert!(
(e2.score - expected).abs() < 0.001,
"Boosted score {} should be 2x original {}",
e2.score,
e1.score
);
}
}
#[test]
fn test_levenshtein_distance() {
assert_eq!(levenshtein_distance("hello", "hello"), 0);
assert_eq!(levenshtein_distance("hello", "hallo"), 1); assert_eq!(levenshtein_distance("hello", "hell"), 1); assert_eq!(levenshtein_distance("hello", "helloo"), 1);
assert_eq!(levenshtein_distance("hello", "hxllo"), 1);
assert_eq!(levenshtein_distance("hello", "hxxlo"), 2);
assert_eq!(levenshtein_distance("abc", "xyz"), 3);
assert_eq!(levenshtein_distance("", ""), 0);
assert_eq!(levenshtein_distance("hello", ""), 5);
assert_eq!(levenshtein_distance("", "hello"), 5);
assert_eq!(levenshtein_distance("Hello", "hello"), 1);
}
#[test]
fn test_auto_fuzziness() {
assert_eq!(auto_fuzziness(""), 0);
assert_eq!(auto_fuzziness("a"), 0);
assert_eq!(auto_fuzziness("ab"), 0);
assert_eq!(auto_fuzziness("abc"), 1);
assert_eq!(auto_fuzziness("abcd"), 1);
assert_eq!(auto_fuzziness("abcde"), 1);
assert_eq!(auto_fuzziness("abcdef"), 2);
assert_eq!(auto_fuzziness("programming"), 2);
}
fn create_fuzzy_test_batch(schema: &ArrowSchema) -> RecordBatch {
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4])),
Arc::new(StringArray::from(vec![
"alpha beta gamma",
"alpho beta delta",
"alpha delta epsilon",
"omega zeta",
"alphax gamma",
])),
],
)
.unwrap()
}
#[test]
fn test_expand_fuzzy_exact_match() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_fuzzy_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let matches = index.expand_fuzzy("alpha", 0, 50);
assert_eq!(
matches.len(),
1,
"Expected 1 match for 'alpha', got {:?}",
matches
);
assert_eq!(matches[0].0, "alpha");
assert_eq!(matches[0].1, 0);
let matches = index.expand_fuzzy("nonexistent", 0, 50);
assert!(matches.is_empty());
}
#[test]
fn test_expand_fuzzy_single_edit() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_fuzzy_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let matches = index.expand_fuzzy("alpho", 1, 50);
assert!(
matches
.iter()
.any(|(term, dist)| term == "alpha" && *dist == 1),
"Expected 'alpha' with distance 1, got {:?}",
matches
);
assert!(
matches.iter().any(|(term, _)| term == "alpho"),
"Expected 'alpho' in matches, got {:?}",
matches
);
}
#[test]
fn test_expand_fuzzy_max_expansions() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_fuzzy_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let matches = index.expand_fuzzy("a", 10, 3);
assert!(
matches.len() <= 3,
"Expected at most 3 matches, got {}",
matches.len()
);
}
#[test]
fn test_search_fuzzy_basic() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_fuzzy_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let entries = index.search_fuzzy("alpho", Some(1), 50);
assert!(!entries.is_empty(), "Expected matches for fuzzy 'alpho'");
let positions: Vec<_> = entries.iter().map(|e| e.row_position).collect();
assert!(
positions.contains(&0) || positions.contains(&1) || positions.contains(&2),
"Expected to match docs with alpha/alpho, got {:?}",
positions
);
}
#[test]
fn test_search_fuzzy_auto_fuzziness() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_fuzzy_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let entries = index.search_fuzzy("alpho", None, 50);
assert!(!entries.is_empty(), "Expected matches with auto-fuzziness");
}
#[test]
fn test_search_fuzzy_no_match() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_fuzzy_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let entries = index.search_fuzzy("xyz", Some(0), 50);
assert!(entries.is_empty(), "Expected no matches for 'xyz'");
let _ = index.search_fuzzy("xyz", Some(1), 50);
}
#[test]
fn test_search_query_fuzzy() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_fuzzy_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::fuzzy("alpho");
let entries = index.search_query(&query);
assert!(
!entries.is_empty(),
"Expected matches for fuzzy query 'alpho'"
);
}
#[test]
fn test_search_query_fuzzy_with_distance() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_fuzzy_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::fuzzy_with_distance("alpho", 1);
let entries = index.search_query(&query);
assert!(
!entries.is_empty(),
"Expected matches for fuzzy query with distance 1"
);
}
#[test]
fn test_search_query_fuzzy_with_boost() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_fuzzy_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query_no_boost = FtsQueryExpr::fuzzy("alpho");
let query_with_boost = FtsQueryExpr::fuzzy("alpho").with_boost(2.0);
let entries_no_boost = index.search_query(&query_no_boost);
let entries_with_boost = index.search_query(&query_with_boost);
assert_eq!(entries_no_boost.len(), entries_with_boost.len());
for e1 in &entries_no_boost {
let e2 = entries_with_boost
.iter()
.find(|e| e.row_position == e1.row_position)
.unwrap();
let expected = e1.score * 2.0;
assert!(
(e2.score - expected).abs() < 0.001,
"Boosted score {} should be 2x original {}",
e2.score,
e1.score
);
}
}
#[test]
fn test_boolean_with_fuzzy() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_fuzzy_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::boolean()
.must(FtsQueryExpr::fuzzy_with_distance("alpho", 1))
.must_not(FtsQueryExpr::match_query("delta"))
.build();
let entries = index.search_query(&query);
let positions: Vec<_> = entries.iter().map(|e| e.row_position).collect();
assert!(
!positions.contains(&1),
"Doc 1 should be excluded due to MUST_NOT, got {:?}",
positions
);
assert!(
!positions.contains(&2),
"Doc 2 should be excluded due to MUST_NOT, got {:?}",
positions
);
assert!(
positions.contains(&0),
"Doc 0 should be included, got {:?}",
positions
);
}
fn create_boost_test_batch(schema: &ArrowSchema) -> RecordBatch {
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4])),
Arc::new(StringArray::from(vec![
"rust programming language",
"python programming language",
"rust web server",
"python web framework",
"javascript programming",
])),
],
)
.unwrap()
}
#[test]
fn test_boost_query_positive_only() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_boost_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::boosting(FtsQueryExpr::match_query("programming"));
let entries = index.search_query(&query);
assert_eq!(
entries.len(),
3,
"Expected 3 matches for 'programming', got {:?}",
entries.iter().map(|e| e.row_position).collect::<Vec<_>>()
);
}
#[test]
fn test_boost_query_with_negative() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_boost_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::boosting_with_negative(
FtsQueryExpr::match_query("programming"),
FtsQueryExpr::match_query("python"),
0.5, );
let entries = index.search_query(&query);
assert_eq!(entries.len(), 3);
let doc0 = entries.iter().find(|e| e.row_position == 0); let doc1 = entries.iter().find(|e| e.row_position == 1); let doc4 = entries.iter().find(|e| e.row_position == 4);
assert!(doc0.is_some() && doc1.is_some() && doc4.is_some());
let score0 = doc0.unwrap().score;
let score1 = doc1.unwrap().score;
let score4 = doc4.unwrap().score;
assert!(
score1 < score0,
"Doc 1 (python) should have lower score than doc 0 (rust). Doc0: {}, Doc1: {}",
score0,
score1
);
assert!(
score1 < score4,
"Doc 1 (python) should have lower score than doc 4 (javascript). Doc1: {}, Doc4: {}",
score1,
score4
);
}
#[test]
fn test_boost_query_negative_boost_factor() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_boost_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query_no_demote = FtsQueryExpr::boosting_with_negative(
FtsQueryExpr::match_query("programming"),
FtsQueryExpr::match_query("python"),
1.0, );
let query_half_demote = FtsQueryExpr::boosting_with_negative(
FtsQueryExpr::match_query("programming"),
FtsQueryExpr::match_query("python"),
0.5, );
let query_zero_demote = FtsQueryExpr::boosting_with_negative(
FtsQueryExpr::match_query("programming"),
FtsQueryExpr::match_query("python"),
0.0, );
let results_no_demote = index.search_query(&query_no_demote);
let results_half_demote = index.search_query(&query_half_demote);
let results_zero_demote = index.search_query(&query_zero_demote);
let score_no_demote = results_no_demote
.iter()
.find(|e| e.row_position == 1)
.unwrap()
.score;
let score_half_demote = results_half_demote
.iter()
.find(|e| e.row_position == 1)
.unwrap()
.score;
let score_zero_demote = results_zero_demote
.iter()
.find(|e| e.row_position == 1)
.unwrap()
.score;
assert!(
(score_half_demote - score_no_demote * 0.5).abs() < 0.001,
"Half demotion should give half score. Expected {}, got {}",
score_no_demote * 0.5,
score_half_demote
);
assert!(
score_zero_demote.abs() < 0.001,
"Zero demotion should give zero score, got {}",
score_zero_demote
);
}
#[test]
fn test_boost_query_no_negative_match() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_boost_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::boosting_with_negative(
FtsQueryExpr::match_query("rust"), FtsQueryExpr::match_query("python"), 0.1,
);
let entries = index.search_query(&query);
assert_eq!(entries.len(), 2);
let query_baseline = FtsQueryExpr::match_query("rust");
let baseline_entries = index.search_query(&query_baseline);
for entry in &entries {
let baseline = baseline_entries
.iter()
.find(|e| e.row_position == entry.row_position)
.unwrap();
assert!(
(entry.score - baseline.score).abs() < 0.001,
"Scores should match when no negative overlap. Got {} vs {}",
entry.score,
baseline.score
);
}
}
#[test]
fn test_boost_query_nested() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_boost_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let positive_query = FtsQueryExpr::boolean()
.should(FtsQueryExpr::match_query("programming"))
.should(FtsQueryExpr::match_query("web"))
.build();
let query = FtsQueryExpr::boosting_with_negative(
positive_query,
FtsQueryExpr::match_query("python"),
0.5,
);
let entries = index.search_query(&query);
assert!(entries.len() >= 4, "Should match multiple docs");
let python_docs: Vec<_> = entries
.iter()
.filter(|e| e.row_position == 1 || e.row_position == 3)
.collect();
let non_python_docs: Vec<_> = entries
.iter()
.filter(|e| e.row_position != 1 && e.row_position != 3)
.collect();
if !python_docs.is_empty() && !non_python_docs.is_empty() {
let max_python_score = python_docs.iter().map(|e| e.score).fold(0.0f32, f32::max);
let max_non_python_score = non_python_docs
.iter()
.map(|e| e.score)
.fold(0.0f32, f32::max);
assert!(
python_docs.iter().any(|e| e.score < max_non_python_score)
|| max_python_score <= max_non_python_score,
"Python docs should generally have lower scores"
);
}
}
#[test]
fn test_search_options_default() {
let options = SearchOptions::default();
assert_eq!(options.wand_factor, 1.0);
assert!(options.limit.is_none());
}
#[test]
fn test_search_options_builder() {
let options = SearchOptions::new().with_wand_factor(0.5).with_limit(10);
assert_eq!(options.wand_factor, 0.5);
assert_eq!(options.limit, Some(10));
}
#[test]
fn test_search_options_wand_factor_clamped() {
let options = SearchOptions::new().with_wand_factor(2.0);
assert_eq!(options.wand_factor, 1.0);
let options = SearchOptions::new().with_wand_factor(-0.5);
assert_eq!(options.wand_factor, 0.0);
}
fn create_wand_test_batch(schema: &ArrowSchema) -> RecordBatch {
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4])),
Arc::new(StringArray::from(vec![
"alpha alpha alpha beta",
"alpha beta gamma",
"beta gamma delta",
"alpha alpha",
"alpha",
])),
],
)
.unwrap()
}
#[test]
fn test_search_with_options_full_recall() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_wand_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::match_query("alpha");
let options = SearchOptions::default();
let results = index.search_with_options(&query, options);
assert_eq!(results.len(), 4, "Expected 4 matches with full recall");
for i in 1..results.len() {
assert!(
results[i - 1].score >= results[i].score,
"Results should be sorted by score descending"
);
}
}
#[test]
fn test_search_with_options_with_limit() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_wand_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::match_query("alpha");
let options = SearchOptions::new().with_limit(2);
let results = index.search_with_options(&query, options);
assert_eq!(results.len(), 2, "Expected 2 matches with limit=2");
let full_results = index.search_query(&query);
let mut full_sorted = full_results;
full_sorted.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
assert_eq!(
results[0].row_position, full_sorted[0].row_position,
"First result should be highest scorer"
);
assert_eq!(
results[1].row_position, full_sorted[1].row_position,
"Second result should be second highest scorer"
);
}
#[test]
fn test_search_with_options_wand_factor_pruning() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_wand_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::match_query("alpha");
let full_results = index.search_query(&query);
let mut full_sorted = full_results.clone();
full_sorted.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
let options = SearchOptions::new().with_wand_factor(0.5);
let results = index.search_with_options(&query, options);
if !results.is_empty() {
let max_score = full_sorted[0].score;
let threshold = max_score * 0.5;
for result in &results {
assert!(
result.score >= threshold - 0.001, "With wand_factor=0.5, all results should score >= {} but got {}",
threshold,
result.score
);
}
assert!(
results.len() <= full_results.len(),
"Pruned results should not exceed full results"
);
}
}
#[test]
fn test_search_with_options_wand_factor_with_limit() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_wand_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::match_query("alpha");
let full_results = index.search_query(&query);
assert!(
full_results.len() >= 3,
"Need at least 3 results for this test"
);
let options = SearchOptions::new().with_limit(2).with_wand_factor(0.5);
let results = index.search_with_options(&query, options);
assert!(results.len() <= 2, "Should not exceed limit");
if results.len() > 1 {
assert!(results[0].score >= results[1].score);
}
}
#[test]
fn test_search_with_options_empty_results() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_wand_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::match_query("nonexistent");
let options = SearchOptions::new().with_limit(10).with_wand_factor(0.5);
let results = index.search_with_options(&query, options);
assert!(
results.is_empty(),
"Should return empty for non-matching query"
);
}
#[test]
fn test_search_with_options_boolean_query() {
let schema = create_test_schema();
let index = FtsMemIndex::new(1, "description".to_string());
let batch = create_wand_test_batch(&schema);
index.insert(&batch, 0).unwrap();
let query = FtsQueryExpr::boolean()
.should(FtsQueryExpr::match_query("alpha"))
.should(FtsQueryExpr::match_query("beta"))
.build();
let options = SearchOptions::new().with_limit(3);
let results = index.search_with_options(&query, options);
assert!(results.len() <= 3, "Should not exceed limit");
for i in 1..results.len() {
assert!(results[i - 1].score >= results[i].score);
}
}
}