lance_index/scalar/inverted/
tokenizer.rs1use lance_core::{Error, Result};
5use serde::{Deserialize, Serialize};
6use snafu::location;
7use std::{env, path::PathBuf};
8
9#[cfg(feature = "tokenizer-jieba")]
10mod jieba;
11
12pub mod lance_tokenizer;
13#[cfg(feature = "tokenizer-lindera")]
14mod lindera;
15
16#[cfg(feature = "tokenizer-jieba")]
17use jieba::JiebaTokenizerBuilder;
18
19#[cfg(feature = "tokenizer-lindera")]
20use lindera::LinderaTokenizerBuilder;
21
22use crate::pbold;
23use crate::scalar::inverted::tokenizer::lance_tokenizer::{
24 JsonTokenizer, LanceTokenizer, TextTokenizer,
25};
26
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
29pub struct InvertedIndexParams {
30 pub(crate) lance_tokenizer: Option<String>,
35 pub(crate) base_tokenizer: String,
44
45 pub(crate) language: tantivy::tokenizer::Language,
48
49 #[serde(default)]
54 pub(crate) with_position: bool,
55
56 pub(crate) max_token_length: Option<usize>,
60
61 #[serde(default = "bool_true")]
63 pub(crate) lower_case: bool,
64
65 #[serde(default = "bool_true")]
67 pub(crate) stem: bool,
68
69 #[serde(default = "bool_true")]
71 pub(crate) remove_stop_words: bool,
72
73 pub(crate) custom_stop_words: Option<Vec<String>>,
77
78 #[serde(default = "bool_true")]
80 pub(crate) ascii_folding: bool,
81
82 #[serde(default = "default_min_ngram_length")]
84 pub(crate) min_ngram_length: u32,
85
86 #[serde(default = "default_max_ngram_length")]
88 pub(crate) max_ngram_length: u32,
89
90 #[serde(default)]
92 pub(crate) prefix_only: bool,
93}
94
95impl TryFrom<&InvertedIndexParams> for pbold::InvertedIndexDetails {
96 type Error = Error;
97
98 fn try_from(params: &InvertedIndexParams) -> Result<Self> {
99 Ok(Self {
100 base_tokenizer: Some(params.base_tokenizer.clone()),
101 language: serde_json::to_string(¶ms.language)?,
102 with_position: params.with_position,
103 max_token_length: params.max_token_length.map(|l| l as u32),
104 lower_case: params.lower_case,
105 stem: params.stem,
106 remove_stop_words: params.remove_stop_words,
107 ascii_folding: params.ascii_folding,
108 min_ngram_length: params.min_ngram_length,
109 max_ngram_length: params.max_ngram_length,
110 prefix_only: params.prefix_only,
111 })
112 }
113}
114
115impl TryFrom<&pbold::InvertedIndexDetails> for InvertedIndexParams {
116 type Error = Error;
117
118 fn try_from(details: &pbold::InvertedIndexDetails) -> Result<Self> {
119 let defaults = Self::default();
120 Ok(Self {
121 lance_tokenizer: defaults.lance_tokenizer,
122 base_tokenizer: details
123 .base_tokenizer
124 .as_ref()
125 .cloned()
126 .unwrap_or(defaults.base_tokenizer),
127 language: serde_json::from_str(details.language.as_str())?,
128 with_position: details.with_position,
129 max_token_length: details.max_token_length.map(|l| l as usize),
130 lower_case: details.lower_case,
131 stem: details.stem,
132 remove_stop_words: details.remove_stop_words,
133 custom_stop_words: defaults.custom_stop_words,
134 ascii_folding: details.ascii_folding,
135 min_ngram_length: details.min_ngram_length,
136 max_ngram_length: details.max_ngram_length,
137 prefix_only: details.prefix_only,
138 })
139 }
140}
141
142fn bool_true() -> bool {
143 true
144}
145
146fn default_min_ngram_length() -> u32 {
147 3
148}
149
150fn default_max_ngram_length() -> u32 {
151 3
152}
153
154impl Default for InvertedIndexParams {
155 fn default() -> Self {
156 Self::new("simple".to_owned(), tantivy::tokenizer::Language::English)
157 }
158}
159
160impl InvertedIndexParams {
161 pub fn new(base_tokenizer: String, language: tantivy::tokenizer::Language) -> Self {
175 Self {
176 lance_tokenizer: None,
177 base_tokenizer,
178 language,
179 with_position: false,
180 max_token_length: Some(40),
181 lower_case: true,
182 stem: true,
183 remove_stop_words: true,
184 custom_stop_words: None,
185 ascii_folding: true,
186 min_ngram_length: default_min_ngram_length(),
187 max_ngram_length: default_max_ngram_length(),
188 prefix_only: false,
189 }
190 }
191
192 pub fn lance_tokenizer(mut self, lance_tokenizer: String) -> Self {
193 self.lance_tokenizer = Some(lance_tokenizer);
194 self
195 }
196
197 pub fn base_tokenizer(mut self, base_tokenizer: String) -> Self {
198 self.base_tokenizer = base_tokenizer;
199 self
200 }
201
202 pub fn language(mut self, language: &str) -> Result<Self> {
203 let language = serde_json::from_str(format!("\"{}\"", language).as_str())?;
205 self.language = language;
206 Ok(self)
207 }
208
209 pub fn with_position(mut self, with_position: bool) -> Self {
215 self.with_position = with_position;
216 self
217 }
218
219 pub fn max_token_length(mut self, max_token_length: Option<usize>) -> Self {
220 self.max_token_length = max_token_length;
221 self
222 }
223
224 pub fn lower_case(mut self, lower_case: bool) -> Self {
225 self.lower_case = lower_case;
226 self
227 }
228
229 pub fn stem(mut self, stem: bool) -> Self {
230 self.stem = stem;
231 self
232 }
233
234 pub fn remove_stop_words(mut self, remove_stop_words: bool) -> Self {
235 self.remove_stop_words = remove_stop_words;
236 self
237 }
238
239 pub fn custom_stop_words(mut self, custom_stop_words: Option<Vec<String>>) -> Self {
240 self.custom_stop_words = custom_stop_words;
241 self
242 }
243
244 pub fn ascii_folding(mut self, ascii_folding: bool) -> Self {
245 self.ascii_folding = ascii_folding;
246 self
247 }
248
249 pub fn ngram_min_length(mut self, min_length: u32) -> Self {
253 self.min_ngram_length = min_length;
254 self
255 }
256
257 pub fn ngram_max_length(mut self, max_length: u32) -> Self {
261 self.max_ngram_length = max_length;
262 self
263 }
264
265 pub fn ngram_prefix_only(mut self, prefix_only: bool) -> Self {
268 self.prefix_only = prefix_only;
269 self
270 }
271
272 pub fn build(&self) -> Result<Box<dyn LanceTokenizer>> {
273 let mut builder = self.build_base_tokenizer()?;
274 if let Some(max_token_length) = self.max_token_length {
275 builder = builder.filter_dynamic(tantivy::tokenizer::RemoveLongFilter::limit(
276 max_token_length,
277 ));
278 }
279 if self.lower_case {
280 builder = builder.filter_dynamic(tantivy::tokenizer::LowerCaser);
281 }
282 if self.stem {
283 builder = builder.filter_dynamic(tantivy::tokenizer::Stemmer::new(self.language));
284 }
285 if self.remove_stop_words {
286 let stop_word_filter = match &self.custom_stop_words {
287 Some(words) => tantivy::tokenizer::StopWordFilter::remove(words.iter().cloned()),
288 None => {
289 tantivy::tokenizer::StopWordFilter::new(self.language).ok_or_else(|| {
290 Error::invalid_input(
291 format!(
292 "removing stop words for language {:?} is not supported yet",
293 self.language
294 ),
295 location!(),
296 )
297 })?
298 }
299 };
300 builder = builder.filter_dynamic(stop_word_filter);
301 }
302 if self.ascii_folding {
303 builder = builder.filter_dynamic(tantivy::tokenizer::AsciiFoldingFilter);
304 }
305 let tokenizer = builder.build();
306
307 match self.lance_tokenizer {
308 Some(ref t) if t == "text" => Ok(Box::new(TextTokenizer::new(tokenizer))),
309 Some(ref t) if t == "json" => Ok(Box::new(JsonTokenizer::new(tokenizer))),
310 None => Ok(Box::new(TextTokenizer::new(tokenizer))),
311 _ => Err(Error::invalid_input(
312 format!(
313 "unknown lance tokenizer {}",
314 self.lance_tokenizer.as_ref().unwrap()
315 ),
316 location!(),
317 )),
318 }
319 }
320
321 fn build_base_tokenizer(&self) -> Result<tantivy::tokenizer::TextAnalyzerBuilder> {
322 match self.base_tokenizer.as_str() {
323 "simple" => Ok(tantivy::tokenizer::TextAnalyzer::builder(
324 tantivy::tokenizer::SimpleTokenizer::default(),
325 )
326 .dynamic()),
327 "whitespace" => Ok(tantivy::tokenizer::TextAnalyzer::builder(
328 tantivy::tokenizer::WhitespaceTokenizer::default(),
329 )
330 .dynamic()),
331 "raw" => Ok(tantivy::tokenizer::TextAnalyzer::builder(
332 tantivy::tokenizer::RawTokenizer::default(),
333 )
334 .dynamic()),
335 "ngram" => Ok(tantivy::tokenizer::TextAnalyzer::builder(
336 tantivy::tokenizer::NgramTokenizer::new(
337 self.min_ngram_length as usize,
338 self.max_ngram_length as usize,
339 self.prefix_only,
340 )
341 .map_err(|e| Error::invalid_input(e.to_string(), location!()))?,
342 )
343 .dynamic()),
344 #[cfg(feature = "tokenizer-lindera")]
345 s if s.starts_with("lindera/") => {
346 let Some(home) = language_model_home() else {
347 return Err(Error::invalid_input(
348 format!("unknown base tokenizer {}", self.base_tokenizer),
349 location!(),
350 ));
351 };
352 lindera::LinderaBuilder::load(&home.join(s))?.build()
353 }
354 #[cfg(feature = "tokenizer-jieba")]
355 s if s.starts_with("jieba/") || s == "jieba" => {
356 let s = if s == "jieba" { "jieba/default" } else { s };
357 let Some(home) = language_model_home() else {
358 return Err(Error::invalid_input(
359 format!("unknown base tokenizer {}", self.base_tokenizer),
360 location!(),
361 ));
362 };
363 jieba::JiebaBuilder::load(&home.join(s))?.build()
364 }
365 _ => Err(Error::invalid_input(
366 format!("unknown base tokenizer {}", self.base_tokenizer),
367 location!(),
368 )),
369 }
370 }
371}
372
373pub const LANCE_LANGUAGE_MODEL_HOME_ENV_KEY: &str = "LANCE_LANGUAGE_MODEL_HOME";
374
375pub const LANCE_LANGUAGE_MODEL_DEFAULT_DIRECTORY: &str = "lance/language_models";
376
377pub fn language_model_home() -> Option<PathBuf> {
378 match env::var(LANCE_LANGUAGE_MODEL_HOME_ENV_KEY) {
379 Ok(p) => Some(PathBuf::from(p)),
380 Err(_) => dirs::data_local_dir().map(|p| p.join(LANCE_LANGUAGE_MODEL_DEFAULT_DIRECTORY)),
381 }
382}