Skip to main content

lance_index/scalar/inverted/tokenizer/
lance_tokenizer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use arrow_schema::{DataType, Field};
5use lance_arrow::json::JSON_EXT_NAME;
6use lance_arrow::ARROW_EXT_NAME_KEY;
7use serde_json::Value;
8use snafu::location;
9use tantivy::tokenizer::{BoxTokenStream, Token, TokenStream};
10
11/// Document type for full text search.
12#[derive(Debug, Clone)]
13pub enum DocType {
14    Text,
15    Json,
16}
17
18impl AsRef<str> for DocType {
19    fn as_ref(&self) -> &str {
20        match self {
21            Self::Text => "text",
22            Self::Json => "json",
23        }
24    }
25}
26
27impl TryFrom<&Field> for DocType {
28    type Error = lance_core::Error;
29
30    fn try_from(field: &Field) -> Result<Self, Self::Error> {
31        match field.data_type() {
32            DataType::Utf8 | DataType::LargeUtf8 => Ok(Self::Text),
33            DataType::List(field) | DataType::LargeList(field)
34                if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) =>
35            {
36                Ok(Self::Text)
37            }
38            DataType::LargeBinary => match field.metadata().get(ARROW_EXT_NAME_KEY) {
39                Some(name) if name.as_str() == JSON_EXT_NAME => Ok(Self::Json),
40                _ => Err(lance_core::Error::InvalidInput {
41                    source: format!("field {} is not json", field.name()).into(),
42                    location: location!(),
43                }),
44            },
45            _ => Err(lance_core::Error::InvalidInput {
46                source: format!("field {} is not json", field.name()).into(),
47                location: location!(),
48            }),
49        }
50    }
51}
52
53impl DocType {
54    /// Get the length of the prefix before value.
55    ///  - JSON Token: path,type,value
56    ///  - Text Token: value
57    pub fn prefix_len(&self, token: &str) -> usize {
58        match self {
59            Self::Json => {
60                if let Some(pos) = token.find(',') {
61                    if let Some(second_pos) = token[pos + 1..].find(',') {
62                        return pos + second_pos + 2;
63                    }
64                }
65                panic!("json token must be in format of <path>,<type>,<value>")
66            }
67            Self::Text => 0,
68        }
69    }
70}
71
72/// Lance full text search tokenizer.
73///
74/// `LanceTokenizer` defines 2 methods for tokenization, normally they are the same, but sometimes
75/// tokenizer needs different behavior for search and index. Take json document as an example:
76/// 1. Query text is a triplet <path,type,value>, something like `a.b,str,123`. We shouldn't use
77///    json in search, because it would be too complicated.
78/// 2. Document text is a json string.
79pub trait LanceTokenizer: Send + Sync {
80    /// Tokenize query text for search.
81    fn token_stream_for_search<'a>(&'a mut self, query_text: &'a str) -> BoxTokenStream<'a>;
82    /// Tokenize document text for index.
83    fn token_stream_for_doc<'a>(&'a mut self, text: &'a str) -> BoxTokenStream<'a>;
84    /// Clone the tokenizer.
85    fn box_clone(&self) -> Box<dyn LanceTokenizer>;
86    /// Get document type.
87    fn doc_type(&self) -> DocType;
88}
89
90impl Clone for Box<dyn LanceTokenizer> {
91    fn clone(&self) -> Self {
92        self.box_clone()
93    }
94}
95
96#[derive(Clone)]
97pub struct TextTokenizer {
98    tokenizer: tantivy::tokenizer::TextAnalyzer,
99}
100
101impl TextTokenizer {
102    pub fn new(tokenizer: tantivy::tokenizer::TextAnalyzer) -> Self {
103        Self { tokenizer }
104    }
105}
106
107impl LanceTokenizer for TextTokenizer {
108    fn token_stream_for_search<'a>(&'a mut self, query_text: &'a str) -> BoxTokenStream<'a> {
109        self.tokenizer.token_stream(query_text)
110    }
111
112    fn token_stream_for_doc<'a>(&'a mut self, text: &'a str) -> BoxTokenStream<'a> {
113        self.tokenizer.token_stream(text)
114    }
115
116    fn box_clone(&self) -> Box<dyn LanceTokenizer> {
117        Box::new(self.clone())
118    }
119
120    fn doc_type(&self) -> DocType {
121        DocType::Text
122    }
123}
124
125#[derive(Clone)]
126pub struct JsonTokenizer {
127    tokenizer: tantivy::tokenizer::TextAnalyzer,
128}
129
130impl JsonTokenizer {
131    pub fn new(tokenizer: tantivy::tokenizer::TextAnalyzer) -> Self {
132        Self { tokenizer }
133    }
134}
135
136impl LanceTokenizer for JsonTokenizer {
137    fn token_stream_for_search<'a>(&'a mut self, query_text: &'a str) -> BoxTokenStream<'a> {
138        let tokens = flatten_triplet(query_text, &mut self.tokenizer).unwrap();
139        BoxTokenStream::new(TTStream { tokens, index: 0 })
140    }
141
142    fn token_stream_for_doc<'a>(&'a mut self, text: &'a str) -> BoxTokenStream<'a> {
143        let value: Value = match serde_json::from_slice(text.as_bytes()) {
144            Ok(v) => v,
145            Err(e) => {
146                panic!("JSON parse error: {:?}", e);
147            }
148        };
149        let mut tokens = vec![];
150        let mut position = 0;
151        flatten_json(&value, "", &mut tokens, &mut position, &mut self.tokenizer);
152        BoxTokenStream::new(TTStream { tokens, index: 0 })
153    }
154
155    fn box_clone(&self) -> Box<dyn LanceTokenizer> {
156        Box::new(self.clone())
157    }
158
159    fn doc_type(&self) -> DocType {
160        DocType::Json
161    }
162}
163
164fn flatten_triplet(
165    text: &str,
166    tokenizer: &mut tantivy::tokenizer::TextAnalyzer,
167) -> lance_core::Result<Vec<Token>> {
168    let mut token_vec = Vec::new();
169    let mut idx = 0;
170
171    for triple in text.split(';') {
172        let parts: Vec<&str> = triple.splitn(3, ',').collect();
173        if parts.len() != 3 {
174            return Err(lance_core::Error::InvalidInput {
175                source: format!("Invalid triple format: {}", triple).into(),
176                location: location!(),
177            });
178        }
179        let field = parts[0];
180        let v_type = parts[1];
181        let value = parts[2];
182
183        match v_type {
184            "number" | "bool" | "null" => {
185                let token = Token {
186                    offset_from: 0,
187                    offset_to: 0,
188                    position: idx,
189                    text: format!("{},{},{}", field, v_type, value),
190                    position_length: 1,
191                };
192                token_vec.push(token);
193                idx += 1;
194            }
195            "str" => {
196                let mut tokens = tokenizer.token_stream(value);
197                while let Some(token) = tokens.next() {
198                    token_vec.push(Token {
199                        offset_from: 0,
200                        offset_to: 0,
201                        position: idx,
202                        text: format!("{},{},{}", field, v_type, token.text),
203                        position_length: 1,
204                    });
205                    idx += 1;
206                }
207            }
208            _ => {
209                return Err(lance_core::Error::InvalidInput {
210                    source: format!("Invalid triple type: {}", v_type).into(),
211                    location: location!(),
212                })
213            }
214        }
215    }
216    Ok(token_vec)
217}
218
219fn flatten_json(
220    value: &Value,
221    prefix: &str,
222    out: &mut Vec<Token>,
223    position: &mut usize,
224    tokenizer: &mut tantivy::tokenizer::TextAnalyzer,
225) {
226    match value {
227        Value::Object(map) => {
228            for (k, v) in map {
229                let next_prefix = if prefix.is_empty() {
230                    k.clone()
231                } else {
232                    format!("{}.{}", prefix, k)
233                };
234                flatten_json(v, &next_prefix, out, position, tokenizer);
235            }
236        }
237        Value::Array(arr) => {
238            for v in arr.iter() {
239                flatten_json(v, prefix, out, position, tokenizer);
240            }
241        }
242        Value::String(text) => {
243            let mut tokens = tokenizer.token_stream(text);
244            while let Some(token) = tokens.next() {
245                let token = Token {
246                    offset_from: 0,
247                    offset_to: 0,
248                    position: *position,
249                    text: format!("{},{},{}", prefix, "str", token.text),
250                    position_length: 1,
251                };
252                *position += 1;
253                out.push(token);
254            }
255        }
256        _ => {
257            let value_type = match value {
258                Value::Null => "null",
259                Value::Bool(_) => "bool",
260                Value::Number(_) => "number",
261                _ => unreachable!(),
262            };
263            let token = Token {
264                offset_from: 0,
265                offset_to: 0,
266                position: *position,
267                text: format!("{},{},{}", prefix, value_type, value),
268                position_length: 1,
269            };
270            *position += 1;
271            out.push(token);
272        }
273    }
274}
275
276struct TTStream {
277    tokens: Vec<Token>,
278    index: usize,
279}
280
281impl TokenStream for TTStream {
282    fn advance(&mut self) -> bool {
283        if self.index < self.tokens.len() {
284            self.index += 1;
285            true
286        } else {
287            false
288        }
289    }
290
291    fn token(&self) -> &Token {
292        &self.tokens[self.index - 1]
293    }
294
295    fn token_mut(&mut self) -> &mut Token {
296        &mut self.tokens[self.index - 1]
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use crate::scalar::inverted::tokenizer::lance_tokenizer::{
303        flatten_json, flatten_triplet, JsonTokenizer, LanceTokenizer,
304    };
305    use serde_json::Value;
306    use tantivy::tokenizer::{SimpleTokenizer, Token};
307
308    #[test]
309    fn test_json_tokenizer() {
310        let text = r#"{
311          "a": 1,
312          "b": [
313            {"c": "d"},
314            {"c": "e"}
315          ]
316        }"#;
317        let mut tokenizer = JsonTokenizer::new(
318            tantivy::tokenizer::TextAnalyzer::builder(SimpleTokenizer::default()).build(),
319        );
320        let mut stream = tokenizer.token_stream_for_doc(text);
321
322        let mut tokens: Vec<Token> = vec![];
323        while let Some(token) = stream.next() {
324            tokens.push(token.clone());
325        }
326
327        assert_eq!(tokens.len(), 3);
328        assert_token(&tokens[0], 0, "a,number,1");
329        assert_token(&tokens[1], 1, "b.c,str,d");
330        assert_token(&tokens[2], 2, "b.c,str,e");
331    }
332
333    #[test]
334    fn test_flatten_json_text() {
335        let json = r#"{
336              "a": 1,
337              "b": [
338                {"c": "hello world"},
339                {"c": "e"}
340              ],
341              "c": true,
342              "d": null,
343              "e": {
344                "f": 1.0
345              }
346          }"#;
347        let value: Value = serde_json::from_str(json).unwrap();
348
349        let mut tokens = vec![];
350        let mut tokenizer =
351            tantivy::tokenizer::TextAnalyzer::builder(SimpleTokenizer::default()).build();
352        let mut position = 0;
353        flatten_json(&value, "", &mut tokens, &mut position, &mut tokenizer);
354
355        assert_eq!(7, tokens.len());
356        assert_token(&tokens[0], 0, "a,number,1");
357        assert_token(&tokens[1], 1, "b.c,str,hello");
358        assert_token(&tokens[2], 2, "b.c,str,world");
359        assert_token(&tokens[3], 3, "b.c,str,e");
360        assert_token(&tokens[4], 4, "c,bool,true");
361        assert_token(&tokens[5], 5, "d,null,null");
362        assert_token(&tokens[6], 6, "e.f,number,1.0");
363    }
364
365    #[test]
366    fn test_flatten_triplet() {
367        let text = r#"a,number,1;b.c,str,d;b.c,str,e;d,str,hello world;e,number,1.0"#;
368        let mut tokenizer =
369            tantivy::tokenizer::TextAnalyzer::builder(SimpleTokenizer::default()).build();
370        let tokens = flatten_triplet(text, &mut tokenizer).unwrap();
371
372        assert_eq!(tokens.len(), 6);
373        assert_token(&tokens[0], 0, "a,number,1");
374        assert_token(&tokens[1], 1, "b.c,str,d");
375        assert_token(&tokens[2], 2, "b.c,str,e");
376        assert_token(&tokens[3], 3, "d,str,hello");
377        assert_token(&tokens[4], 4, "d,str,world");
378        assert_token(&tokens[5], 5, "e,number,1.0");
379    }
380
381    fn assert_token(token: &Token, position: usize, text: &str) {
382        assert_eq!(
383            token.position, position,
384            "expected position {position} but {token:?}"
385        );
386        assert_eq!(
387            token.text.as_str(),
388            text,
389            "expected text {text} but {token:?}"
390        );
391    }
392}