1use std::fmt::{Display, Formatter};
8
9use crate::{Token, TokenStream, Tokenizer};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct NgramError {
13 message: String,
14}
15
16impl NgramError {
17 fn invalid_argument(message: impl Into<String>) -> Self {
18 Self {
19 message: message.into(),
20 }
21 }
22}
23
24impl Display for NgramError {
25 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
26 f.write_str(&self.message)
27 }
28}
29
30impl std::error::Error for NgramError {}
31
32#[derive(Clone, Debug)]
33pub struct NgramTokenizer {
34 min_gram: usize,
35 max_gram: usize,
36 prefix_only: bool,
37 token: Token,
38}
39
40impl NgramTokenizer {
41 pub fn new(min_gram: usize, max_gram: usize, prefix_only: bool) -> Result<Self, NgramError> {
42 if min_gram == 0 {
43 return Err(NgramError::invalid_argument(
44 "min_gram must be greater than 0",
45 ));
46 }
47 if min_gram > max_gram {
48 return Err(NgramError::invalid_argument(
49 "min_gram must not be greater than max_gram",
50 ));
51 }
52 Ok(Self {
53 min_gram,
54 max_gram,
55 prefix_only,
56 token: Token::default(),
57 })
58 }
59
60 pub fn all_ngrams(min_gram: usize, max_gram: usize) -> Result<Self, NgramError> {
61 Self::new(min_gram, max_gram, false)
62 }
63
64 pub fn prefix_only(min_gram: usize, max_gram: usize) -> Result<Self, NgramError> {
65 Self::new(min_gram, max_gram, true)
66 }
67}
68
69pub struct NgramTokenStream<'a> {
70 ngram_charidx_iterator: StutteringIterator<CodepointFrontiers<'a>>,
71 prefix_only: bool,
72 text: &'a str,
73 token: &'a mut Token,
74}
75
76impl Tokenizer for NgramTokenizer {
77 type TokenStream<'a> = NgramTokenStream<'a>;
78
79 fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> {
80 self.token.reset();
81 NgramTokenStream {
82 ngram_charidx_iterator: StutteringIterator::new(
83 CodepointFrontiers::for_str(text),
84 self.min_gram,
85 self.max_gram,
86 ),
87 prefix_only: self.prefix_only,
88 text,
89 token: &mut self.token,
90 }
91 }
92}
93
94impl TokenStream for NgramTokenStream<'_> {
95 fn advance(&mut self) -> bool {
96 if let Some((offset_from, offset_to)) = self.ngram_charidx_iterator.next() {
97 if self.prefix_only && offset_from > 0 {
98 return false;
99 }
100 self.token.position = 0;
101 self.token.offset_from = offset_from;
102 self.token.offset_to = offset_to;
103 self.token.text.clear();
104 self.token.text.push_str(&self.text[offset_from..offset_to]);
105 true
106 } else {
107 false
108 }
109 }
110
111 fn token(&self) -> &Token {
112 self.token
113 }
114
115 fn token_mut(&mut self) -> &mut Token {
116 self.token
117 }
118}
119
120struct StutteringIterator<T> {
121 underlying: T,
122 min_gram: usize,
123 max_gram: usize,
124 memory: Vec<usize>,
125 cursor: usize,
126 gram_len: usize,
127}
128
129impl<T> StutteringIterator<T>
130where
131 T: Iterator<Item = usize>,
132{
133 fn new(mut underlying: T, min_gram: usize, max_gram: usize) -> Self {
134 debug_assert!(min_gram > 0, "min_gram must be positive");
135 let memory: Vec<usize> = (&mut underlying).take(max_gram + 1).collect();
136 if memory.len() <= min_gram {
137 Self {
138 underlying,
139 min_gram: 1,
140 max_gram: 0,
141 memory,
142 cursor: 0,
143 gram_len: 0,
144 }
145 } else {
146 Self {
147 underlying,
148 min_gram,
149 max_gram: memory.len() - 1,
150 memory,
151 cursor: 0,
152 gram_len: min_gram,
153 }
154 }
155 }
156}
157
158impl<T> Iterator for StutteringIterator<T>
159where
160 T: Iterator<Item = usize>,
161{
162 type Item = (usize, usize);
163
164 fn next(&mut self) -> Option<Self::Item> {
165 if self.gram_len > self.max_gram {
166 self.gram_len = self.min_gram;
167 if let Some(next_val) = self.underlying.next() {
168 self.memory[self.cursor] = next_val;
169 } else {
170 self.max_gram -= 1;
171 }
172 self.cursor += 1;
173 if self.cursor >= self.memory.len() {
174 self.cursor = 0;
175 }
176 }
177 if self.max_gram < self.min_gram {
178 return None;
179 }
180 let start = self.memory[self.cursor % self.memory.len()];
181 let stop = self.memory[(self.cursor + self.gram_len) % self.memory.len()];
182 self.gram_len += 1;
183 Some((start, stop))
184 }
185}
186
187struct CodepointFrontiers<'a> {
188 text: &'a str,
189 next_offset: Option<usize>,
190}
191
192impl<'a> CodepointFrontiers<'a> {
193 fn for_str(text: &'a str) -> Self {
194 Self {
195 text,
196 next_offset: Some(0),
197 }
198 }
199}
200
201impl Iterator for CodepointFrontiers<'_> {
202 type Item = usize;
203
204 fn next(&mut self) -> Option<Self::Item> {
205 let offset = self.next_offset?;
206 if self.text.is_empty() {
207 self.next_offset = None;
208 } else {
209 let width = utf8_codepoint_width(self.text.as_bytes()[0]);
210 self.text = &self.text[width..];
211 self.next_offset = Some(offset + width);
212 }
213 Some(offset)
214 }
215}
216
217const CODEPOINT_UTF8_WIDTH: [u8; 16] = [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 4];
218
219fn utf8_codepoint_width(byte: u8) -> usize {
220 CODEPOINT_UTF8_WIDTH[(byte as usize) >> 4] as usize
221}