Skip to main content

lance_index/scalar/inverted/tokenizer/
document_tokenizer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use arrow_schema::{DataType, Field};
5use lance_arrow::ARROW_EXT_NAME_KEY;
6use lance_arrow::json::JSON_EXT_NAME;
7use lance_tokenizer::{BoxTokenStream, TextAnalyzer, Token, TokenStream};
8use serde_json::Value;
9
10/// Document type for full text search.
11#[derive(Debug, Clone)]
12pub enum DocType {
13    Text,
14    Json,
15}
16
17impl AsRef<str> for DocType {
18    fn as_ref(&self) -> &str {
19        match self {
20            Self::Text => "text",
21            Self::Json => "json",
22        }
23    }
24}
25
26impl TryFrom<&Field> for DocType {
27    type Error = lance_core::Error;
28
29    fn try_from(field: &Field) -> Result<Self, Self::Error> {
30        match field.data_type() {
31            DataType::Utf8 | DataType::LargeUtf8 => Ok(Self::Text),
32            DataType::List(field) | DataType::LargeList(field)
33                if matches!(field.data_type(), DataType::Utf8 | DataType::LargeUtf8) =>
34            {
35                Ok(Self::Text)
36            }
37            DataType::LargeBinary => match field.metadata().get(ARROW_EXT_NAME_KEY) {
38                Some(name) if name.as_str() == JSON_EXT_NAME => Ok(Self::Json),
39                _ => Err(lance_core::Error::invalid_input_source(
40                    format!("field {} is not json", field.name()).into(),
41                )),
42            },
43            _ => Err(lance_core::Error::invalid_input_source(
44                format!("field {} is not json", field.name()).into(),
45            )),
46        }
47    }
48}
49
50impl DocType {
51    /// Get the length of the prefix before value.
52    ///  - JSON Token: path,type,value
53    ///  - Text Token: value
54    pub fn prefix_len(&self, token: &str) -> usize {
55        match self {
56            Self::Json => {
57                if let Some(pos) = token.find(',')
58                    && let Some(second_pos) = token[pos + 1..].find(',')
59                {
60                    return pos + second_pos + 2;
61                }
62                panic!("json token must be in format of <path>,<type>,<value>")
63            }
64            Self::Text => 0,
65        }
66    }
67}
68
69/// Lance full text search tokenizer.
70///
71/// `LanceTokenizer` defines 2 methods for tokenization, normally they are the same, but sometimes
72/// tokenizer needs different behavior for search and index. Take json document as an example:
73/// 1. Query text is a triplet <path,type,value>, something like `a.b,str,123`. We shouldn't use
74///    json in search, because it would be too complicated.
75/// 2. Document text is a json string.
76pub trait LanceTokenizer: Send + Sync + std::fmt::Debug {
77    /// Tokenize query text for search.
78    fn token_stream_for_search<'a>(&'a mut self, query_text: &'a str) -> BoxTokenStream<'a>;
79    /// Tokenize document text for index.
80    fn token_stream_for_doc<'a>(&'a mut self, text: &'a str) -> BoxTokenStream<'a>;
81    /// Clone the tokenizer.
82    fn box_clone(&self) -> Box<dyn LanceTokenizer>;
83    /// Get document type.
84    fn doc_type(&self) -> DocType;
85}
86
87impl Clone for Box<dyn LanceTokenizer> {
88    fn clone(&self) -> Self {
89        self.box_clone()
90    }
91}
92
93#[derive(Clone)]
94pub struct TextTokenizer {
95    tokenizer: TextAnalyzer,
96}
97
98impl std::fmt::Debug for TextTokenizer {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        write!(f, "TextTokenizer")
101    }
102}
103
104impl TextTokenizer {
105    pub fn new(tokenizer: TextAnalyzer) -> Self {
106        Self { tokenizer }
107    }
108}
109
110impl LanceTokenizer for TextTokenizer {
111    fn token_stream_for_search<'a>(&'a mut self, query_text: &'a str) -> BoxTokenStream<'a> {
112        self.tokenizer.token_stream(query_text)
113    }
114
115    fn token_stream_for_doc<'a>(&'a mut self, text: &'a str) -> BoxTokenStream<'a> {
116        self.tokenizer.token_stream(text)
117    }
118
119    fn box_clone(&self) -> Box<dyn LanceTokenizer> {
120        Box::new(self.clone())
121    }
122
123    fn doc_type(&self) -> DocType {
124        DocType::Text
125    }
126}
127
128#[derive(Clone)]
129pub struct JsonTokenizer {
130    tokenizer: TextAnalyzer,
131}
132
133impl JsonTokenizer {
134    pub fn new(tokenizer: TextAnalyzer) -> Self {
135        Self { tokenizer }
136    }
137}
138
139impl std::fmt::Debug for JsonTokenizer {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        write!(f, "JsonTokenizer")
142    }
143}
144
145impl LanceTokenizer for JsonTokenizer {
146    fn token_stream_for_search<'a>(&'a mut self, query_text: &'a str) -> BoxTokenStream<'a> {
147        let tokens = flatten_triplet(query_text, &mut self.tokenizer).unwrap();
148        BoxTokenStream::new(TTStream { tokens, index: 0 })
149    }
150
151    fn token_stream_for_doc<'a>(&'a mut self, text: &'a str) -> BoxTokenStream<'a> {
152        let value: Value = match serde_json::from_slice(text.as_bytes()) {
153            Ok(v) => v,
154            Err(e) => {
155                panic!("JSON parse error: {:?}", e);
156            }
157        };
158        let mut tokens = vec![];
159        let mut position = 0;
160        flatten_json(&value, "", &mut tokens, &mut position, &mut self.tokenizer);
161        BoxTokenStream::new(TTStream { tokens, index: 0 })
162    }
163
164    fn box_clone(&self) -> Box<dyn LanceTokenizer> {
165        Box::new(self.clone())
166    }
167
168    fn doc_type(&self) -> DocType {
169        DocType::Json
170    }
171}
172
173fn flatten_triplet(text: &str, tokenizer: &mut TextAnalyzer) -> lance_core::Result<Vec<Token>> {
174    let mut token_vec = Vec::new();
175    let mut idx = 0;
176
177    for triple in text.split(';') {
178        let parts: Vec<&str> = triple.splitn(3, ',').collect();
179        if parts.len() != 3 {
180            return Err(lance_core::Error::invalid_input_source(
181                format!("Invalid triple format: {}", triple).into(),
182            ));
183        }
184        let field = parts[0];
185        let v_type = parts[1];
186        let value = parts[2];
187
188        match v_type {
189            "number" | "bool" | "null" => {
190                let token = Token {
191                    offset_from: 0,
192                    offset_to: 0,
193                    position: idx,
194                    text: format!("{},{},{}", field, v_type, value),
195                    position_length: 1,
196                };
197                token_vec.push(token);
198                idx += 1;
199            }
200            "str" => {
201                let mut tokens = tokenizer.token_stream(value);
202                while let Some(token) = tokens.next() {
203                    token_vec.push(Token {
204                        offset_from: 0,
205                        offset_to: 0,
206                        position: idx,
207                        text: format!("{},{},{}", field, v_type, token.text),
208                        position_length: 1,
209                    });
210                    idx += 1;
211                }
212            }
213            _ => {
214                return Err(lance_core::Error::invalid_input_source(
215                    format!("Invalid triple type: {}", v_type).into(),
216                ));
217            }
218        }
219    }
220    Ok(token_vec)
221}
222
223fn flatten_json(
224    value: &Value,
225    prefix: &str,
226    out: &mut Vec<Token>,
227    position: &mut usize,
228    tokenizer: &mut TextAnalyzer,
229) {
230    match value {
231        Value::Object(map) => {
232            for (k, v) in map {
233                let next_prefix = if prefix.is_empty() {
234                    k.clone()
235                } else {
236                    format!("{}.{}", prefix, k)
237                };
238                flatten_json(v, &next_prefix, out, position, tokenizer);
239            }
240        }
241        Value::Array(arr) => {
242            for v in arr.iter() {
243                flatten_json(v, prefix, out, position, tokenizer);
244            }
245        }
246        Value::String(text) => {
247            let mut tokens = tokenizer.token_stream(text);
248            while let Some(token) = tokens.next() {
249                let token = Token {
250                    offset_from: 0,
251                    offset_to: 0,
252                    position: *position,
253                    text: format!("{},{},{}", prefix, "str", token.text),
254                    position_length: 1,
255                };
256                *position += 1;
257                out.push(token);
258            }
259        }
260        _ => {
261            let value_type = match value {
262                Value::Null => "null",
263                Value::Bool(_) => "bool",
264                Value::Number(_) => "number",
265                _ => unreachable!(),
266            };
267            let token = Token {
268                offset_from: 0,
269                offset_to: 0,
270                position: *position,
271                text: format!("{},{},{}", prefix, value_type, value),
272                position_length: 1,
273            };
274            *position += 1;
275            out.push(token);
276        }
277    }
278}
279
280struct TTStream {
281    tokens: Vec<Token>,
282    index: usize,
283}
284
285impl TokenStream for TTStream {
286    fn advance(&mut self) -> bool {
287        if self.index < self.tokens.len() {
288            self.index += 1;
289            true
290        } else {
291            false
292        }
293    }
294
295    fn token(&self) -> &Token {
296        &self.tokens[self.index - 1]
297    }
298
299    fn token_mut(&mut self) -> &mut Token {
300        &mut self.tokens[self.index - 1]
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use crate::scalar::inverted::tokenizer::document_tokenizer::{
307        JsonTokenizer, LanceTokenizer, flatten_json, flatten_triplet,
308    };
309    use lance_tokenizer::{SimpleTokenizer, TextAnalyzer, Token};
310    use serde_json::Value;
311
312    #[test]
313    fn test_json_tokenizer() {
314        let text = r#"{
315          "a": 1,
316          "b": [
317            {"c": "d"},
318            {"c": "e"}
319          ]
320        }"#;
321        let mut tokenizer =
322            JsonTokenizer::new(TextAnalyzer::builder(SimpleTokenizer::default()).build());
323        let mut stream = tokenizer.token_stream_for_doc(text);
324
325        let mut tokens: Vec<Token> = vec![];
326        while let Some(token) = stream.next() {
327            tokens.push(token.clone());
328        }
329
330        assert_eq!(tokens.len(), 3);
331        assert_token(&tokens[0], 0, "a,number,1");
332        assert_token(&tokens[1], 1, "b.c,str,d");
333        assert_token(&tokens[2], 2, "b.c,str,e");
334    }
335
336    #[test]
337    fn test_flatten_json_text() {
338        let json = r#"{
339              "a": 1,
340              "b": [
341                {"c": "hello world"},
342                {"c": "e"}
343              ],
344              "c": true,
345              "d": null,
346              "e": {
347                "f": 1.0
348              }
349          }"#;
350        let value: Value = serde_json::from_str(json).unwrap();
351
352        let mut tokens = vec![];
353        let mut tokenizer = TextAnalyzer::builder(SimpleTokenizer::default()).build();
354        let mut position = 0;
355        flatten_json(&value, "", &mut tokens, &mut position, &mut tokenizer);
356
357        assert_eq!(7, tokens.len());
358        assert_token(&tokens[0], 0, "a,number,1");
359        assert_token(&tokens[1], 1, "b.c,str,hello");
360        assert_token(&tokens[2], 2, "b.c,str,world");
361        assert_token(&tokens[3], 3, "b.c,str,e");
362        assert_token(&tokens[4], 4, "c,bool,true");
363        assert_token(&tokens[5], 5, "d,null,null");
364        assert_token(&tokens[6], 6, "e.f,number,1.0");
365    }
366
367    #[test]
368    fn test_flatten_triplet() {
369        let text = r#"a,number,1;b.c,str,d;b.c,str,e;d,str,hello world;e,number,1.0"#;
370        let mut tokenizer = TextAnalyzer::builder(SimpleTokenizer::default()).build();
371        let tokens = flatten_triplet(text, &mut tokenizer).unwrap();
372
373        assert_eq!(tokens.len(), 6);
374        assert_token(&tokens[0], 0, "a,number,1");
375        assert_token(&tokens[1], 1, "b.c,str,d");
376        assert_token(&tokens[2], 2, "b.c,str,e");
377        assert_token(&tokens[3], 3, "d,str,hello");
378        assert_token(&tokens[4], 4, "d,str,world");
379        assert_token(&tokens[5], 5, "e,number,1.0");
380    }
381
382    fn assert_token(token: &Token, position: usize, text: &str) {
383        assert_eq!(
384            token.position, position,
385            "expected position {position} but {token:?}"
386        );
387        assert_eq!(
388            token.text.as_str(),
389            text,
390            "expected text {text} but {token:?}"
391        );
392    }
393}