1#![warn(clippy::all, clippy::pedantic)]
2#![doc = include_str!("../README.md")]
3
4#[cfg(any(
5 feature = "southeast-asian",
6 feature = "japanese-icu",
7 feature = "chinese-icu"
8))]
9use icu_segmenter::{WordSegmenter, options::WordBreakInvariantOptions};
10#[cfg(any(
11 feature = "southeast-asian",
12 feature = "japanese-icu",
13 feature = "chinese-icu"
14))]
15use itertools::Itertools;
16#[cfg(any(
17 feature = "japanese-ipadic-neologd-lindera",
18 feature = "japanese-ipadic-lindera",
19 feature = "japanese-unidic-lindera",
20 feature = "chinese-lindera",
21 feature = "korean-lindera"
22))]
23use lindera::{
24 dictionary::load_dictionary, mode::Mode, segmenter::Segmenter, tokenizer::Tokenizer,
25};
26use num_enum::{FromPrimitive, IntoPrimitive};
27#[cfg(feature = "serde")]
28use serde::{
29 de::{self, SeqAccess, Visitor},
30 ser::SerializeTuple,
31 {Deserialize, Deserializer, Serialize, Serializer},
32};
33#[cfg(feature = "serde")]
34use std::fmt;
35#[cfg(feature = "snowball")]
36use std::mem::transmute;
37use strum_macros::Display;
38use thiserror::Error;
39#[cfg(feature = "snowball")]
40use unicode_normalization::UnicodeNormalization;
41#[cfg(feature = "snowball")]
42use unicode_segmentation::UnicodeSegmentation;
43#[cfg(feature = "snowball")]
44use waken_snowball::{Algorithm as SnowballAlgorithm, stem};
45
46#[cfg(all(
47 feature = "japanese-ipadic-neologd-lindera",
48 any(
49 feature = "japanese-ipadic-lindera",
50 feature = "japanese-unidic-lindera",
51 feature = "japanese-icu",
52 )
53))]
54compile_error!("Only one Japanese tokenizer feature may be enabled at a time.");
55
56#[cfg(all(
57 feature = "japanese-ipadic-lindera",
58 any(
59 feature = "japanese-ipadic-neologd-lindera",
60 feature = "japanese-unidic-lindera",
61 feature = "japanese-icu",
62 )
63))]
64compile_error!("Only one Japanese tokenizer feature may be enabled at a time.");
65
66#[cfg(all(
67 feature = "japanese-unidic-lindera",
68 any(
69 feature = "japanese-ipadic-neologd-lindera",
70 feature = "japanese-ipadic-lindera",
71 feature = "japanese-icu",
72 )
73))]
74compile_error!("Only one Japanese tokenizer feature may be enabled at a time.");
75
76#[cfg(all(
77 feature = "japanese-icu",
78 any(
79 feature = "japanese-ipadic-neologd-lindera",
80 feature = "japanese-ipadic-lindera",
81 feature = "japanese-unidic-lindera",
82 )
83))]
84compile_error!("Only one Japanese tokenizer feature may be enabled at a time.");
85
86#[cfg(all(feature = "chinese-lindera", feature = "chinese-icu"))]
87compile_error!("Only one Chinese tokenizer feature may be enabled at a time.");
88
89#[cfg(any(
90 feature = "japanese-ipadic-neologd-lindera",
91 feature = "japanese-ipadic-lindera",
92 feature = "japanese-unidic-lindera",
93 feature = "chinese-lindera",
94 feature = "korean-lindera"
95))]
96thread_local! {
97 static JAPANESE_TOKENIZER: Tokenizer =
98 Tokenizer::new(Segmenter::new(
99 Mode::Normal,
100 load_dictionary(
101 #[cfg(feature = "japanese-ipadic-neologd-lindera")]
102 "embedded://ipadic-neologd",
103
104 #[cfg(feature = "japanese-ipadic-lindera")]
105 "embedded://ipadic",
106
107 #[cfg(feature = "japanese-unidic-lindera")]
108 "embedded://unidic",
109
110 #[cfg(not(any(
111 feature = "japanese-ipadic-neologd-lindera",
112 feature = "japanese-ipadic-lindera",
113 feature = "japanese-unidic-lindera"
114 )))]
115 "",
116 ).unwrap(),
117 None,
118 ));
119 static KOREAN_TOKENIZER: Tokenizer =
120 Tokenizer::new(Segmenter::new(
121 Mode::Normal,
122 load_dictionary("embedded://ko-dic").unwrap(),
123 None,
124 ));
125 static CHINESE_TOKENIZER: Tokenizer =
126 Tokenizer::new(Segmenter::new(
127 Mode::Normal,
128 load_dictionary("embedded://cc-cedict").unwrap(),
129 None,
130 ));
131}
132
133#[derive(
134 Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Display, FromPrimitive, IntoPrimitive,
135)]
136#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
137#[cfg_attr(feature = "serde", serde(into = "i8", try_from = "i8"))]
138#[repr(i8)]
139pub enum Algorithm {
140 #[default]
141 None = -1,
142
143 Arabic,
144 Armenian,
145 Basque,
146 Catalan,
147 Danish,
148 Dutch,
149 DutchPorter,
150 English,
151 Esperanto,
152 Estonian,
153 Finnish,
154 French,
155 German,
156 Greek,
157 Hindi,
158 Hungarian,
159 Indonesian,
160 Irish,
161 Italian,
162 Lithuanian,
163 Lovins,
164 Nepali,
165 Norwegian,
166 Porter,
167 Portuguese,
168 Romanian,
169 Russian,
170 Serbian,
171 Spanish,
172 Swedish,
173 Tamil,
174 Turkish,
175 Yiddish,
176
177 Japanese,
178 Chinese,
179 Korean,
180
181 Thai,
182 Burmese,
183 Lao,
184 Khmer,
185}
186
187impl Algorithm {
188 pub const fn is_snowball(self) -> bool {
189 !self.is_cjk() && !self.is_southeast_asian()
190 }
191
192 pub const fn is_cjk(self) -> bool {
193 matches!(self, Self::Japanese | Self::Chinese | Self::Korean)
194 }
195
196 pub const fn is_southeast_asian(self) -> bool {
197 matches!(self, Self::Thai | Self::Burmese | Self::Lao | Self::Khmer)
198 }
199}
200
201#[derive(Debug, Error)]
202#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
203pub enum Error {
204 #[error(
205 "No tokenizer found for algorithm {0:?}, you might want to enable a crate feature that corresponds to desired language."
206 )]
207 NoTokenizer(Algorithm),
208}
209
210#[derive(Debug, Clone, Copy)]
223#[repr(u8)]
224pub enum MatchMode {
225 Exact,
226 Fuzzy { threshold: f64 },
227 Both { threshold: f64 },
228}
229
230#[derive(Debug, Clone)]
231pub struct Token {
232 pub text: String,
233 pub start: u32, pub len: u32, }
236
237impl<T> PartialEq<T> for Token
238where
239 T: AsRef<str>,
240{
241 fn eq(&self, other: &T) -> bool {
242 self.text == other.as_ref()
243 }
244}
245
246impl PartialEq for Token {
247 fn eq(&self, other: &Self) -> bool {
248 self.text == other.text
249 }
250}
251
252impl Eq for Token {}
253
254#[derive(Debug, Clone, Copy, PartialEq)]
255#[repr(u8)]
256pub enum MatchResult {
257 Exact { offset: u32, len: u32 },
259 Fuzzy { offset: u32, len: u32, score: f64 },
261}
262
263#[cfg(feature = "serde")]
264impl Serialize for MatchResult {
265 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
266 where
267 S: Serializer,
268 {
269 match *self {
270 MatchResult::Exact((a, b)) => (a, b).serialize(serializer),
271 MatchResult::Fuzzy((a, b), score) => (a, b, score).serialize(serializer),
272 }
273 }
274}
275
276#[cfg(feature = "serde")]
277impl<'de> Deserialize<'de> for MatchResult {
278 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
279 where
280 D: Deserializer<'de>,
281 {
282 struct MatchResultVisitor;
283
284 impl<'de> Visitor<'de> for MatchResultVisitor {
285 type Value = MatchResult;
286
287 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
288 formatter.write_str("a tuple of length 2 or 3")
289 }
290
291 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
292 where
293 A: SeqAccess<'de>,
294 {
295 let a: usize = seq
296 .next_element()?
297 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
298 let b: usize = seq
299 .next_element()?
300 .ok_or_else(|| de::Error::invalid_length(1, &self))?;
301
302 if let Some(score) = seq.next_element::<f64>()? {
303 Ok(MatchResult::Fuzzy((a, b), score))
304 } else {
305 Ok(MatchResult::Exact((a, b)))
306 }
307 }
308 }
309
310 deserializer.deserialize_seq(MatchResultVisitor)
311 }
312}
313
314#[cfg(feature = "serde")]
315impl Serialize for MatchMode {
316 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
317 where
318 S: Serializer,
319 {
320 let mut tup = serializer.serialize_tuple(2)?;
321 match *self {
322 MatchMode::Exact => {
323 tup.serialize_element(&0u8)?;
324 tup.serialize_element(&0.0f64)?;
325 }
326 MatchMode::Fuzzy(v) => {
327 tup.serialize_element(&1u8)?;
328 tup.serialize_element(&v)?;
329 }
330 MatchMode::Both(v) => {
331 tup.serialize_element(&2u8)?;
332 tup.serialize_element(&v)?;
333 }
334 }
335 tup.end()
336 }
337}
338
339#[cfg(feature = "serde")]
340impl<'de> Deserialize<'de> for MatchMode {
341 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
342 where
343 D: Deserializer<'de>,
344 {
345 struct MatchModeVisitor;
346
347 impl<'de> Visitor<'de> for MatchModeVisitor {
348 type Value = MatchMode;
349
350 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
351 formatter.write_str("a tuple [u8, f64]")
352 }
353
354 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
355 where
356 A: SeqAccess<'de>,
357 {
358 let tag: u8 = seq
359 .next_element()?
360 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
361
362 let value: f64 = seq
363 .next_element()?
364 .ok_or_else(|| de::Error::invalid_length(1, &self))?;
365
366 match tag {
367 0 => Ok(MatchMode::Exact),
368 1 => Ok(MatchMode::Fuzzy(value)),
369 2 => Ok(MatchMode::Both(value)),
370 _ => Err(de::Error::custom(format!("invalid MatchMode tag: {}", tag))),
371 }
372 }
373 }
374
375 deserializer.deserialize_tuple(2, MatchModeVisitor)
376 }
377}
378
379#[cfg(feature = "snowball")]
380fn normalize_punctuation(s: &str) -> String {
381 s.chars()
382 .map(|c| match c as u32 {
383 0x2010..=0x2015 => '\'',
384 0x201C..=0x201F => '"',
385 0x2018..=0x201B => '-',
386 _ => c,
387 })
388 .collect()
389}
390
391#[cfg(feature = "snowball")]
392fn tokenize_snowball(text: &str, algorithm: Algorithm, case_sensitive: bool) -> Vec<Token> {
393 let mut tokens = Vec::new();
394
395 for (byte_start, word) in text.unicode_word_indices() {
397 let trimmed = word.trim_matches('\'');
398
399 if !trimmed.chars().any(|c| c.is_alphabetic() || c.is_numeric()) {
400 continue;
401 }
402
403 let start = text[..byte_start].chars().count();
405 let len = trimmed.chars().count();
406
407 let normalized: String = trimmed.nfkc().collect();
409 let normalized = normalize_punctuation(&normalized);
410
411 let token_text = if case_sensitive {
412 stem(
413 unsafe { transmute::<Algorithm, SnowballAlgorithm>(algorithm) },
414 &normalized,
415 )
416 .into_owned()
417 } else {
418 stem(
419 unsafe { transmute::<Algorithm, SnowballAlgorithm>(algorithm) },
420 &normalized.to_lowercase(),
421 )
422 .into_owned()
423 };
424
425 tokens.push(Token {
426 text: token_text,
427 start: start as u32,
428 len: len as u32,
429 });
430 }
431
432 tokens
433}
434
435#[cfg(any(
436 feature = "japanese-ipadic-neologd-lindera",
437 feature = "japanese-ipadic-lindera",
438 feature = "japanese-unidic-lindera",
439 feature = "chinese-lindera",
440 feature = "korean-lindera",
441 feature = "japanese-icu",
442 feature = "chinese-icu"
443))]
444fn tokenize_cjk(text: &str, algorithm: Algorithm) -> Vec<Token> {
445 match algorithm {
446 Algorithm::Chinese => {
447 #[cfg(feature = "chinese-lindera")]
448 {
449 CHINESE_TOKENIZER.with(|t| {
450 t.tokenize(text)
451 .unwrap()
452 .into_iter()
453 .map(|tok| {
454 let start = text[..tok.byte_start].chars().count();
455 let len = tok.surface.chars().count();
456
457 Token {
458 text: tok.surface.into_owned(),
459 start: start as u32,
460 len: len as u32,
461 }
462 })
463 .collect()
464 })
465 }
466
467 #[cfg(feature = "chinese-icu")]
468 tokenize_cjk_icu(text, algorithm)
469 }
470
471 Algorithm::Japanese => {
472 #[cfg(any(
473 feature = "japanese-ipadic-neologd-lindera",
474 feature = "japanese-ipadic-lindera",
475 feature = "japanese-unidic-lindera",
476 ))]
477 {
478 JAPANESE_TOKENIZER.with(|t| {
479 t.tokenize(text)
480 .unwrap()
481 .into_iter()
482 .map(|tok| {
483 let start = text[..tok.byte_start].chars().count();
484 let len = tok.surface.chars().count();
485
486 Token {
487 text: tok.surface.into_owned(),
488 start: start as u32,
489 len: len as u32,
490 }
491 })
492 .collect()
493 })
494 }
495
496 #[cfg(feature = "japanese-icu")]
497 tokenize_cjk_icu(text, algorithm)
498 }
499
500 Algorithm::Korean =>
501 {
502 #[cfg(feature = "korean-lindera")]
503 KOREAN_TOKENIZER.with(|t| {
504 t.tokenize(text)
505 .unwrap()
506 .into_iter()
507 .map(|tok| {
508 let start = text[..tok.byte_start].chars().count();
509 let len = tok.surface.chars().count();
510
511 Token {
512 text: tok.surface.into_owned(),
513 start: start as u32,
514 len: len as u32,
515 }
516 })
517 .collect()
518 })
519 }
520
521 _ => unreachable!(),
522 }
523}
524
525#[cfg(any(feature = "japanese-icu", feature = "chinese-icu"))]
526fn tokenize_cjk_icu(text: &str, _algorithm: Algorithm) -> Vec<Token> {
527 let segmenter = WordSegmenter::new_auto(WordBreakInvariantOptions::default());
528
529 segmenter
530 .segment_str(text)
531 .tuple_windows()
532 .map(|(i, j)| {
533 let slice = &text[i..j];
534
535 Token {
536 text: slice.to_owned(),
537 start: text[..i].chars().count() as u32,
538 len: slice.chars().count() as u32,
539 }
540 })
541 .collect()
542}
543
544#[cfg(feature = "southeast-asian")]
545fn tokenize_southeast_asian(text: &str, _algorithm: Algorithm) -> Vec<Token> {
546 let segmenter = WordSegmenter::new_lstm(WordBreakInvariantOptions::default());
547
548 segmenter
549 .segment_str(text)
550 .tuple_windows()
551 .map(|(i, j)| {
552 let slice = &text[i..j];
553
554 Token {
555 text: slice.to_owned(),
556 start: text[..i].chars().count() as u32,
557 len: slice.chars().count() as u32,
558 }
559 })
560 .collect()
561}
562
563pub fn tokenize(
592 text: &str,
593 algorithm: Algorithm,
594 case_sensitive: bool,
595) -> Result<Vec<Token>, Error> {
596 if algorithm.is_snowball() {
597 #[cfg(feature = "snowball")]
598 return Ok(tokenize_snowball(text, algorithm, case_sensitive));
599 } else if algorithm.is_cjk() {
600 #[cfg(any(
601 feature = "japanese-ipadic-neologd-lindera",
602 feature = "japanese-ipadic-lindera",
603 feature = "japanese-unidic-lindera",
604 feature = "chinese-lindera",
605 feature = "korean-lindera",
606 feature = "japanese-icu",
607 feature = "chinese-icu"
608 ))]
609 return Ok(tokenize_cjk(text, algorithm));
610 } else if algorithm.is_southeast_asian() {
611 #[cfg(feature = "southeast-asian")]
612 return Ok(tokenize_southeast_asian(text, algorithm));
613 }
614
615 Err(Error::NoTokenizer(algorithm))
616}
617
618fn find_exact_match(haystack: &[Token], needle: &[Token], permissive: bool) -> Option<MatchResult> {
619 haystack.windows(needle.len()).find_map(|window| {
620 let matches = if permissive {
621 window.iter().zip(needle).all(|(a, b)| {
622 let a_lower = a.text.to_lowercase();
623 let b_lower = b.text.to_lowercase();
624
625 if a_lower == b_lower {
626 let a_upper_count = a.text.chars().filter(|c| c.is_uppercase()).count();
627 let b_upper_count = b.text.chars().filter(|c| c.is_uppercase()).count();
628
629 a_upper_count >= b_upper_count
630 } else {
631 false
632 }
633 })
634 } else {
635 window == needle
636 };
637
638 matches.then_some(MatchResult::Exact {
639 offset: window[0].start,
640 len: needle.iter().fold(0, |mut acc, a| {
641 acc += a.len;
642 acc
643 }),
644 })
645 })
646}
647
648fn find_fuzzy_match(
649 haystack: &[Token],
650 needle: &[Token],
651 threshold: f64,
652 permissive: bool,
653 _collapse: bool,
654) -> Option<MatchResult> {
655 haystack.windows(needle.len()).find_map(|window| {
656 let score = window
657 .iter()
658 .zip(needle)
659 .map(|(a, b)| {
660 if permissive {
661 strsim::normalized_levenshtein(&a.text.to_lowercase(), &b.text.to_lowercase())
662 } else {
663 strsim::normalized_levenshtein(&a.text, &b.text)
664 }
665 })
666 .sum::<f64>()
667 / needle.len() as f64;
668
669 let passes_threshold = if score >= threshold && permissive {
670 window.iter().zip(needle).all(|(a, b)| {
671 let a_upper_count = a.text.chars().filter(|c| c.is_uppercase()).count();
672 let b_upper_count = b.text.chars().filter(|c| c.is_uppercase()).count();
673
674 a_upper_count >= b_upper_count
675 })
676 } else {
677 score >= threshold
678 };
679
680 passes_threshold.then_some(MatchResult::Fuzzy {
681 offset: window[0].start,
682 len: window.iter().fold(0, |mut acc, a| {
683 acc += a.len;
684 acc
685 }),
686 score,
687 })
688 })
689}
690
691pub fn find_match(
720 haystack: &[Token],
721 needle: &[Token],
722 mode: MatchMode,
723 permissive: bool,
724) -> Option<MatchResult> {
725 if needle.len() == 0 || needle.len() > haystack.len() {
726 return None;
727 }
728
729 match mode {
730 MatchMode::Exact => find_exact_match(&haystack, &needle, permissive),
731 MatchMode::Fuzzy { threshold } => {
732 find_fuzzy_match(&haystack, &needle, threshold, permissive, false)
733 }
734 MatchMode::Both { threshold } => find_exact_match(&haystack, &needle, permissive)
735 .or_else(|| find_fuzzy_match(&haystack, &needle, threshold, permissive, false)),
736 }
737}
738
739pub fn find_all_matches(
767 haystack: &[Token],
768 needle: &[Token],
769 mode: MatchMode,
770 permissive: bool,
771) -> Vec<MatchResult> {
772 if needle.len() == 0 || needle.len() > haystack.len() {
773 return Vec::new();
774 }
775
776 let mut results = Vec::new();
777 let mut offset = 0u32;
778
779 while offset < haystack.len() as u32 {
780 let slice = &haystack[offset as usize..];
781 let found = find_match(slice, needle, mode, permissive);
782
783 match found {
784 Some(t) => {
785 match t {
786 MatchResult::Exact { offset: start, .. } => {
787 let absolute_start = offset + start;
788 offset = absolute_start + 1;
789 }
790 MatchResult::Fuzzy { offset: start, .. } => {
791 let absolute_start = offset + start;
792 offset = absolute_start + 1;
793 }
794 }
795
796 results.push(t);
797 }
798 None => break,
799 }
800 }
801
802 results
803}