tocken/
tokenizer.rs

1//! Naive tokenizer for Unicode.
2
3use std::{
4    borrow::Cow,
5    collections::{HashMap, HashSet},
6    time::Instant,
7};
8
9use log::debug;
10use serde::{Deserialize, Serialize};
11use tantivy_stemmers::algorithms::english_porter as stemmer;
12use unicode_normalization::UnicodeNormalization;
13use unicode_segmentation::UnicodeSegmentation;
14
15use crate::stopwords::ENGLISH_NLTK;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18/// Unicode normalization.
19/// Ref: https://unicode.org/reports/tr15/
20pub enum Normalization {
21    /// Canonical Decomposition
22    NFD,
23    /// Canonical Decomposition, followed by Canonical Composition
24    NFC,
25    /// Compatibility Decomposition
26    NFKD,
27    /// Compatibility Decomposition, followed by Canonical Composition
28    NFKC,
29    /// No normalization
30    None,
31}
32
33impl Normalization {
34    /// Normalize the text.
35    pub fn normalize(&self, text: &str) -> String {
36        match self {
37            Normalization::NFD => text.nfd().collect(),
38            Normalization::NFC => text.nfc().collect(),
39            Normalization::NFKD => text.nfkd().collect(),
40            Normalization::NFKC => text.nfkc().collect(),
41            Normalization::None => text.to_string(),
42        }
43    }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47/// Stemmer method.
48pub enum Stemmer {
49    /// https://snowballstem.org/algorithms/
50    Snowball,
51    /// No stemmer
52    None,
53}
54
55impl Stemmer {
56    /// Stem the text.
57    pub fn stem<'a>(&self, text: &'a str) -> Cow<'a, str> {
58        match self {
59            Stemmer::Snowball => stemmer(text),
60            Stemmer::None => Cow::Borrowed(text),
61        }
62    }
63}
64
65/// English possessive filter.
66pub fn english_possessive_filter(text: &str) -> Option<String> {
67    match text.len() > 2 && text.ends_with("s") {
68        true => {
69            let chars = text.chars().collect::<Vec<_>>();
70            let c = chars[chars.len() - 2];
71            match c {
72                '\'' | '\u{2019}' | '\u{FF07}' => Some(chars[..chars.len() - 2].iter().collect()),
73                _ => None,
74            }
75        }
76        false => None,
77    }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81/// Tokenizer for text keyword match.
82pub struct Tokenizer {
83    /// The minimum frequency of the token to be kept when calling `trim`.
84    pub min_freq: u32,
85    /// The stopwords set.
86    pub stopwords: HashSet<String>,
87    /// The normalization method.
88    pub norm: Normalization,
89    /// The stemmer method.
90    pub stemmer: Stemmer,
91    table: HashMap<String, u32>,
92    counter: Vec<u32>,
93}
94
95impl Default for Tokenizer {
96    fn default() -> Self {
97        Self {
98            stopwords: HashSet::from_iter(ENGLISH_NLTK.iter().map(|&s| s.to_string())),
99            norm: Normalization::None,
100            stemmer: Stemmer::Snowball,
101            table: HashMap::new(),
102            counter: Vec::new(),
103            min_freq: 5,
104        }
105    }
106}
107
108impl Tokenizer {
109    fn get_token(&self, content: &str) -> Vec<String> {
110        let lowercase = content.to_lowercase();
111        let mut tokens = Vec::new();
112        for word in lowercase.unicode_words() {
113            let word = match english_possessive_filter(word) {
114                Some(w) => w,
115                None => word.to_string(),
116            };
117            if self.stopwords.contains(&word) {
118                continue;
119            }
120            let token = self.norm.normalize(self.stemmer.stem(&word).as_ref());
121            if token.is_empty() {
122                continue;
123            }
124            tokens.push(token);
125        }
126
127        tokens
128    }
129
130    /// Fit the tokenizer with the contents. Re-call this function will update the tokenizer.
131    pub fn fit(&mut self, contents: &[String]) {
132        let instant = Instant::now();
133        let exist_token = self.table.len();
134        for content in contents {
135            let tokens = self.get_token(content);
136            for token in tokens {
137                let length = self.table.len();
138                let entry = self.table.entry(token).or_insert(length as u32);
139                if *entry == self.counter.len() as u32 {
140                    self.counter.push(0);
141                }
142                self.counter[*entry as usize] += 1;
143            }
144        }
145        debug!(
146            "fitting took {:?}, parsed {:?} lines of text, found {:?} tokens",
147            instant.elapsed().as_secs_f32(),
148            contents.len(),
149            self.table.len() - exist_token
150        );
151    }
152
153    /// Tokenize the content and return the token ids.
154    pub fn tokenize(&self, content: &str) -> Vec<u32> {
155        let tokens = self.get_token(content);
156        let mut ids = Vec::with_capacity(tokens.len());
157        for token in tokens {
158            if let Some(&id) = self.table.get(&token) {
159                ids.push(id);
160            }
161        }
162        ids
163    }
164
165    /// This will trim the `table` according to the `min_freq` and clean the `counter`.
166    pub fn trim(&mut self) {
167        let mut selected = HashMap::new();
168        for (token, &id) in self.table.iter() {
169            if self.counter[id as usize] >= self.min_freq {
170                selected.insert(token.clone(), selected.len() as u32);
171            }
172        }
173        debug!(
174            "trim {:?} tokens into {:?} tokens",
175            self.table.len(),
176            selected.len()
177        );
178        self.table = selected;
179        self.counter.clear();
180    }
181
182    /// Serialize the tokenizer into a JSON string.
183    pub fn dumps(&self) -> String {
184        serde_json::to_string(self).expect("failed to serialize")
185    }
186
187    /// Serialize the tokenizer into a JSON file.
188    pub fn dump(&self, path: &impl AsRef<std::path::Path>) {
189        std::fs::write(path, self.dumps()).expect("failed to write");
190    }
191
192    /// Deserialize the tokenizer from a JSON string.
193    pub fn loads(data: &str) -> Self {
194        serde_json::from_str(data).unwrap()
195    }
196
197    /// Deserialize the tokenizer from a JSON file.
198    pub fn load(path: &impl AsRef<std::path::Path>) -> Self {
199        serde_json::from_slice(&std::fs::read(path).expect("failed to read"))
200            .expect("failed to deserialize")
201    }
202
203    /// Get the total token number.
204    pub fn vocab_len(&self) -> usize {
205        self.table.len()
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use crate::tokenizer::english_possessive_filter;
212
213    #[test]
214    fn test_english_possessive_filter() {
215        let cases = [
216            ("John's", "John"),
217            ("John’s", "John"),
218            ("John's", "John"),
219            ("Johns", "Johns"),
220            ("John", "John"),
221            ("Johns'", "Johns'"),
222            ("John'ss", "John'ss"),
223            ("'s", "'s"),
224        ];
225
226        for (text, expected) in cases.iter() {
227            if let Some(res) = english_possessive_filter(text) {
228                assert_eq!(res, *expected);
229            }
230        }
231    }
232}