Skip to main content

lance_index/scalar/
inverted.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4pub mod builder;
5mod cache_codec;
6mod encoding;
7mod index;
8mod iter;
9pub mod json;
10pub mod parser;
11pub mod query;
12mod scorer;
13pub mod tokenizer;
14mod wand;
15
16use std::collections::{HashMap, HashSet};
17use std::sync::Arc;
18
19use arrow_schema::{DataType, Field};
20use async_trait::async_trait;
21pub use builder::InvertedIndexBuilder;
22use datafusion::execution::SendableRecordBatchStream;
23pub use index::*;
24use lance_core::{Result, cache::LanceCache};
25pub use lance_tokenizer::Language;
26pub use scorer::{MemBM25Scorer, Scorer};
27pub use tokenizer::*;
28
29use crate::scalar::inverted::query::{FtsSearchParams, Tokens};
30
31/// Collect the unique terms needed to build a shared BM25 scorer.
32///
33/// The scorer only needs corpus-level document frequencies, so we keep a
34/// deduplicated term list here instead of constructing a full `Tokens`
35/// object with positions. When fuzziness is enabled, each segment may
36/// contribute additional terms (via `expand_fuzzy_tokens`); the union of
37/// those terms is what the global scorer must cover.
38fn scorer_terms(
39    indices: &[Arc<InvertedIndex>],
40    query_tokens: &Tokens,
41    params: &FtsSearchParams,
42) -> Result<Vec<String>> {
43    let mut terms = Vec::new();
44    let mut seen = HashSet::new();
45
46    if !matches!(params.fuzziness, Some(n) if n != 0) {
47        for token in query_tokens {
48            if seen.insert(token.to_string()) {
49                terms.push(token.to_string());
50            }
51        }
52        return Ok(terms);
53    }
54
55    for index in indices {
56        let expanded = index.expand_fuzzy_tokens(query_tokens, params)?;
57        for idx in 0..expanded.len() {
58            let token = expanded.get_token(idx);
59            if seen.insert(token.to_string()) {
60                terms.push(token.to_string());
61            }
62        }
63    }
64    Ok(terms)
65}
66
67/// Build a shared [`MemBM25Scorer`] across a set of FTS index segments.
68///
69/// Aggregates each segment's `(total_tokens, num_docs, per_term_doc_freq)`
70/// statistics — obtained via [`InvertedIndex::bm25_stats_for_terms`] — into a
71/// single corpus-wide scorer, so that BM25 IDF scoring uses *global*
72/// statistics rather than per-segment statistics. Computes the union of
73/// fuzzy-expanded terms when `params.fuzziness` is set.
74///
75/// Public as the canonical producer paired with the `with_base_scorer`
76/// consumer on FTS exec types: callers holding `Arc<InvertedIndex>` segment
77/// handles locally can construct an injectable scorer without reimplementing
78/// per-segment stat aggregation, term deduplication, and fuzzy-expansion
79/// union. Keeps a single source of truth for BM25 IDF arithmetic across
80/// segments.
81pub fn build_global_bm25_scorer(
82    indices: &[Arc<InvertedIndex>],
83    query_tokens: &Tokens,
84    params: &FtsSearchParams,
85) -> Result<MemBM25Scorer> {
86    let terms = scorer_terms(indices, query_tokens, params)?;
87    let first_index = indices.first().ok_or_else(|| {
88        lance_core::Error::invalid_input("FTS index requires at least one segment")
89    })?;
90    let (mut total_tokens, mut num_docs, first_token_docs) =
91        first_index.bm25_stats_for_terms(&terms);
92    let mut token_docs = HashMap::with_capacity(terms.len());
93    for (term, count) in terms.iter().cloned().zip(first_token_docs.into_iter()) {
94        token_docs.insert(term, count);
95    }
96
97    for index in indices.iter().skip(1) {
98        let (segment_total_tokens, segment_num_docs, segment_token_docs) =
99            index.bm25_stats_for_terms(&terms);
100        total_tokens += segment_total_tokens;
101        num_docs += segment_num_docs;
102        for (term, count) in terms.iter().zip(segment_token_docs.into_iter()) {
103            *token_docs
104                .get_mut(term)
105                .expect("global scorer terms should already be initialized") += count;
106        }
107    }
108
109    Ok(MemBM25Scorer::new(total_tokens, num_docs, token_docs))
110}
111
112use lance_core::Error;
113
114use crate::pbold;
115use crate::progress::IndexBuildProgress;
116use crate::{
117    frag_reuse::FragReuseIndex,
118    scalar::{
119        CreatedIndex, ScalarIndex,
120        expression::{FtsQueryParser, ScalarQueryParser},
121        registry::{ScalarIndexPlugin, TrainingCriteria, TrainingOrdering, TrainingRequest},
122    },
123};
124
125use super::IndexStore;
126
127#[derive(Debug, Default)]
128pub struct InvertedIndexPlugin;
129
130impl InvertedIndexPlugin {
131    pub async fn train_inverted_index(
132        data: SendableRecordBatchStream,
133        index_store: &dyn IndexStore,
134        params: InvertedIndexParams,
135        fragment_ids: Option<Vec<u32>>,
136        progress: Arc<dyn IndexBuildProgress>,
137    ) -> Result<CreatedIndex> {
138        let fragment_mask = fragment_ids.as_ref().and_then(|frag_ids| {
139            if !frag_ids.is_empty() {
140                // Create a mask with fragment_id in high 32 bits for distributed indexing
141                // This mask is used to filter partitions belonging to specific fragments
142                // If multiple fragments processed, use first fragment_id <<32 as mask
143                Some((frag_ids[0] as u64) << 32)
144            } else {
145                None
146            }
147        });
148
149        let details = pbold::InvertedIndexDetails::try_from(&params)?;
150        let mut inverted_index =
151            InvertedIndexBuilder::new_with_fragment_mask(params, fragment_mask)
152                .with_progress(progress);
153        inverted_index.update(data, index_store, None).await?;
154        Ok(CreatedIndex {
155            index_details: prost_types::Any::from_msg(&details).unwrap(),
156            index_version: current_fts_format_version().index_version(),
157            files: Some(index_store.list_files_with_sizes().await?),
158        })
159    }
160
161    /// Return true if the query can be used to speed up contains_tokens queries
162    fn can_accelerate_queries(details: &pbold::InvertedIndexDetails) -> bool {
163        details.base_tokenizer == Some("simple".to_string())
164            && details.max_token_length.is_none()
165            && details.language == serde_json::to_string(&Language::English).unwrap()
166            && !details.stem
167    }
168}
169
170struct InvertedIndexTrainingRequest {
171    parameters: InvertedIndexParams,
172    criteria: TrainingCriteria,
173}
174
175impl InvertedIndexTrainingRequest {
176    pub fn new(parameters: InvertedIndexParams) -> Self {
177        Self {
178            parameters,
179            criteria: TrainingCriteria::new(TrainingOrdering::None).with_row_id(),
180        }
181    }
182}
183
184impl TrainingRequest for InvertedIndexTrainingRequest {
185    fn as_any(&self) -> &dyn std::any::Any {
186        self
187    }
188
189    fn criteria(&self) -> &TrainingCriteria {
190        &self.criteria
191    }
192}
193
194#[async_trait]
195impl ScalarIndexPlugin for InvertedIndexPlugin {
196    fn name(&self) -> &str {
197        "Inverted"
198    }
199
200    fn new_training_request(
201        &self,
202        params: &str,
203        field: &Field,
204    ) -> Result<Box<dyn TrainingRequest>> {
205        match field.data_type() {
206            DataType::Utf8 | DataType::LargeUtf8 | DataType::LargeBinary => (),
207            DataType::List(f) if matches!(f.data_type(), DataType::Utf8 | DataType::LargeUtf8) => (),
208            DataType::LargeList(f) if matches!(f.data_type(), DataType::Utf8 | DataType::LargeUtf8) => (),
209
210            _ => return Err(Error::invalid_input_source(format!(
211                "A inverted index can only be created on a Utf8 or LargeUtf8 field/list or LargeBinary field. Column has type {:?}",
212                field.data_type()
213            )
214                .into()))
215        }
216
217        let params = serde_json::from_str::<InvertedIndexParams>(params)?;
218        Ok(Box::new(InvertedIndexTrainingRequest::new(params)))
219    }
220
221    fn provides_exact_answer(&self) -> bool {
222        false
223    }
224
225    fn version(&self) -> u32 {
226        max_supported_fts_format_version().index_version()
227    }
228
229    fn new_query_parser(
230        &self,
231        index_name: String,
232        _index_details: &prost_types::Any,
233    ) -> Option<Box<dyn ScalarQueryParser>> {
234        let Ok(index_details) = _index_details.to_msg::<pbold::InvertedIndexDetails>() else {
235            return None;
236        };
237
238        if Self::can_accelerate_queries(&index_details) {
239            Some(Box::new(FtsQueryParser::new(
240                index_name,
241                self.name().to_string(),
242            )))
243        } else {
244            None
245        }
246    }
247
248    /// Train a new index
249    ///
250    /// The provided data must fulfill all the criteria returned by `training_criteria`.
251    /// It is the caller's responsibility to ensure this.
252    ///
253    /// Returns index details that describe the index.  These details can potentially be
254    /// useful for planning (although this will currently require inside information on
255    /// the index type) and they will need to be provided when loading the index.
256    ///
257    /// It is the caller's responsibility to store these details somewhere.
258    async fn train_index(
259        &self,
260        data: SendableRecordBatchStream,
261        index_store: &dyn IndexStore,
262        request: Box<dyn TrainingRequest>,
263        fragment_ids: Option<Vec<u32>>,
264        progress: Arc<dyn IndexBuildProgress>,
265    ) -> Result<CreatedIndex> {
266        let request = (request as Box<dyn std::any::Any>)
267            .downcast::<InvertedIndexTrainingRequest>()
268            .map_err(|_| {
269                Error::invalid_input_source(
270                    "must provide training request created by new_training_request".into(),
271                )
272            })?;
273        Self::train_inverted_index(
274            data,
275            index_store,
276            request.parameters.clone(),
277            fragment_ids,
278            progress,
279        )
280        .await
281    }
282
283    /// Load an index from storage
284    ///
285    /// The index details should match the details that were returned when the index was
286    /// originally trained.
287    async fn load_index(
288        &self,
289        index_store: Arc<dyn IndexStore>,
290        _index_details: &prost_types::Any,
291        frag_reuse_index: Option<Arc<FragReuseIndex>>,
292        cache: &LanceCache,
293    ) -> Result<Arc<dyn ScalarIndex>> {
294        Ok(
295            InvertedIndex::load(index_store, frag_reuse_index, cache).await?
296                as Arc<dyn ScalarIndex>,
297        )
298    }
299
300    fn details_as_json(&self, details: &prost_types::Any) -> Result<serde_json::Value> {
301        let index_details = details.to_msg::<pbold::InvertedIndexDetails>()?;
302        let index_params = InvertedIndexParams::try_from(&index_details)?;
303        Ok(serde_json::json!(&index_params))
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn test_plugin_version_tracks_max_supported_format() {
313        let plugin = InvertedIndexPlugin;
314        assert_eq!(
315            plugin.version(),
316            max_supported_fts_format_version().index_version()
317        );
318    }
319}