Skip to main content

lance_tokenizer/
ngram_tokenizer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3// SPDX-License-Identifier: MIT
4// Adapted from Tantivy v0.24.2 ngram tokenizer.
5// Copyright (c) 2017-present Tantivy contributors.
6
7use std::fmt::{Display, Formatter};
8
9use crate::{Token, TokenStream, Tokenizer};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct NgramError {
13    message: String,
14}
15
16impl NgramError {
17    fn invalid_argument(message: impl Into<String>) -> Self {
18        Self {
19            message: message.into(),
20        }
21    }
22}
23
24impl Display for NgramError {
25    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
26        f.write_str(&self.message)
27    }
28}
29
30impl std::error::Error for NgramError {}
31
32#[derive(Clone, Debug)]
33pub struct NgramTokenizer {
34    min_gram: usize,
35    max_gram: usize,
36    prefix_only: bool,
37    token: Token,
38}
39
40impl NgramTokenizer {
41    pub fn new(min_gram: usize, max_gram: usize, prefix_only: bool) -> Result<Self, NgramError> {
42        if min_gram == 0 {
43            return Err(NgramError::invalid_argument(
44                "min_gram must be greater than 0",
45            ));
46        }
47        if min_gram > max_gram {
48            return Err(NgramError::invalid_argument(
49                "min_gram must not be greater than max_gram",
50            ));
51        }
52        Ok(Self {
53            min_gram,
54            max_gram,
55            prefix_only,
56            token: Token::default(),
57        })
58    }
59
60    pub fn all_ngrams(min_gram: usize, max_gram: usize) -> Result<Self, NgramError> {
61        Self::new(min_gram, max_gram, false)
62    }
63
64    pub fn prefix_only(min_gram: usize, max_gram: usize) -> Result<Self, NgramError> {
65        Self::new(min_gram, max_gram, true)
66    }
67}
68
69pub struct NgramTokenStream<'a> {
70    ngram_charidx_iterator: StutteringIterator<CodepointFrontiers<'a>>,
71    prefix_only: bool,
72    text: &'a str,
73    token: &'a mut Token,
74}
75
76impl Tokenizer for NgramTokenizer {
77    type TokenStream<'a> = NgramTokenStream<'a>;
78
79    fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> {
80        self.token.reset();
81        NgramTokenStream {
82            ngram_charidx_iterator: StutteringIterator::new(
83                CodepointFrontiers::for_str(text),
84                self.min_gram,
85                self.max_gram,
86            ),
87            prefix_only: self.prefix_only,
88            text,
89            token: &mut self.token,
90        }
91    }
92}
93
94impl TokenStream for NgramTokenStream<'_> {
95    fn advance(&mut self) -> bool {
96        if let Some((offset_from, offset_to)) = self.ngram_charidx_iterator.next() {
97            if self.prefix_only && offset_from > 0 {
98                return false;
99            }
100            self.token.position = 0;
101            self.token.offset_from = offset_from;
102            self.token.offset_to = offset_to;
103            self.token.text.clear();
104            self.token.text.push_str(&self.text[offset_from..offset_to]);
105            true
106        } else {
107            false
108        }
109    }
110
111    fn token(&self) -> &Token {
112        self.token
113    }
114
115    fn token_mut(&mut self) -> &mut Token {
116        self.token
117    }
118}
119
120struct StutteringIterator<T> {
121    underlying: T,
122    min_gram: usize,
123    max_gram: usize,
124    memory: Vec<usize>,
125    cursor: usize,
126    gram_len: usize,
127}
128
129impl<T> StutteringIterator<T>
130where
131    T: Iterator<Item = usize>,
132{
133    fn new(mut underlying: T, min_gram: usize, max_gram: usize) -> Self {
134        debug_assert!(min_gram > 0, "min_gram must be positive");
135        let memory: Vec<usize> = (&mut underlying).take(max_gram + 1).collect();
136        if memory.len() <= min_gram {
137            Self {
138                underlying,
139                min_gram: 1,
140                max_gram: 0,
141                memory,
142                cursor: 0,
143                gram_len: 0,
144            }
145        } else {
146            Self {
147                underlying,
148                min_gram,
149                max_gram: memory.len() - 1,
150                memory,
151                cursor: 0,
152                gram_len: min_gram,
153            }
154        }
155    }
156}
157
158impl<T> Iterator for StutteringIterator<T>
159where
160    T: Iterator<Item = usize>,
161{
162    type Item = (usize, usize);
163
164    fn next(&mut self) -> Option<Self::Item> {
165        if self.gram_len > self.max_gram {
166            self.gram_len = self.min_gram;
167            if let Some(next_val) = self.underlying.next() {
168                self.memory[self.cursor] = next_val;
169            } else {
170                self.max_gram -= 1;
171            }
172            self.cursor += 1;
173            if self.cursor >= self.memory.len() {
174                self.cursor = 0;
175            }
176        }
177        if self.max_gram < self.min_gram {
178            return None;
179        }
180        let start = self.memory[self.cursor % self.memory.len()];
181        let stop = self.memory[(self.cursor + self.gram_len) % self.memory.len()];
182        self.gram_len += 1;
183        Some((start, stop))
184    }
185}
186
187struct CodepointFrontiers<'a> {
188    text: &'a str,
189    next_offset: Option<usize>,
190}
191
192impl<'a> CodepointFrontiers<'a> {
193    fn for_str(text: &'a str) -> Self {
194        Self {
195            text,
196            next_offset: Some(0),
197        }
198    }
199}
200
201impl Iterator for CodepointFrontiers<'_> {
202    type Item = usize;
203
204    fn next(&mut self) -> Option<Self::Item> {
205        let offset = self.next_offset?;
206        if self.text.is_empty() {
207            self.next_offset = None;
208        } else {
209            let width = utf8_codepoint_width(self.text.as_bytes()[0]);
210            self.text = &self.text[width..];
211            self.next_offset = Some(offset + width);
212        }
213        Some(offset)
214    }
215}
216
217const CODEPOINT_UTF8_WIDTH: [u8; 16] = [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 4];
218
219fn utf8_codepoint_width(byte: u8) -> usize {
220    CODEPOINT_UTF8_WIDTH[(byte as usize) >> 4] as usize
221}