1pub mod builder;
5mod cache_codec;
6mod encoding;
7mod index;
8mod iter;
9pub mod json;
10pub mod parser;
11pub mod query;
12mod scorer;
13pub mod tokenizer;
14mod wand;
15
16use std::collections::{HashMap, HashSet};
17use std::sync::Arc;
18
19use arrow_schema::{DataType, Field};
20use async_trait::async_trait;
21pub use builder::InvertedIndexBuilder;
22use datafusion::execution::SendableRecordBatchStream;
23pub use index::*;
24use lance_core::{Result, cache::LanceCache};
25pub use lance_tokenizer::Language;
26pub use scorer::{MemBM25Scorer, Scorer};
27pub use tokenizer::*;
28
29use crate::scalar::inverted::query::{FtsSearchParams, Tokens};
30
31fn scorer_terms(
39 indices: &[Arc<InvertedIndex>],
40 query_tokens: &Tokens,
41 params: &FtsSearchParams,
42) -> Result<Vec<String>> {
43 let mut terms = Vec::new();
44 let mut seen = HashSet::new();
45
46 if !matches!(params.fuzziness, Some(n) if n != 0) {
47 for token in query_tokens {
48 if seen.insert(token.to_string()) {
49 terms.push(token.to_string());
50 }
51 }
52 return Ok(terms);
53 }
54
55 for index in indices {
56 let expanded = index.expand_fuzzy_tokens(query_tokens, params)?;
57 for idx in 0..expanded.len() {
58 let token = expanded.get_token(idx);
59 if seen.insert(token.to_string()) {
60 terms.push(token.to_string());
61 }
62 }
63 }
64 Ok(terms)
65}
66
67pub fn build_global_bm25_scorer(
82 indices: &[Arc<InvertedIndex>],
83 query_tokens: &Tokens,
84 params: &FtsSearchParams,
85) -> Result<MemBM25Scorer> {
86 let terms = scorer_terms(indices, query_tokens, params)?;
87 let first_index = indices.first().ok_or_else(|| {
88 lance_core::Error::invalid_input("FTS index requires at least one segment")
89 })?;
90 let (mut total_tokens, mut num_docs, first_token_docs) =
91 first_index.bm25_stats_for_terms(&terms);
92 let mut token_docs = HashMap::with_capacity(terms.len());
93 for (term, count) in terms.iter().cloned().zip(first_token_docs.into_iter()) {
94 token_docs.insert(term, count);
95 }
96
97 for index in indices.iter().skip(1) {
98 let (segment_total_tokens, segment_num_docs, segment_token_docs) =
99 index.bm25_stats_for_terms(&terms);
100 total_tokens += segment_total_tokens;
101 num_docs += segment_num_docs;
102 for (term, count) in terms.iter().zip(segment_token_docs.into_iter()) {
103 *token_docs
104 .get_mut(term)
105 .expect("global scorer terms should already be initialized") += count;
106 }
107 }
108
109 Ok(MemBM25Scorer::new(total_tokens, num_docs, token_docs))
110}
111
112use lance_core::Error;
113
114use crate::pbold;
115use crate::progress::IndexBuildProgress;
116use crate::{
117 frag_reuse::FragReuseIndex,
118 scalar::{
119 CreatedIndex, ScalarIndex,
120 expression::{FtsQueryParser, ScalarQueryParser},
121 registry::{ScalarIndexPlugin, TrainingCriteria, TrainingOrdering, TrainingRequest},
122 },
123};
124
125use super::IndexStore;
126
127#[derive(Debug, Default)]
128pub struct InvertedIndexPlugin;
129
130impl InvertedIndexPlugin {
131 pub async fn train_inverted_index(
132 data: SendableRecordBatchStream,
133 index_store: &dyn IndexStore,
134 params: InvertedIndexParams,
135 fragment_ids: Option<Vec<u32>>,
136 progress: Arc<dyn IndexBuildProgress>,
137 ) -> Result<CreatedIndex> {
138 let fragment_mask = fragment_ids.as_ref().and_then(|frag_ids| {
139 if !frag_ids.is_empty() {
140 Some((frag_ids[0] as u64) << 32)
144 } else {
145 None
146 }
147 });
148
149 let details = pbold::InvertedIndexDetails::try_from(¶ms)?;
150 let mut inverted_index =
151 InvertedIndexBuilder::new_with_fragment_mask(params, fragment_mask)
152 .with_progress(progress);
153 inverted_index.update(data, index_store, None).await?;
154 Ok(CreatedIndex {
155 index_details: prost_types::Any::from_msg(&details).unwrap(),
156 index_version: current_fts_format_version().index_version(),
157 files: Some(index_store.list_files_with_sizes().await?),
158 })
159 }
160
161 fn can_accelerate_queries(details: &pbold::InvertedIndexDetails) -> bool {
163 details.base_tokenizer == Some("simple".to_string())
164 && details.max_token_length.is_none()
165 && details.language == serde_json::to_string(&Language::English).unwrap()
166 && !details.stem
167 }
168}
169
170struct InvertedIndexTrainingRequest {
171 parameters: InvertedIndexParams,
172 criteria: TrainingCriteria,
173}
174
175impl InvertedIndexTrainingRequest {
176 pub fn new(parameters: InvertedIndexParams) -> Self {
177 Self {
178 parameters,
179 criteria: TrainingCriteria::new(TrainingOrdering::None).with_row_id(),
180 }
181 }
182}
183
184impl TrainingRequest for InvertedIndexTrainingRequest {
185 fn as_any(&self) -> &dyn std::any::Any {
186 self
187 }
188
189 fn criteria(&self) -> &TrainingCriteria {
190 &self.criteria
191 }
192}
193
194#[async_trait]
195impl ScalarIndexPlugin for InvertedIndexPlugin {
196 fn name(&self) -> &str {
197 "Inverted"
198 }
199
200 fn new_training_request(
201 &self,
202 params: &str,
203 field: &Field,
204 ) -> Result<Box<dyn TrainingRequest>> {
205 match field.data_type() {
206 DataType::Utf8 | DataType::LargeUtf8 | DataType::LargeBinary => (),
207 DataType::List(f) if matches!(f.data_type(), DataType::Utf8 | DataType::LargeUtf8) => (),
208 DataType::LargeList(f) if matches!(f.data_type(), DataType::Utf8 | DataType::LargeUtf8) => (),
209
210 _ => return Err(Error::invalid_input_source(format!(
211 "A inverted index can only be created on a Utf8 or LargeUtf8 field/list or LargeBinary field. Column has type {:?}",
212 field.data_type()
213 )
214 .into()))
215 }
216
217 let params = serde_json::from_str::<InvertedIndexParams>(params)?;
218 Ok(Box::new(InvertedIndexTrainingRequest::new(params)))
219 }
220
221 fn provides_exact_answer(&self) -> bool {
222 false
223 }
224
225 fn version(&self) -> u32 {
226 max_supported_fts_format_version().index_version()
227 }
228
229 fn new_query_parser(
230 &self,
231 index_name: String,
232 _index_details: &prost_types::Any,
233 ) -> Option<Box<dyn ScalarQueryParser>> {
234 let Ok(index_details) = _index_details.to_msg::<pbold::InvertedIndexDetails>() else {
235 return None;
236 };
237
238 if Self::can_accelerate_queries(&index_details) {
239 Some(Box::new(FtsQueryParser::new(
240 index_name,
241 self.name().to_string(),
242 )))
243 } else {
244 None
245 }
246 }
247
248 async fn train_index(
259 &self,
260 data: SendableRecordBatchStream,
261 index_store: &dyn IndexStore,
262 request: Box<dyn TrainingRequest>,
263 fragment_ids: Option<Vec<u32>>,
264 progress: Arc<dyn IndexBuildProgress>,
265 ) -> Result<CreatedIndex> {
266 let request = (request as Box<dyn std::any::Any>)
267 .downcast::<InvertedIndexTrainingRequest>()
268 .map_err(|_| {
269 Error::invalid_input_source(
270 "must provide training request created by new_training_request".into(),
271 )
272 })?;
273 Self::train_inverted_index(
274 data,
275 index_store,
276 request.parameters.clone(),
277 fragment_ids,
278 progress,
279 )
280 .await
281 }
282
283 async fn load_index(
288 &self,
289 index_store: Arc<dyn IndexStore>,
290 _index_details: &prost_types::Any,
291 frag_reuse_index: Option<Arc<FragReuseIndex>>,
292 cache: &LanceCache,
293 ) -> Result<Arc<dyn ScalarIndex>> {
294 Ok(
295 InvertedIndex::load(index_store, frag_reuse_index, cache).await?
296 as Arc<dyn ScalarIndex>,
297 )
298 }
299
300 fn details_as_json(&self, details: &prost_types::Any) -> Result<serde_json::Value> {
301 let index_details = details.to_msg::<pbold::InvertedIndexDetails>()?;
302 let index_params = InvertedIndexParams::try_from(&index_details)?;
303 Ok(serde_json::json!(&index_params))
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn test_plugin_version_tracks_max_supported_format() {
313 let plugin = InvertedIndexPlugin;
314 assert_eq!(
315 plugin.version(),
316 max_supported_fts_format_version().index_version()
317 );
318 }
319}