pub mod builder;
mod encoding;
mod index;
mod iter;
pub mod json;
pub mod parser;
pub mod query;
mod scorer;
pub mod tokenizer;
mod wand;
use std::sync::Arc;
use arrow_schema::{DataType, Field};
use async_trait::async_trait;
pub use builder::InvertedIndexBuilder;
use datafusion::execution::SendableRecordBatchStream;
pub use index::*;
use lance_core::{Result, cache::LanceCache};
pub use lance_tokenizer::Language;
pub use scorer::MemBM25Scorer;
pub use tokenizer::*;
use lance_core::Error;
use crate::pbold;
use crate::progress::IndexBuildProgress;
use crate::{
frag_reuse::FragReuseIndex,
scalar::{
CreatedIndex, ScalarIndex,
expression::{FtsQueryParser, ScalarQueryParser},
registry::{ScalarIndexPlugin, TrainingCriteria, TrainingOrdering, TrainingRequest},
},
};
use super::IndexStore;
#[derive(Debug, Default)]
pub struct InvertedIndexPlugin;
impl InvertedIndexPlugin {
pub async fn train_inverted_index(
data: SendableRecordBatchStream,
index_store: &dyn IndexStore,
params: InvertedIndexParams,
fragment_ids: Option<Vec<u32>>,
progress: Arc<dyn IndexBuildProgress>,
) -> Result<CreatedIndex> {
let fragment_mask = fragment_ids.as_ref().and_then(|frag_ids| {
if !frag_ids.is_empty() {
Some((frag_ids[0] as u64) << 32)
} else {
None
}
});
let details = pbold::InvertedIndexDetails::try_from(¶ms)?;
let mut inverted_index =
InvertedIndexBuilder::new_with_fragment_mask(params, fragment_mask)
.with_progress(progress);
inverted_index.update(data, index_store, None).await?;
Ok(CreatedIndex {
index_details: prost_types::Any::from_msg(&details).unwrap(),
index_version: current_fts_format_version().index_version(),
files: Some(index_store.list_files_with_sizes().await?),
})
}
fn can_accelerate_queries(details: &pbold::InvertedIndexDetails) -> bool {
details.base_tokenizer == Some("simple".to_string())
&& details.max_token_length.is_none()
&& details.language == serde_json::to_string(&Language::English).unwrap()
&& !details.stem
}
}
struct InvertedIndexTrainingRequest {
parameters: InvertedIndexParams,
criteria: TrainingCriteria,
}
impl InvertedIndexTrainingRequest {
pub fn new(parameters: InvertedIndexParams) -> Self {
Self {
parameters,
criteria: TrainingCriteria::new(TrainingOrdering::None).with_row_id(),
}
}
}
impl TrainingRequest for InvertedIndexTrainingRequest {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn criteria(&self) -> &TrainingCriteria {
&self.criteria
}
}
#[async_trait]
impl ScalarIndexPlugin for InvertedIndexPlugin {
fn name(&self) -> &str {
"Inverted"
}
fn new_training_request(
&self,
params: &str,
field: &Field,
) -> Result<Box<dyn TrainingRequest>> {
match field.data_type() {
DataType::Utf8 | DataType::LargeUtf8 | DataType::LargeBinary => (),
DataType::List(f) if matches!(f.data_type(), DataType::Utf8 | DataType::LargeUtf8) => (),
DataType::LargeList(f) if matches!(f.data_type(), DataType::Utf8 | DataType::LargeUtf8) => (),
_ => return Err(Error::invalid_input_source(format!(
"A inverted index can only be created on a Utf8 or LargeUtf8 field/list or LargeBinary field. Column has type {:?}",
field.data_type()
)
.into()))
}
let params = serde_json::from_str::<InvertedIndexParams>(params)?;
Ok(Box::new(InvertedIndexTrainingRequest::new(params)))
}
fn provides_exact_answer(&self) -> bool {
false
}
fn version(&self) -> u32 {
max_supported_fts_format_version().index_version()
}
fn new_query_parser(
&self,
index_name: String,
_index_details: &prost_types::Any,
) -> Option<Box<dyn ScalarQueryParser>> {
let Ok(index_details) = _index_details.to_msg::<pbold::InvertedIndexDetails>() else {
return None;
};
if Self::can_accelerate_queries(&index_details) {
Some(Box::new(FtsQueryParser::new(index_name)))
} else {
None
}
}
async fn train_index(
&self,
data: SendableRecordBatchStream,
index_store: &dyn IndexStore,
request: Box<dyn TrainingRequest>,
fragment_ids: Option<Vec<u32>>,
progress: Arc<dyn IndexBuildProgress>,
) -> Result<CreatedIndex> {
let request = (request as Box<dyn std::any::Any>)
.downcast::<InvertedIndexTrainingRequest>()
.map_err(|_| {
Error::invalid_input_source(
"must provide training request created by new_training_request".into(),
)
})?;
Self::train_inverted_index(
data,
index_store,
request.parameters.clone(),
fragment_ids,
progress,
)
.await
}
async fn load_index(
&self,
index_store: Arc<dyn IndexStore>,
_index_details: &prost_types::Any,
frag_reuse_index: Option<Arc<FragReuseIndex>>,
cache: &LanceCache,
) -> Result<Arc<dyn ScalarIndex>> {
Ok(
InvertedIndex::load(index_store, frag_reuse_index, cache).await?
as Arc<dyn ScalarIndex>,
)
}
fn details_as_json(&self, details: &prost_types::Any) -> Result<serde_json::Value> {
let index_details = details.to_msg::<pbold::InvertedIndexDetails>()?;
let index_params = InvertedIndexParams::try_from(&index_details)?;
Ok(serde_json::json!(&index_params))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_plugin_version_tracks_max_supported_format() {
let plugin = InvertedIndexPlugin;
assert_eq!(
plugin.version(),
max_supported_fts_format_version().index_version()
);
}
}