1use bitflags::bitflags;
2use cached::proc_macro::cached;
3use rust_stemmers::{Algorithm as StemmingAlgorithm, Stemmer};
4use std::{
5 borrow::Cow,
6 collections::HashSet,
7 fmt::{self, Debug},
8};
9use stop_words::LANGUAGE as StopWordLanguage;
10#[cfg(feature = "language_detection")]
11use whichlang::Lang as DetectedLanguage;
12
13use crate::tokenizer::Tokenizer;
14
15#[allow(missing_docs)]
17#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
18pub enum Language {
19 Arabic,
20 Danish,
21 Dutch,
22 English,
23 French,
24 German,
25 Greek,
26 Hungarian,
27 Italian,
28 Norwegian,
29 Portuguese,
30 Romanian,
31 Russian,
32 Spanish,
33 Swedish,
34 Tamil,
35 Turkish,
36}
37
38#[non_exhaustive]
42#[derive(Debug, Clone, PartialEq)]
43pub enum LanguageMode {
44 #[cfg(feature = "language_detection")]
46 Detect,
47 Fixed(Language),
49}
50
51impl Default for LanguageMode {
52 fn default() -> Self {
53 LanguageMode::Fixed(Language::English)
54 }
55}
56
57impl From<Language> for LanguageMode {
58 fn from(language: Language) -> Self {
59 LanguageMode::Fixed(language)
60 }
61}
62
63#[cfg(feature = "language_detection")]
64#[derive(Debug, PartialEq)]
65pub struct LanguageFromDetectionError {
66 pub lang: DetectedLanguage,
67}
68
69#[cfg(feature = "language_detection")]
70impl std::fmt::Display for LanguageFromDetectionError {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 write!(f, "Unsupported detected language: {:?}", self.lang)
73 }
74}
75
76#[cfg(feature = "language_detection")]
77impl std::error::Error for LanguageFromDetectionError {
78 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
79 None
80 }
81}
82
83#[cfg(feature = "language_detection")]
84impl TryFrom<DetectedLanguage> for Language {
85 type Error = LanguageFromDetectionError;
86
87 fn try_from(detected_language: DetectedLanguage) -> Result<Self, Self::Error> {
88 match detected_language {
89 DetectedLanguage::Ara => Ok(Language::Arabic),
90 DetectedLanguage::Cmn => Err(LanguageFromDetectionError {
91 lang: detected_language,
92 }),
93 DetectedLanguage::Deu => Ok(Language::German),
94 DetectedLanguage::Eng => Ok(Language::English),
95 DetectedLanguage::Fra => Ok(Language::French),
96 DetectedLanguage::Hin => Err(LanguageFromDetectionError {
97 lang: detected_language,
98 }),
99 DetectedLanguage::Ita => Ok(Language::Italian),
100 DetectedLanguage::Jpn => Err(LanguageFromDetectionError {
101 lang: detected_language,
102 }),
103 DetectedLanguage::Kor => Err(LanguageFromDetectionError {
104 lang: detected_language,
105 }),
106 DetectedLanguage::Nld => Ok(Language::Dutch),
107 DetectedLanguage::Por => Ok(Language::Portuguese),
108 DetectedLanguage::Rus => Ok(Language::Russian),
109 DetectedLanguage::Spa => Ok(Language::Spanish),
110 DetectedLanguage::Swe => Ok(Language::Swedish),
111 DetectedLanguage::Tur => Ok(Language::Turkish),
112 DetectedLanguage::Vie => Err(LanguageFromDetectionError {
113 lang: detected_language,
114 }),
115 }
116 }
117}
118
119impl From<&Language> for StemmingAlgorithm {
120 fn from(language: &Language) -> Self {
121 match language {
122 Language::Arabic => StemmingAlgorithm::Arabic,
123 Language::Danish => StemmingAlgorithm::Danish,
124 Language::Dutch => StemmingAlgorithm::Dutch,
125 Language::English => StemmingAlgorithm::English,
126 Language::French => StemmingAlgorithm::French,
127 Language::German => StemmingAlgorithm::German,
128 Language::Greek => StemmingAlgorithm::Greek,
129 Language::Hungarian => StemmingAlgorithm::Hungarian,
130 Language::Italian => StemmingAlgorithm::Italian,
131 Language::Norwegian => StemmingAlgorithm::Norwegian,
132 Language::Portuguese => StemmingAlgorithm::Portuguese,
133 Language::Romanian => StemmingAlgorithm::Romanian,
134 Language::Russian => StemmingAlgorithm::Russian,
135 Language::Spanish => StemmingAlgorithm::Spanish,
136 Language::Swedish => StemmingAlgorithm::Swedish,
137 Language::Tamil => StemmingAlgorithm::Tamil,
138 Language::Turkish => StemmingAlgorithm::Turkish,
139 }
140 }
141}
142
143#[derive(Debug, PartialEq)]
144pub struct StopWordLanguageError {
145 pub lang: Language,
146}
147
148impl std::fmt::Display for StopWordLanguageError {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 write!(f, "Unsupported language: {:?}", self.lang)
151 }
152}
153
154impl std::error::Error for StopWordLanguageError {
155 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
156 None
157 }
158}
159
160impl TryFrom<&Language> for StopWordLanguage {
161 type Error = StopWordLanguageError;
162
163 fn try_from(language: &Language) -> Result<Self, Self::Error> {
164 match language {
165 Language::Arabic => Ok(StopWordLanguage::Arabic),
166 Language::Danish => Ok(StopWordLanguage::Danish),
167 Language::Dutch => Ok(StopWordLanguage::Dutch),
168 Language::English => Ok(StopWordLanguage::English),
169 Language::French => Ok(StopWordLanguage::French),
170 Language::German => Ok(StopWordLanguage::German),
171 Language::Greek => Ok(StopWordLanguage::Greek),
172 Language::Hungarian => Ok(StopWordLanguage::Hungarian),
173 Language::Italian => Ok(StopWordLanguage::Italian),
174 Language::Norwegian => Ok(StopWordLanguage::Norwegian),
175 Language::Portuguese => Ok(StopWordLanguage::Portuguese),
176 Language::Romanian => Ok(StopWordLanguage::Romanian),
177 Language::Russian => Ok(StopWordLanguage::Russian),
178 Language::Spanish => Ok(StopWordLanguage::Spanish),
179 Language::Swedish => Ok(StopWordLanguage::Swedish),
180 Language::Tamil => Err(StopWordLanguageError { lang: *language }),
181 Language::Turkish => Ok(StopWordLanguage::Turkish),
182 }
183 }
184}
185
186fn normalize(text: &str) -> Cow<'_, str> {
187 deunicode::deunicode_with_tofu_cow(text, "[?]")
188}
189
190#[cached(size = 16)]
191fn get_stopwords(language: Language, normalized: bool) -> HashSet<String> {
192 match TryInto::<StopWordLanguage>::try_into(&language) {
193 Err(_) => HashSet::new(),
194 Ok(lang) => stop_words::get(lang)
195 .iter()
196 .map(|w| match normalized {
197 true => normalize(w).into(),
198 false => w.to_string(),
199 })
200 .collect(),
201 }
202}
203
204fn get_stemmer(language: &Language) -> Stemmer {
205 Stemmer::create(language.into())
206}
207
208struct WordIter {
209 text: String,
210 offset: usize,
211}
212
213impl WordIter {
214 fn new(text: String) -> Self {
215 WordIter { text, offset: 0 }
216 }
217}
218
219impl Iterator for WordIter {
220 type Item = String;
221
222 fn next(&mut self) -> Option<Self::Item> {
223 use unicode_segmentation::UnicodeSegmentation;
224
225 let slice = &self.text[self.offset..];
226 let mut words = slice.unicode_word_indices();
227 let (relative_idx, word) = words.next()?;
228 self.offset += relative_idx + word.len();
229 Some(word.to_string())
230 }
231}
232
233struct TokenIterBorrowed<'a> {
234 word_iter: WordIter,
235 stopwords: &'a HashSet<String>,
236 stemmer: Option<&'a Stemmer>,
237}
238
239impl<'a> Iterator for TokenIterBorrowed<'a> {
240 type Item = String;
241
242 fn next(&mut self) -> Option<Self::Item> {
243 loop {
244 let token = self.word_iter.next()?;
245 if self.stopwords.contains(&token) {
246 continue;
247 }
248 return Some(match self.stemmer {
249 Some(stemmer) => stemmer.stem(&token).to_string(),
250 None => token,
251 });
252 }
253 }
254}
255
256#[cfg(feature = "language_detection")]
257struct TokenIterOwned {
258 word_iter: WordIter,
259 stopwords: HashSet<String>,
260 stemmer: Option<Stemmer>,
261}
262
263#[cfg(feature = "language_detection")]
264impl Iterator for TokenIterOwned {
265 type Item = String;
266
267 fn next(&mut self) -> Option<Self::Item> {
268 loop {
269 let token = self.word_iter.next()?;
270 if self.stopwords.contains(&token) {
271 continue;
272 }
273 return Some(match &self.stemmer {
274 Some(stemmer) => stemmer.stem(&token).to_string(),
275 None => token,
276 });
277 }
278 }
279}
280
281bitflags! {
282 #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
283 struct Settings: u8 {
284 const NORMALIZATION = 1 << 0;
285 const STEMMING = 1 << 1;
286 const STOPWORDS = 1 << 2;
287 }
288}
289
290impl Settings {
291 fn new(stemming: bool, stopwords: bool, normalization: bool) -> Self {
292 Settings::from_bits_retain(
293 (normalization as u8 * Settings::NORMALIZATION.bits())
294 | (stemming as u8 * Settings::STEMMING.bits())
295 | (stopwords as u8 * Settings::STOPWORDS.bits()),
296 )
297 }
298
299 fn normalization_enabled(self) -> bool {
300 self.contains(Settings::NORMALIZATION)
301 }
302
303 fn stemming_enabled(self) -> bool {
304 self.contains(Settings::STEMMING)
305 }
306
307 fn stopwords_enabled(self) -> bool {
308 self.contains(Settings::STOPWORDS)
309 }
310}
311
312struct Components {
313 settings: Settings,
314 normalizer: fn(&str) -> Cow<str>,
315 stemmer: Option<Stemmer>,
316 stopwords: HashSet<String>,
317}
318
319impl Components {
320 fn new(settings: Settings, language: Option<&Language>) -> Self {
321 let stemmer = language.and_then(|lang| {
322 if settings.stemming_enabled() {
323 Some(get_stemmer(lang))
324 } else {
325 None
326 }
327 });
328 let stopwords = language.map_or_else(HashSet::new, |lang| {
329 if settings.stopwords_enabled() {
330 get_stopwords(*lang, settings.normalization_enabled())
331 } else {
332 HashSet::new()
333 }
334 });
335 let normalizer: fn(&str) -> Cow<str> = match settings.normalization_enabled() {
336 true => normalize,
337 false => |text: &str| Cow::from(text),
338 };
339 Self {
340 settings,
341 stemmer,
342 stopwords,
343 normalizer,
344 }
345 }
346}
347
348#[non_exhaustive]
349enum Resources {
350 Static(Components),
351 #[cfg(feature = "language_detection")]
352 Dynamic(Settings),
353}
354
355pub struct DefaultTokenizer {
356 resources: Resources,
357}
358
359impl Debug for DefaultTokenizer {
360 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 let settings = match &self.resources {
362 Resources::Static(components) => components.settings,
363 #[cfg(feature = "language_detection")]
364 Resources::Dynamic(settings) => *settings,
365 };
366 write!(f, "DefaultTokenizer({settings:?})")
367 }
368}
369
370impl DefaultTokenizer {
371 pub fn new(language_mode: impl Into<LanguageMode>) -> DefaultTokenizer {
373 Self::builder().language_mode(language_mode).build()
374 }
375
376 pub fn builder() -> DefaultTokenizerBuilder {
378 DefaultTokenizerBuilder::new()
379 }
380
381 fn _new(
382 language_mode: impl Into<LanguageMode>,
383 normalization: bool,
384 stemming: bool,
385 stopwords: bool,
386 ) -> DefaultTokenizer {
387 let language_mode = &language_mode.into();
388 let settings = Settings::new(stemming, stopwords, normalization);
389 let resources = match language_mode {
390 #[cfg(feature = "language_detection")]
391 LanguageMode::Detect => Resources::Dynamic(settings),
392 LanguageMode::Fixed(lang) => Resources::Static(Components::new(settings, Some(lang))),
393 };
394 DefaultTokenizer { resources }
395 }
396
397 #[cfg(feature = "language_detection")]
398 fn detect_language(text: &str) -> Option<Language> {
399 Language::try_from(whichlang::detect_language(text)).ok()
400 }
401
402 fn tokenize<'a>(&'a self, input_text: &'a str) -> impl Iterator<Item = String> + 'a {
403 enum TokenStream<'a> {
404 Borrowed(TokenIterBorrowed<'a>),
405 #[cfg(feature = "language_detection")]
406 Owned(TokenIterOwned),
407 }
408
409 impl<'a> Iterator for TokenStream<'a> {
410 type Item = String;
411
412 fn next(&mut self) -> Option<Self::Item> {
413 match self {
414 TokenStream::Borrowed(iter) => iter.next(),
415 #[cfg(feature = "language_detection")]
416 TokenStream::Owned(iter) => iter.next(),
417 }
418 }
419 }
420
421 let make_word_iter = |input: &str, normalizer: fn(&str) -> Cow<str>| {
422 WordIter::new(normalizer(input).to_lowercase())
423 };
424
425 match &self.resources {
426 Resources::Static(components) => TokenStream::Borrowed(TokenIterBorrowed {
427 word_iter: make_word_iter(input_text, components.normalizer),
428 stopwords: &components.stopwords,
429 stemmer: components.stemmer.as_ref(),
430 }),
431 #[cfg(feature = "language_detection")]
432 Resources::Dynamic(settings) => {
433 let detected_language = Self::detect_language(input_text);
434 let components = Components::new(*settings, detected_language.as_ref());
435
436 TokenStream::Owned(TokenIterOwned {
437 word_iter: make_word_iter(input_text, components.normalizer),
438 stopwords: components.stopwords,
439 stemmer: components.stemmer,
440 })
441 }
442 }
443 }
444}
445
446impl Tokenizer for DefaultTokenizer {
447 fn tokenize<'a>(&'a self, input_text: &'a str) -> impl Iterator<Item = String> + 'a {
448 DefaultTokenizer::tokenize(self, input_text)
449 }
450}
451
452impl Default for DefaultTokenizer {
453 fn default() -> Self {
454 DefaultTokenizer::new(LanguageMode::default())
455 }
456}
457
458pub struct DefaultTokenizerBuilder {
459 language_mode: LanguageMode,
460 normalization: bool,
461 stemming: bool,
462 stopwords: bool,
463}
464
465impl Default for DefaultTokenizerBuilder {
466 fn default() -> Self {
467 DefaultTokenizerBuilder::new()
468 }
469}
470
471impl DefaultTokenizerBuilder {
472 pub fn new() -> DefaultTokenizerBuilder {
474 DefaultTokenizerBuilder {
475 language_mode: LanguageMode::default(),
476 normalization: true,
477 stemming: true,
478 stopwords: true,
479 }
480 }
481
482 pub fn language_mode(mut self, language_mode: impl Into<LanguageMode>) -> Self {
484 self.language_mode = language_mode.into();
485 self
486 }
487
488 pub fn normalization(mut self, normalization: bool) -> Self {
492 self.normalization = normalization;
493 self
494 }
495
496 pub fn stemming(mut self, stemming: bool) -> Self {
500 self.stemming = stemming;
501 self
502 }
503
504 pub fn stopwords(mut self, stopwords: bool) -> Self {
508 self.stopwords = stopwords;
509 self
510 }
511
512 pub fn build(self) -> DefaultTokenizer {
514 DefaultTokenizer::_new(
515 self.language_mode,
516 self.normalization,
517 self.stemming,
518 self.stopwords,
519 )
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use crate::test_data_loader::tests::{read_recipes, Recipe};
526
527 use super::*;
528
529 use insta::assert_debug_snapshot;
530
531 fn tokenize_recipes(recipe_file: &str, language_mode: LanguageMode) -> Vec<Vec<String>> {
532 let recipes = read_recipes(recipe_file);
533
534 recipes
535 .iter()
536 .map(|Recipe { recipe, .. }| {
537 let tokenizer = DefaultTokenizer::new(language_mode.clone());
538 tokenizer.tokenize(recipe).collect::<Vec<_>>()
539 })
540 .collect()
541 }
542
543 #[test]
544 fn it_can_tokenize_english() {
545 let text = "space station";
546 let tokenizer = DefaultTokenizer::new(Language::English);
547
548 let tokens: Vec<_> = tokenizer.tokenize(text).collect();
549
550 assert_eq!(tokens, vec!["space", "station"]);
551 }
552
553 #[test]
554 fn it_converts_to_lowercase() {
555 let text = "SPACE STATION";
556 let tokenizer = DefaultTokenizer::new(Language::English);
557
558 let tokens: Vec<_> = tokenizer.tokenize(text).collect();
559
560 assert_eq!(tokens, vec!["space", "station"]);
561 }
562
563 #[test]
564 fn it_removes_whitespace() {
565 let text = "\tspace\r\nstation\n space station";
566 let tokenizer = DefaultTokenizer::new(Language::English);
567
568 let tokens: Vec<_> = tokenizer.tokenize(text).collect();
569
570 assert_eq!(tokens, vec!["space", "station", "space", "station"]);
571 }
572
573 #[test]
574 fn it_removes_stopwords() {
575 let text = "i me my myself we our ours ourselves you you're you've you'll you'd";
576 let tokenizer = DefaultTokenizer::new(Language::English);
577
578 let tokens: Vec<_> = tokenizer.tokenize(text).collect();
579
580 assert!(tokens.is_empty());
581 }
582
583 #[test]
584 fn it_keeps_numbers() {
585 let text = "42 1337 3.14";
586 let tokenizer = DefaultTokenizer::new(Language::English);
587
588 let tokens: Vec<_> = tokenizer.tokenize(text).collect();
589
590 assert_eq!(tokens, vec!["42", "1337", "3.14"]);
591 }
592
593 #[test]
594 fn it_keeps_contracted_words() {
595 let text = "can't you're won't let's couldn't've";
596 let tokenizer = DefaultTokenizer::builder()
597 .language_mode(Language::English)
598 .stemming(false)
599 .stopwords(false)
600 .build();
601
602 let tokens: Vec<_> = tokenizer.tokenize(text).collect();
603
604 assert_eq!(
605 tokens,
606 vec!["can't", "you're", "won't", "let's", "couldn't've"]
607 );
608 }
609
610 #[test]
611 fn it_removes_punctuation() {
612 let test_cases = vec![
613 ("space, station!", vec!["space", "station"]),
614 ("space,station", vec!["space", "station"]),
615 ("!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~", vec![]),
616 ];
617 let tokenizer = DefaultTokenizer::new(Language::English);
618
619 for (text, expected) in test_cases {
620 let tokens: Vec<_> = tokenizer.tokenize(text).collect();
621 assert_eq!(tokens, expected);
622 }
623 }
624
625 #[test]
626 fn it_stems_words() {
627 let text = "connection connections connective connected connecting connect";
628 let tokenizer = DefaultTokenizer::new(Language::English);
629
630 let tokens: Vec<_> = tokenizer.tokenize(text).collect();
631
632 assert_eq!(
633 tokens,
634 vec!["connect", "connect", "connect", "connect", "connect", "connect"]
635 );
636 }
637
638 #[test]
639 fn it_tokenizes_emojis_as_text() {
640 let text = "🍕 🚀 🍋";
641 let tokenizer = DefaultTokenizer::new(Language::English);
642
643 let tokens: Vec<_> = tokenizer.tokenize(text).collect();
644
645 assert_eq!(tokens, vec!["pizza", "rocket", "lemon"]);
646 }
647
648 #[test]
649 fn it_converts_unicode_to_ascii() {
650 let text = "gemüse, Gießen";
651 let tokenizer = DefaultTokenizer::builder()
652 .language_mode(Language::German)
653 .stemming(false)
654 .build();
655
656 let tokens: Vec<_> = tokenizer.tokenize(text).collect();
657
658 assert_eq!(tokens, vec!["gemuse", "giessen"]);
659 }
660
661 #[test]
662 #[cfg(feature = "language_detection")]
663 fn it_handles_empty_input() {
664 let text = "";
665 let tokenizer = DefaultTokenizer::new(LanguageMode::Detect);
666
667 let tokens: Vec<_> = tokenizer.tokenize(text).collect();
668
669 assert!(tokens.is_empty());
670 }
671
672 #[test]
673 #[cfg(feature = "language_detection")]
674 fn it_detects_english() {
675 let tokens_detected = tokenize_recipes("recipes_en.csv", LanguageMode::Detect);
676 let tokens_en = tokenize_recipes("recipes_en.csv", LanguageMode::Fixed(Language::English));
677
678 assert_eq!(tokens_detected, tokens_en);
679 }
680
681 #[test]
682 #[cfg(feature = "language_detection")]
683 fn it_detects_german() {
684 let tokens_detected = tokenize_recipes("recipes_de.csv", LanguageMode::Detect);
685 let token_de = tokenize_recipes("recipes_de.csv", LanguageMode::Fixed(Language::German));
686
687 assert_eq!(tokens_detected, token_de);
688 }
689
690 #[test]
691 fn it_matches_snapshot_en() {
692 let tokens = tokenize_recipes("recipes_en.csv", LanguageMode::Fixed(Language::English));
693
694 insta::with_settings!({snapshot_path => "../snapshots"}, {
695 assert_debug_snapshot!(tokens);
696 });
697 }
698
699 #[test]
700 fn it_matches_snapshot_de() {
701 let tokens = tokenize_recipes("recipes_de.csv", LanguageMode::Fixed(Language::German));
702
703 insta::with_settings!({snapshot_path => "../snapshots"}, {
704 assert_debug_snapshot!(tokens);
705 });
706 }
707
708 #[test]
709 fn it_does_not_convert_unicode_when_normalization_disabled() {
710 let text = "étude";
711 let tokenizer = DefaultTokenizer::builder()
712 .language_mode(Language::French)
713 .normalization(false)
714 .stemming(false)
715 .build();
716
717 let tokens: Vec<_> = tokenizer.tokenize(text).collect();
718
719 assert_eq!(tokens, vec!["étude"]);
720 }
721
722 #[test]
723 fn it_does_not_remove_stopwords_when_stopwords_disabled() {
724 let text = "i my myself we you have";
725 let tokenizer = DefaultTokenizer::builder()
726 .language_mode(Language::English)
727 .stopwords(false)
728 .build();
729
730 let tokens: Vec<_> = tokenizer.tokenize(text).collect();
731
732 assert_eq!(tokens, vec!["i", "my", "myself", "we", "you", "have"]);
733 }
734
735 #[test]
736 fn it_does_not_stem_when_stemming_disabled() {
737 let text = "connection connections connective connect";
738 let tokenizer = DefaultTokenizer::builder()
739 .language_mode(Language::English)
740 .stemming(false)
741 .build();
742
743 let tokens: Vec<_> = tokenizer.tokenize(text).collect();
744
745 assert_eq!(
746 tokens,
747 vec!["connection", "connections", "connective", "connect"]
748 );
749 }
750}