lance_index/scalar/
inverted.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4pub mod builder;
5mod encoding;
6mod index;
7mod iter;
8pub mod json;
9mod merger;
10pub mod parser;
11pub mod query;
12mod scorer;
13pub mod tokenizer;
14mod wand;
15
16use std::sync::Arc;
17
18use arrow_schema::{DataType, Field};
19use async_trait::async_trait;
20pub use builder::InvertedIndexBuilder;
21use datafusion::execution::SendableRecordBatchStream;
22pub use index::*;
23use lance_core::{cache::LanceCache, Result};
24use tantivy::tokenizer::Language;
25pub use tokenizer::*;
26
27use lance_core::Error;
28use snafu::location;
29
30use crate::pbold;
31use crate::{
32    frag_reuse::FragReuseIndex,
33    scalar::{
34        expression::{FtsQueryParser, ScalarQueryParser},
35        registry::{ScalarIndexPlugin, TrainingCriteria, TrainingOrdering, TrainingRequest},
36        CreatedIndex, ScalarIndex,
37    },
38};
39
40use super::IndexStore;
41
42#[derive(Debug, Default)]
43pub struct InvertedIndexPlugin;
44
45impl InvertedIndexPlugin {
46    pub async fn train_inverted_index(
47        data: SendableRecordBatchStream,
48        index_store: &dyn IndexStore,
49        params: InvertedIndexParams,
50        fragment_ids: Option<Vec<u32>>,
51    ) -> Result<CreatedIndex> {
52        let fragment_mask = fragment_ids.as_ref().and_then(|frag_ids| {
53            if !frag_ids.is_empty() {
54                // Create a mask with fragment_id in high 32 bits for distributed indexing
55                // This mask is used to filter partitions belonging to specific fragments
56                // If multiple fragments processed, use first fragment_id <<32 as mask
57                Some((frag_ids[0] as u64) << 32)
58            } else {
59                None
60            }
61        });
62
63        let details = pbold::InvertedIndexDetails::try_from(&params)?;
64        let mut inverted_index =
65            InvertedIndexBuilder::new_with_fragment_mask(params, fragment_mask);
66        inverted_index.update(data, index_store).await?;
67        Ok(CreatedIndex {
68            index_details: prost_types::Any::from_msg(&details).unwrap(),
69            index_version: INVERTED_INDEX_VERSION,
70        })
71    }
72
73    /// Return true if the query can be used to speed up contains_tokens queries
74    fn can_accelerate_queries(details: &pbold::InvertedIndexDetails) -> bool {
75        details.base_tokenizer == Some("simple".to_string())
76            && details.max_token_length.is_none()
77            && details.language == serde_json::to_string(&Language::English).unwrap()
78            && !details.stem
79    }
80}
81
82struct InvertedIndexTrainingRequest {
83    parameters: InvertedIndexParams,
84    criteria: TrainingCriteria,
85}
86
87impl InvertedIndexTrainingRequest {
88    pub fn new(parameters: InvertedIndexParams) -> Self {
89        Self {
90            parameters,
91            criteria: TrainingCriteria::new(TrainingOrdering::None).with_row_id(),
92        }
93    }
94}
95
96impl TrainingRequest for InvertedIndexTrainingRequest {
97    fn as_any(&self) -> &dyn std::any::Any {
98        self
99    }
100
101    fn criteria(&self) -> &TrainingCriteria {
102        &self.criteria
103    }
104}
105
106#[async_trait]
107impl ScalarIndexPlugin for InvertedIndexPlugin {
108    fn name(&self) -> &str {
109        "Inverted"
110    }
111
112    fn new_training_request(
113        &self,
114        params: &str,
115        field: &Field,
116    ) -> Result<Box<dyn TrainingRequest>> {
117        match field.data_type() {
118            DataType::Utf8 | DataType::LargeUtf8 | DataType::LargeBinary => (),
119            DataType::List(f) if matches!(f.data_type(), DataType::Utf8 | DataType::LargeUtf8) => (),
120            DataType::LargeList(f) if matches!(f.data_type(), DataType::Utf8 | DataType::LargeUtf8) => (),
121
122            _ => return Err(Error::InvalidInput {
123                source: format!(
124                    "A inverted index can only be created on a Utf8 or LargeUtf8 field/list or LargeBinary field. Column has type {:?}",
125                    field.data_type()
126                )
127                    .into(),
128                location: location!(),
129            })
130        }
131
132        let params = serde_json::from_str::<InvertedIndexParams>(params)?;
133        Ok(Box::new(InvertedIndexTrainingRequest::new(params)))
134    }
135
136    fn provides_exact_answer(&self) -> bool {
137        false
138    }
139
140    fn version(&self) -> u32 {
141        INVERTED_INDEX_VERSION
142    }
143
144    fn new_query_parser(
145        &self,
146        index_name: String,
147        _index_details: &prost_types::Any,
148    ) -> Option<Box<dyn ScalarQueryParser>> {
149        let Ok(index_details) = _index_details.to_msg::<pbold::InvertedIndexDetails>() else {
150            return None;
151        };
152
153        if Self::can_accelerate_queries(&index_details) {
154            Some(Box::new(FtsQueryParser::new(index_name)))
155        } else {
156            None
157        }
158    }
159
160    /// Train a new index
161    ///
162    /// The provided data must fulfill all the criteria returned by `training_criteria`.
163    /// It is the caller's responsibility to ensure this.
164    ///
165    /// Returns index details that describe the index.  These details can potentially be
166    /// useful for planning (although this will currently require inside information on
167    /// the index type) and they will need to be provided when loading the index.
168    ///
169    /// It is the caller's responsibility to store these details somewhere.
170    async fn train_index(
171        &self,
172        data: SendableRecordBatchStream,
173        index_store: &dyn IndexStore,
174        request: Box<dyn TrainingRequest>,
175        fragment_ids: Option<Vec<u32>>,
176    ) -> Result<CreatedIndex> {
177        let request = (request as Box<dyn std::any::Any>)
178            .downcast::<InvertedIndexTrainingRequest>()
179            .map_err(|_| Error::InvalidInput {
180                source: "must provide training request created by new_training_request".into(),
181                location: location!(),
182            })?;
183        Self::train_inverted_index(data, index_store, request.parameters.clone(), fragment_ids)
184            .await
185    }
186
187    /// Load an index from storage
188    ///
189    /// The index details should match the details that were returned when the index was
190    /// originally trained.
191    async fn load_index(
192        &self,
193        index_store: Arc<dyn IndexStore>,
194        _index_details: &prost_types::Any,
195        frag_reuse_index: Option<Arc<FragReuseIndex>>,
196        cache: &LanceCache,
197    ) -> Result<Arc<dyn ScalarIndex>> {
198        Ok(
199            InvertedIndex::load(index_store, frag_reuse_index, cache).await?
200                as Arc<dyn ScalarIndex>,
201        )
202    }
203
204    fn details_as_json(&self, details: &prost_types::Any) -> Result<serde_json::Value> {
205        let index_details = details.to_msg::<pbold::InvertedIndexDetails>()?;
206        let index_params = InvertedIndexParams::try_from(&index_details)?;
207        Ok(serde_json::json!(&index_params))
208    }
209}