use lance_core::{Error, Result};
use serde::{Deserialize, Serialize};
use std::{env, path::PathBuf};
#[cfg(feature = "tokenizer-jieba")]
mod jieba;
pub mod lance_tokenizer;
#[cfg(feature = "tokenizer-lindera")]
mod lindera;
#[cfg(feature = "tokenizer-jieba")]
use jieba::JiebaTokenizerBuilder;
#[cfg(feature = "tokenizer-lindera")]
use lindera::LinderaTokenizerBuilder;
use crate::pbold;
use crate::scalar::inverted::tokenizer::lance_tokenizer::{
JsonTokenizer, LanceTokenizer, TextTokenizer,
};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct InvertedIndexParams {
pub(crate) lance_tokenizer: Option<String>,
pub(crate) base_tokenizer: String,
pub(crate) language: tantivy::tokenizer::Language,
#[serde(default)]
pub(crate) with_position: bool,
pub(crate) max_token_length: Option<usize>,
#[serde(default = "bool_true")]
pub(crate) lower_case: bool,
#[serde(default = "bool_true")]
pub(crate) stem: bool,
#[serde(default = "bool_true")]
pub(crate) remove_stop_words: bool,
pub(crate) custom_stop_words: Option<Vec<String>>,
#[serde(default = "bool_true")]
pub(crate) ascii_folding: bool,
#[serde(default = "default_min_ngram_length")]
pub(crate) min_ngram_length: u32,
#[serde(default = "default_max_ngram_length")]
pub(crate) max_ngram_length: u32,
#[serde(default)]
pub(crate) prefix_only: bool,
#[serde(
rename = "memory_limit",
skip_serializing,
default,
alias = "worker_memory_limit_mb"
)]
pub(crate) memory_limit_mb: Option<u64>,
#[serde(rename = "num_workers", skip_serializing, default)]
pub(crate) num_workers: Option<usize>,
}
impl TryFrom<&InvertedIndexParams> for pbold::InvertedIndexDetails {
type Error = Error;
fn try_from(params: &InvertedIndexParams) -> Result<Self> {
Ok(Self {
base_tokenizer: Some(params.base_tokenizer.clone()),
language: serde_json::to_string(¶ms.language)?,
with_position: params.with_position,
max_token_length: params.max_token_length.map(|l| l as u32),
lower_case: params.lower_case,
stem: params.stem,
remove_stop_words: params.remove_stop_words,
ascii_folding: params.ascii_folding,
min_ngram_length: params.min_ngram_length,
max_ngram_length: params.max_ngram_length,
prefix_only: params.prefix_only,
})
}
}
impl TryFrom<&pbold::InvertedIndexDetails> for InvertedIndexParams {
type Error = Error;
fn try_from(details: &pbold::InvertedIndexDetails) -> Result<Self> {
let defaults = Self::default();
Ok(Self {
lance_tokenizer: defaults.lance_tokenizer,
base_tokenizer: details
.base_tokenizer
.as_ref()
.cloned()
.unwrap_or(defaults.base_tokenizer),
language: serde_json::from_str(details.language.as_str())?,
with_position: details.with_position,
max_token_length: details.max_token_length.map(|l| l as usize),
lower_case: details.lower_case,
stem: details.stem,
remove_stop_words: details.remove_stop_words,
custom_stop_words: defaults.custom_stop_words,
ascii_folding: details.ascii_folding,
min_ngram_length: details.min_ngram_length,
max_ngram_length: details.max_ngram_length,
prefix_only: details.prefix_only,
memory_limit_mb: defaults.memory_limit_mb,
num_workers: defaults.num_workers,
})
}
}
fn bool_true() -> bool {
true
}
fn default_min_ngram_length() -> u32 {
3
}
fn default_max_ngram_length() -> u32 {
3
}
impl Default for InvertedIndexParams {
fn default() -> Self {
Self::new("simple".to_owned(), tantivy::tokenizer::Language::English)
}
}
impl InvertedIndexParams {
pub fn new(base_tokenizer: String, language: tantivy::tokenizer::Language) -> Self {
Self {
lance_tokenizer: None,
base_tokenizer,
language,
with_position: false,
max_token_length: Some(40),
lower_case: true,
stem: true,
remove_stop_words: true,
custom_stop_words: None,
ascii_folding: true,
min_ngram_length: default_min_ngram_length(),
max_ngram_length: default_max_ngram_length(),
prefix_only: false,
memory_limit_mb: None,
num_workers: None,
}
}
pub fn lance_tokenizer(mut self, lance_tokenizer: String) -> Self {
self.lance_tokenizer = Some(lance_tokenizer);
self
}
pub fn base_tokenizer(mut self, base_tokenizer: String) -> Self {
self.base_tokenizer = base_tokenizer;
self
}
pub fn language(mut self, language: &str) -> Result<Self> {
let language = serde_json::from_str(format!("\"{}\"", language).as_str())?;
self.language = language;
Ok(self)
}
pub fn with_position(mut self, with_position: bool) -> Self {
self.with_position = with_position;
self
}
pub fn has_positions(&self) -> bool {
self.with_position
}
pub fn max_token_length(mut self, max_token_length: Option<usize>) -> Self {
self.max_token_length = max_token_length;
self
}
pub fn lower_case(mut self, lower_case: bool) -> Self {
self.lower_case = lower_case;
self
}
pub fn stem(mut self, stem: bool) -> Self {
self.stem = stem;
self
}
pub fn remove_stop_words(mut self, remove_stop_words: bool) -> Self {
self.remove_stop_words = remove_stop_words;
self
}
pub fn custom_stop_words(mut self, custom_stop_words: Option<Vec<String>>) -> Self {
self.custom_stop_words = custom_stop_words;
self
}
pub fn ascii_folding(mut self, ascii_folding: bool) -> Self {
self.ascii_folding = ascii_folding;
self
}
pub fn ngram_min_length(mut self, min_length: u32) -> Self {
self.min_ngram_length = min_length;
self
}
pub fn ngram_max_length(mut self, max_length: u32) -> Self {
self.max_ngram_length = max_length;
self
}
pub fn ngram_prefix_only(mut self, prefix_only: bool) -> Self {
self.prefix_only = prefix_only;
self
}
pub fn memory_limit_mb(mut self, memory_limit_mb: u64) -> Self {
self.memory_limit_mb = Some(memory_limit_mb);
self
}
pub fn num_workers(mut self, num_workers: usize) -> Self {
self.num_workers = Some(num_workers);
self
}
pub fn to_training_json(&self) -> serde_json::Result<serde_json::Value> {
let mut value = serde_json::to_value(self)?;
let object = value
.as_object_mut()
.expect("inverted index params should serialize to a JSON object");
if let Some(memory_limit_mb) = self.memory_limit_mb {
object.insert(
"memory_limit".to_string(),
serde_json::Value::from(memory_limit_mb),
);
}
if let Some(num_workers) = self.num_workers {
object.insert(
"num_workers".to_string(),
serde_json::Value::from(num_workers),
);
}
Ok(value)
}
pub fn build(&self) -> Result<Box<dyn LanceTokenizer>> {
let mut builder = self.build_base_tokenizer()?;
if let Some(max_token_length) = self.max_token_length {
builder = builder.filter_dynamic(tantivy::tokenizer::RemoveLongFilter::limit(
max_token_length,
));
}
if self.lower_case {
builder = builder.filter_dynamic(tantivy::tokenizer::LowerCaser);
}
if self.stem {
builder = builder.filter_dynamic(tantivy::tokenizer::Stemmer::new(self.language));
}
if self.remove_stop_words {
let stop_word_filter = match &self.custom_stop_words {
Some(words) => tantivy::tokenizer::StopWordFilter::remove(words.iter().cloned()),
None => {
tantivy::tokenizer::StopWordFilter::new(self.language).ok_or_else(|| {
Error::invalid_input(format!(
"removing stop words for language {:?} is not supported yet",
self.language
))
})?
}
};
builder = builder.filter_dynamic(stop_word_filter);
}
if self.ascii_folding {
builder = builder.filter_dynamic(tantivy::tokenizer::AsciiFoldingFilter);
}
let tokenizer = builder.build();
match self.lance_tokenizer {
Some(ref t) if t == "text" => Ok(Box::new(TextTokenizer::new(tokenizer))),
Some(ref t) if t == "json" => Ok(Box::new(JsonTokenizer::new(tokenizer))),
None => Ok(Box::new(TextTokenizer::new(tokenizer))),
_ => Err(Error::invalid_input(format!(
"unknown lance tokenizer {}",
self.lance_tokenizer.as_ref().unwrap()
))),
}
}
fn build_base_tokenizer(&self) -> Result<tantivy::tokenizer::TextAnalyzerBuilder> {
match self.base_tokenizer.as_str() {
"simple" => Ok(tantivy::tokenizer::TextAnalyzer::builder(
tantivy::tokenizer::SimpleTokenizer::default(),
)
.dynamic()),
"whitespace" => Ok(tantivy::tokenizer::TextAnalyzer::builder(
tantivy::tokenizer::WhitespaceTokenizer::default(),
)
.dynamic()),
"raw" => Ok(tantivy::tokenizer::TextAnalyzer::builder(
tantivy::tokenizer::RawTokenizer::default(),
)
.dynamic()),
"ngram" => Ok(tantivy::tokenizer::TextAnalyzer::builder(
tantivy::tokenizer::NgramTokenizer::new(
self.min_ngram_length as usize,
self.max_ngram_length as usize,
self.prefix_only,
)
.map_err(|e| Error::invalid_input(e.to_string()))?,
)
.dynamic()),
#[cfg(feature = "tokenizer-lindera")]
s if s.starts_with("lindera/") => {
let Some(home) = language_model_home() else {
return Err(Error::invalid_input(format!(
"unknown base tokenizer {}",
self.base_tokenizer
)));
};
lindera::LinderaBuilder::load(&home.join(s))?.build()
}
#[cfg(feature = "tokenizer-jieba")]
s if s.starts_with("jieba/") || s == "jieba" => {
let s = if s == "jieba" { "jieba/default" } else { s };
let Some(home) = language_model_home() else {
return Err(Error::invalid_input(format!(
"unknown base tokenizer {}",
self.base_tokenizer
)));
};
jieba::JiebaBuilder::load(&home.join(s))?.build()
}
_ => Err(Error::invalid_input(format!(
"unknown base tokenizer {}",
self.base_tokenizer
))),
}
}
}
pub const LANCE_LANGUAGE_MODEL_HOME_ENV_KEY: &str = "LANCE_LANGUAGE_MODEL_HOME";
pub const LANCE_LANGUAGE_MODEL_DEFAULT_DIRECTORY: &str = "lance/language_models";
pub fn language_model_home() -> Option<PathBuf> {
match env::var(LANCE_LANGUAGE_MODEL_HOME_ENV_KEY) {
Ok(p) => Some(PathBuf::from(p)),
Err(_) => dirs::data_local_dir().map(|p| p.join(LANCE_LANGUAGE_MODEL_DEFAULT_DIRECTORY)),
}
}
#[cfg(test)]
mod tests {
use super::InvertedIndexParams;
#[test]
fn test_build_only_fields_are_not_serialized() {
let params = InvertedIndexParams::default()
.memory_limit_mb(4096)
.num_workers(7);
let json = serde_json::to_value(¶ms).unwrap();
assert!(json.get("memory_limit").is_none());
assert!(json.get("num_workers").is_none());
}
#[test]
fn test_memory_limit_serde_accepts_legacy_worker_field_name() {
let mut json = serde_json::to_value(InvertedIndexParams::default()).unwrap();
let obj = json.as_object_mut().unwrap();
obj.remove("memory_limit");
obj.insert(
"worker_memory_limit_mb".to_string(),
serde_json::Value::from(2048),
);
let params: InvertedIndexParams = serde_json::from_value(json).unwrap();
assert_eq!(params.memory_limit_mb, Some(2048));
}
#[test]
fn test_build_only_fields_deserialize_from_public_names() {
let mut json = serde_json::to_value(InvertedIndexParams::default()).unwrap();
let obj = json.as_object_mut().unwrap();
obj.insert("memory_limit".to_string(), serde_json::Value::from(4096));
obj.insert("num_workers".to_string(), serde_json::Value::from(3));
let params: InvertedIndexParams = serde_json::from_value(json).unwrap();
assert_eq!(params.memory_limit_mb, Some(4096));
assert_eq!(params.num_workers, Some(3));
}
#[test]
fn test_training_json_serializes_build_only_fields() {
let params = InvertedIndexParams::default()
.memory_limit_mb(4096)
.num_workers(3);
let json = params.to_training_json().unwrap();
assert_eq!(
json.get("memory_limit"),
Some(&serde_json::Value::from(4096))
);
assert_eq!(json.get("num_workers"), Some(&serde_json::Value::from(3)));
}
}