lance_index/scalar/inverted/
tokenizer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use lance_core::{Error, Result};
5use serde::{Deserialize, Serialize};
6use snafu::location;
7use std::{env, path::PathBuf};
8
9#[cfg(feature = "tokenizer-jieba")]
10mod jieba;
11
12pub mod lance_tokenizer;
13#[cfg(feature = "tokenizer-lindera")]
14mod lindera;
15
16#[cfg(feature = "tokenizer-jieba")]
17use jieba::JiebaTokenizerBuilder;
18
19#[cfg(feature = "tokenizer-lindera")]
20use lindera::LinderaTokenizerBuilder;
21
22use crate::pbold;
23use crate::scalar::inverted::tokenizer::lance_tokenizer::{
24    JsonTokenizer, LanceTokenizer, TextTokenizer,
25};
26
27/// Tokenizer configs
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
29pub struct InvertedIndexParams {
30    /// lance tokenizer takes care of different data types, such as text, json, etc.
31    /// - 'text': parsing input documents into tokens
32    /// - 'json': parsing input json string into tokens
33    /// - none: auto type inference
34    pub(crate) lance_tokenizer: Option<String>,
35    /// base tokenizer:
36    /// - `simple`: splits tokens on whitespace and punctuation
37    /// - `whitespace`: splits tokens on whitespace
38    /// - `raw`: no tokenization
39    /// - `lindera/*`: Lindera tokenizer
40    /// - `jieba/*`: Jieba tokenizer
41    ///
42    /// `simple` is recommended for most cases and the default value
43    pub(crate) base_tokenizer: String,
44
45    /// language for stemming and stop words
46    /// this is only used when `stem` or `remove_stop_words` is true
47    pub(crate) language: tantivy::tokenizer::Language,
48
49    /// If true, store the position of the term in the document
50    /// This can significantly increase the size of the index
51    /// If false, only store the frequency of the term in the document
52    /// Default is false
53    #[serde(default)]
54    pub(crate) with_position: bool,
55
56    /// maximum token length
57    /// - `None`: no limit
58    /// - `Some(n)`: remove tokens longer than `n`
59    pub(crate) max_token_length: Option<usize>,
60
61    /// whether lower case tokens
62    #[serde(default = "bool_true")]
63    pub(crate) lower_case: bool,
64
65    /// whether apply stemming
66    #[serde(default = "bool_true")]
67    pub(crate) stem: bool,
68
69    /// whether remove stop words
70    #[serde(default = "bool_true")]
71    pub(crate) remove_stop_words: bool,
72
73    /// use customized stop words.
74    /// - `None`: use built-in stop words based on language
75    /// - `Some(words)`: use customized stop words
76    pub(crate) custom_stop_words: Option<Vec<String>>,
77
78    /// ascii folding
79    #[serde(default = "bool_true")]
80    pub(crate) ascii_folding: bool,
81
82    /// min ngram length
83    #[serde(default = "default_min_ngram_length")]
84    pub(crate) min_ngram_length: u32,
85
86    /// max ngram length
87    #[serde(default = "default_max_ngram_length")]
88    pub(crate) max_ngram_length: u32,
89
90    /// whether prefix only
91    #[serde(default)]
92    pub(crate) prefix_only: bool,
93}
94
95impl TryFrom<&InvertedIndexParams> for pbold::InvertedIndexDetails {
96    type Error = Error;
97
98    fn try_from(params: &InvertedIndexParams) -> Result<Self> {
99        Ok(Self {
100            base_tokenizer: Some(params.base_tokenizer.clone()),
101            language: serde_json::to_string(&params.language)?,
102            with_position: params.with_position,
103            max_token_length: params.max_token_length.map(|l| l as u32),
104            lower_case: params.lower_case,
105            stem: params.stem,
106            remove_stop_words: params.remove_stop_words,
107            ascii_folding: params.ascii_folding,
108            min_ngram_length: params.min_ngram_length,
109            max_ngram_length: params.max_ngram_length,
110            prefix_only: params.prefix_only,
111        })
112    }
113}
114
115impl TryFrom<&pbold::InvertedIndexDetails> for InvertedIndexParams {
116    type Error = Error;
117
118    fn try_from(details: &pbold::InvertedIndexDetails) -> Result<Self> {
119        let defaults = Self::default();
120        Ok(Self {
121            lance_tokenizer: defaults.lance_tokenizer,
122            base_tokenizer: details
123                .base_tokenizer
124                .as_ref()
125                .cloned()
126                .unwrap_or(defaults.base_tokenizer),
127            language: serde_json::from_str(details.language.as_str())?,
128            with_position: details.with_position,
129            max_token_length: details.max_token_length.map(|l| l as usize),
130            lower_case: details.lower_case,
131            stem: details.stem,
132            remove_stop_words: details.remove_stop_words,
133            custom_stop_words: defaults.custom_stop_words,
134            ascii_folding: details.ascii_folding,
135            min_ngram_length: details.min_ngram_length,
136            max_ngram_length: details.max_ngram_length,
137            prefix_only: details.prefix_only,
138        })
139    }
140}
141
142fn bool_true() -> bool {
143    true
144}
145
146fn default_min_ngram_length() -> u32 {
147    3
148}
149
150fn default_max_ngram_length() -> u32 {
151    3
152}
153
154impl Default for InvertedIndexParams {
155    fn default() -> Self {
156        Self::new("simple".to_owned(), tantivy::tokenizer::Language::English)
157    }
158}
159
160impl InvertedIndexParams {
161    /// Create a new `InvertedIndexParams` with the given base tokenizer and language.
162    ///
163    /// The `base_tokenizer` can be one of the following:
164    /// - `simple`: splits tokens on whitespace and punctuation, default
165    /// - `whitespace`: splits tokens on whitespace
166    /// - `raw`: no tokenization
167    /// - `ngram`: N-Gram tokenizer
168    /// - `lindera/*`: Lindera tokenizer
169    /// - `jieba/*`: Jieba tokenizer
170    ///
171    /// The `language` is used for stemming and removing stop words,
172    /// this is not used for `lindera/*` and `jieba/*` tokenizers.
173    /// Default to `English`.
174    pub fn new(base_tokenizer: String, language: tantivy::tokenizer::Language) -> Self {
175        Self {
176            lance_tokenizer: None,
177            base_tokenizer,
178            language,
179            with_position: false,
180            max_token_length: Some(40),
181            lower_case: true,
182            stem: true,
183            remove_stop_words: true,
184            custom_stop_words: None,
185            ascii_folding: true,
186            min_ngram_length: default_min_ngram_length(),
187            max_ngram_length: default_max_ngram_length(),
188            prefix_only: false,
189        }
190    }
191
192    pub fn lance_tokenizer(mut self, lance_tokenizer: String) -> Self {
193        self.lance_tokenizer = Some(lance_tokenizer);
194        self
195    }
196
197    pub fn base_tokenizer(mut self, base_tokenizer: String) -> Self {
198        self.base_tokenizer = base_tokenizer;
199        self
200    }
201
202    pub fn language(mut self, language: &str) -> Result<Self> {
203        // need to convert to valid JSON string
204        let language = serde_json::from_str(format!("\"{}\"", language).as_str())?;
205        self.language = language;
206        Ok(self)
207    }
208
209    /// Set whether to store the position of the term in the document.
210    /// This can significantly increase the size of the index.
211    /// If false, only store the frequency of the term in the document.
212    /// This doesn't work with `ngram` tokenizer.
213    /// Default to `false`.
214    pub fn with_position(mut self, with_position: bool) -> Self {
215        self.with_position = with_position;
216        self
217    }
218
219    pub fn max_token_length(mut self, max_token_length: Option<usize>) -> Self {
220        self.max_token_length = max_token_length;
221        self
222    }
223
224    pub fn lower_case(mut self, lower_case: bool) -> Self {
225        self.lower_case = lower_case;
226        self
227    }
228
229    pub fn stem(mut self, stem: bool) -> Self {
230        self.stem = stem;
231        self
232    }
233
234    pub fn remove_stop_words(mut self, remove_stop_words: bool) -> Self {
235        self.remove_stop_words = remove_stop_words;
236        self
237    }
238
239    pub fn custom_stop_words(mut self, custom_stop_words: Option<Vec<String>>) -> Self {
240        self.custom_stop_words = custom_stop_words;
241        self
242    }
243
244    pub fn ascii_folding(mut self, ascii_folding: bool) -> Self {
245        self.ascii_folding = ascii_folding;
246        self
247    }
248
249    /// Set the minimum N-Gram length, only works when `base_tokenizer` is `ngram`.
250    /// Must be greater than 0 and not greater than `max_ngram_length`.
251    /// Default to 3.
252    pub fn ngram_min_length(mut self, min_length: u32) -> Self {
253        self.min_ngram_length = min_length;
254        self
255    }
256
257    /// Set the maximum N-Gram length, only works when `base_tokenizer` is `ngram`.
258    /// Must be greater than 0 and not less than `min_ngram_length`.
259    /// Default to 3.
260    pub fn ngram_max_length(mut self, max_length: u32) -> Self {
261        self.max_ngram_length = max_length;
262        self
263    }
264
265    /// Set whether only prefix N-Gram is generated, only works when `base_tokenizer` is `ngram`.
266    /// Default to `false`.
267    pub fn ngram_prefix_only(mut self, prefix_only: bool) -> Self {
268        self.prefix_only = prefix_only;
269        self
270    }
271
272    pub fn build(&self) -> Result<Box<dyn LanceTokenizer>> {
273        let mut builder = self.build_base_tokenizer()?;
274        if let Some(max_token_length) = self.max_token_length {
275            builder = builder.filter_dynamic(tantivy::tokenizer::RemoveLongFilter::limit(
276                max_token_length,
277            ));
278        }
279        if self.lower_case {
280            builder = builder.filter_dynamic(tantivy::tokenizer::LowerCaser);
281        }
282        if self.stem {
283            builder = builder.filter_dynamic(tantivy::tokenizer::Stemmer::new(self.language));
284        }
285        if self.remove_stop_words {
286            let stop_word_filter = match &self.custom_stop_words {
287                Some(words) => tantivy::tokenizer::StopWordFilter::remove(words.iter().cloned()),
288                None => {
289                    tantivy::tokenizer::StopWordFilter::new(self.language).ok_or_else(|| {
290                        Error::invalid_input(
291                            format!(
292                                "removing stop words for language {:?} is not supported yet",
293                                self.language
294                            ),
295                            location!(),
296                        )
297                    })?
298                }
299            };
300            builder = builder.filter_dynamic(stop_word_filter);
301        }
302        if self.ascii_folding {
303            builder = builder.filter_dynamic(tantivy::tokenizer::AsciiFoldingFilter);
304        }
305        let tokenizer = builder.build();
306
307        match self.lance_tokenizer {
308            Some(ref t) if t == "text" => Ok(Box::new(TextTokenizer::new(tokenizer))),
309            Some(ref t) if t == "json" => Ok(Box::new(JsonTokenizer::new(tokenizer))),
310            None => Ok(Box::new(TextTokenizer::new(tokenizer))),
311            _ => Err(Error::invalid_input(
312                format!(
313                    "unknown lance tokenizer {}",
314                    self.lance_tokenizer.as_ref().unwrap()
315                ),
316                location!(),
317            )),
318        }
319    }
320
321    fn build_base_tokenizer(&self) -> Result<tantivy::tokenizer::TextAnalyzerBuilder> {
322        match self.base_tokenizer.as_str() {
323            "simple" => Ok(tantivy::tokenizer::TextAnalyzer::builder(
324                tantivy::tokenizer::SimpleTokenizer::default(),
325            )
326            .dynamic()),
327            "whitespace" => Ok(tantivy::tokenizer::TextAnalyzer::builder(
328                tantivy::tokenizer::WhitespaceTokenizer::default(),
329            )
330            .dynamic()),
331            "raw" => Ok(tantivy::tokenizer::TextAnalyzer::builder(
332                tantivy::tokenizer::RawTokenizer::default(),
333            )
334            .dynamic()),
335            "ngram" => Ok(tantivy::tokenizer::TextAnalyzer::builder(
336                tantivy::tokenizer::NgramTokenizer::new(
337                    self.min_ngram_length as usize,
338                    self.max_ngram_length as usize,
339                    self.prefix_only,
340                )
341                .map_err(|e| Error::invalid_input(e.to_string(), location!()))?,
342            )
343            .dynamic()),
344            #[cfg(feature = "tokenizer-lindera")]
345            s if s.starts_with("lindera/") => {
346                let Some(home) = language_model_home() else {
347                    return Err(Error::invalid_input(
348                        format!("unknown base tokenizer {}", self.base_tokenizer),
349                        location!(),
350                    ));
351                };
352                lindera::LinderaBuilder::load(&home.join(s))?.build()
353            }
354            #[cfg(feature = "tokenizer-jieba")]
355            s if s.starts_with("jieba/") || s == "jieba" => {
356                let s = if s == "jieba" { "jieba/default" } else { s };
357                let Some(home) = language_model_home() else {
358                    return Err(Error::invalid_input(
359                        format!("unknown base tokenizer {}", self.base_tokenizer),
360                        location!(),
361                    ));
362                };
363                jieba::JiebaBuilder::load(&home.join(s))?.build()
364            }
365            _ => Err(Error::invalid_input(
366                format!("unknown base tokenizer {}", self.base_tokenizer),
367                location!(),
368            )),
369        }
370    }
371}
372
373pub const LANCE_LANGUAGE_MODEL_HOME_ENV_KEY: &str = "LANCE_LANGUAGE_MODEL_HOME";
374
375pub const LANCE_LANGUAGE_MODEL_DEFAULT_DIRECTORY: &str = "lance/language_models";
376
377pub fn language_model_home() -> Option<PathBuf> {
378    match env::var(LANCE_LANGUAGE_MODEL_HOME_ENV_KEY) {
379        Ok(p) => Some(PathBuf::from(p)),
380        Err(_) => dirs::data_local_dir().map(|p| p.join(LANCE_LANGUAGE_MODEL_DEFAULT_DIRECTORY)),
381    }
382}