1use std::{
4 borrow::Cow,
5 collections::{HashMap, HashSet},
6 time::Instant,
7};
8
9use log::debug;
10use serde::{Deserialize, Serialize};
11use tantivy_stemmers::algorithms::english_porter as stemmer;
12use unicode_normalization::UnicodeNormalization;
13use unicode_segmentation::UnicodeSegmentation;
14
15use crate::stopwords::ENGLISH_NLTK;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub enum Normalization {
21 NFD,
23 NFC,
25 NFKD,
27 NFKC,
29 None,
31}
32
33impl Normalization {
34 pub fn normalize(&self, text: &str) -> String {
36 match self {
37 Normalization::NFD => text.nfd().collect(),
38 Normalization::NFC => text.nfc().collect(),
39 Normalization::NFKD => text.nfkd().collect(),
40 Normalization::NFKC => text.nfkc().collect(),
41 Normalization::None => text.to_string(),
42 }
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub enum Stemmer {
49 Snowball,
51 None,
53}
54
55impl Stemmer {
56 pub fn stem<'a>(&self, text: &'a str) -> Cow<'a, str> {
58 match self {
59 Stemmer::Snowball => stemmer(text),
60 Stemmer::None => Cow::Borrowed(text),
61 }
62 }
63}
64
65pub fn english_possessive_filter(text: &str) -> Option<String> {
67 match text.len() > 2 && text.ends_with("s") {
68 true => {
69 let chars = text.chars().collect::<Vec<_>>();
70 let c = chars[chars.len() - 2];
71 match c {
72 '\'' | '\u{2019}' | '\u{FF07}' => Some(chars[..chars.len() - 2].iter().collect()),
73 _ => None,
74 }
75 }
76 false => None,
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct Tokenizer {
83 pub min_freq: u32,
85 pub stopwords: HashSet<String>,
87 pub norm: Normalization,
89 pub stemmer: Stemmer,
91 table: HashMap<String, u32>,
92 counter: Vec<u32>,
93}
94
95impl Default for Tokenizer {
96 fn default() -> Self {
97 Self {
98 stopwords: HashSet::from_iter(ENGLISH_NLTK.iter().map(|&s| s.to_string())),
99 norm: Normalization::None,
100 stemmer: Stemmer::Snowball,
101 table: HashMap::new(),
102 counter: Vec::new(),
103 min_freq: 5,
104 }
105 }
106}
107
108impl Tokenizer {
109 fn get_token(&self, content: &str) -> Vec<String> {
110 let lowercase = content.to_lowercase();
111 let mut tokens = Vec::new();
112 for word in lowercase.unicode_words() {
113 let word = match english_possessive_filter(word) {
114 Some(w) => w,
115 None => word.to_string(),
116 };
117 if self.stopwords.contains(&word) {
118 continue;
119 }
120 let token = self.norm.normalize(self.stemmer.stem(&word).as_ref());
121 if token.is_empty() {
122 continue;
123 }
124 tokens.push(token);
125 }
126
127 tokens
128 }
129
130 pub fn fit(&mut self, contents: &[String]) {
132 let instant = Instant::now();
133 let exist_token = self.table.len();
134 for content in contents {
135 let tokens = self.get_token(content);
136 for token in tokens {
137 let length = self.table.len();
138 let entry = self.table.entry(token).or_insert(length as u32);
139 if *entry == self.counter.len() as u32 {
140 self.counter.push(0);
141 }
142 self.counter[*entry as usize] += 1;
143 }
144 }
145 debug!(
146 "fitting took {:?}, parsed {:?} lines of text, found {:?} tokens",
147 instant.elapsed().as_secs_f32(),
148 contents.len(),
149 self.table.len() - exist_token
150 );
151 }
152
153 pub fn tokenize(&self, content: &str) -> Vec<u32> {
155 let tokens = self.get_token(content);
156 let mut ids = Vec::with_capacity(tokens.len());
157 for token in tokens {
158 if let Some(&id) = self.table.get(&token) {
159 ids.push(id);
160 }
161 }
162 ids
163 }
164
165 pub fn trim(&mut self) {
167 let mut selected = HashMap::new();
168 for (token, &id) in self.table.iter() {
169 if self.counter[id as usize] >= self.min_freq {
170 selected.insert(token.clone(), selected.len() as u32);
171 }
172 }
173 debug!(
174 "trim {:?} tokens into {:?} tokens",
175 self.table.len(),
176 selected.len()
177 );
178 self.table = selected;
179 self.counter.clear();
180 }
181
182 pub fn dumps(&self) -> String {
184 serde_json::to_string(self).expect("failed to serialize")
185 }
186
187 pub fn dump(&self, path: &impl AsRef<std::path::Path>) {
189 std::fs::write(path, self.dumps()).expect("failed to write");
190 }
191
192 pub fn loads(data: &str) -> Self {
194 serde_json::from_str(data).unwrap()
195 }
196
197 pub fn load(path: &impl AsRef<std::path::Path>) -> Self {
199 serde_json::from_slice(&std::fs::read(path).expect("failed to read"))
200 .expect("failed to deserialize")
201 }
202
203 pub fn vocab_len(&self) -> usize {
205 self.table.len()
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use crate::tokenizer::english_possessive_filter;
212
213 #[test]
214 fn test_english_possessive_filter() {
215 let cases = [
216 ("John's", "John"),
217 ("John’s", "John"),
218 ("John's", "John"),
219 ("Johns", "Johns"),
220 ("John", "John"),
221 ("Johns'", "Johns'"),
222 ("John'ss", "John'ss"),
223 ("'s", "'s"),
224 ];
225
226 for (text, expected) in cases.iter() {
227 if let Some(res) = english_possessive_filter(text) {
228 assert_eq!(res, *expected);
229 }
230 }
231 }
232}