1#![warn(clippy::all, clippy::pedantic)]
2#![doc = include_str!("../README.md")]
3
4#[cfg(any(
5 feature = "southeast-asian",
6 feature = "japanese-icu",
7 feature = "chinese-icu"
8))]
9use icu_segmenter::{options::WordBreakInvariantOptions, WordSegmenter};
10#[cfg(any(
11 feature = "southeast-asian",
12 feature = "japanese-icu",
13 feature = "chinese-icu"
14))]
15use itertools::Itertools;
16#[cfg(any(
17 feature = "japanese-ipadic-neologd-lindera",
18 feature = "japanese-ipadic-lindera",
19 feature = "japanese-unidic-lindera",
20 feature = "chinese-lindera",
21 feature = "korean-lindera"
22))]
23use lindera::{
24 dictionary::load_dictionary, mode::Mode, segmenter::Segmenter, tokenizer::Tokenizer,
25};
26use num_enum::{FromPrimitive, IntoPrimitive};
27#[cfg(feature = "serde")]
28use serde::{
29 de::{self, SeqAccess, Visitor},
30 ser::SerializeTuple,
31 {Deserialize, Deserializer, Serialize, Serializer},
32};
33#[cfg(feature = "serde")]
34use std::fmt;
35#[cfg(feature = "snowball")]
36use std::mem::transmute;
37use strum_macros::Display;
38use thiserror::Error;
39#[cfg(feature = "snowball")]
40use unicode_normalization::UnicodeNormalization;
41#[cfg(feature = "snowball")]
42use unicode_segmentation::UnicodeSegmentation;
43#[cfg(feature = "snowball")]
44use waken_snowball::{stem, Algorithm as SnowballAlgorithm};
45
46#[cfg(all(
47 feature = "japanese-ipadic-neologd-lindera",
48 any(
49 feature = "japanese-ipadic-lindera",
50 feature = "japanese-unidic-lindera",
51 feature = "japanese-icu",
52 )
53))]
54compile_error!("Only one Japanese tokenizer feature may be enabled at a time.");
55
56#[cfg(all(
57 feature = "japanese-ipadic-lindera",
58 any(
59 feature = "japanese-ipadic-neologd-lindera",
60 feature = "japanese-unidic-lindera",
61 feature = "japanese-icu",
62 )
63))]
64compile_error!("Only one Japanese tokenizer feature may be enabled at a time.");
65
66#[cfg(all(
67 feature = "japanese-unidic-lindera",
68 any(
69 feature = "japanese-ipadic-neologd-lindera",
70 feature = "japanese-ipadic-lindera",
71 feature = "japanese-icu",
72 )
73))]
74compile_error!("Only one Japanese tokenizer feature may be enabled at a time.");
75
76#[cfg(all(
77 feature = "japanese-icu",
78 any(
79 feature = "japanese-ipadic-neologd-lindera",
80 feature = "japanese-ipadic-lindera",
81 feature = "japanese-unidic-lindera",
82 )
83))]
84compile_error!("Only one Japanese tokenizer feature may be enabled at a time.");
85
86#[cfg(all(feature = "chinese-lindera", feature = "chinese-icu"))]
87compile_error!("Only one Chinese tokenizer feature may be enabled at a time.");
88
89#[cfg(any(
90 feature = "japanese-ipadic-neologd-lindera",
91 feature = "japanese-ipadic-lindera",
92 feature = "japanese-unidic-lindera",
93 feature = "chinese-lindera",
94 feature = "korean-lindera"
95))]
96thread_local! {
97 static JAPANESE_TOKENIZER: Tokenizer =
98 Tokenizer::new(Segmenter::new(
99 Mode::Normal,
100 load_dictionary(
101 #[cfg(feature = "japanese-ipadic-neologd-lindera")]
102 "embedded://ipadic-neologd",
103
104 #[cfg(feature = "japanese-ipadic-lindera")]
105 "embedded://ipadic",
106
107 #[cfg(feature = "japanese-unidic-lindera")]
108 "embedded://unidic",
109
110 #[cfg(not(any(
111 feature = "japanese-ipadic-neologd-lindera",
112 feature = "japanese-ipadic-lindera",
113 feature = "japanese-unidic-lindera"
114 )))]
115 "",
116 ).unwrap(),
117 None,
118 ));
119 static KOREAN_TOKENIZER: Tokenizer =
120 Tokenizer::new(Segmenter::new(
121 Mode::Normal,
122 load_dictionary("embedded://ko-dic").unwrap(),
123 None,
124 ));
125 static CHINESE_TOKENIZER: Tokenizer =
126 Tokenizer::new(Segmenter::new(
127 Mode::Normal,
128 load_dictionary("embedded://cc-cedict").unwrap(),
129 None,
130 ));
131}
132
133#[derive(
134 Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Display, FromPrimitive, IntoPrimitive,
135)]
136#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
137#[cfg_attr(feature = "serde", serde(into = "i8", try_from = "i8"))]
138#[repr(i8)]
139pub enum Algorithm {
140 #[default]
141 None = -1,
142
143 Arabic,
144 Armenian,
145 Basque,
146 Catalan,
147 Danish,
148 Dutch,
149 DutchPorter,
150 English,
151 Esperanto,
152 Estonian,
153 Finnish,
154 French,
155 German,
156 Greek,
157 Hindi,
158 Hungarian,
159 Indonesian,
160 Irish,
161 Italian,
162 Lithuanian,
163 Lovins,
164 Nepali,
165 Norwegian,
166 Porter,
167 Portuguese,
168 Romanian,
169 Russian,
170 Serbian,
171 Spanish,
172 Swedish,
173 Tamil,
174 Turkish,
175 Yiddish,
176
177 Japanese,
178 Chinese,
179 Korean,
180
181 Thai,
182 Burmese,
183 Lao,
184 Khmer,
185}
186
187impl Algorithm {
188 pub const fn is_snowball(self) -> bool {
189 !self.is_cjk() && !self.is_southeast_asian()
190 }
191
192 pub const fn is_cjk(self) -> bool {
193 matches!(self, Self::Japanese | Self::Chinese | Self::Korean)
194 }
195
196 pub const fn is_southeast_asian(self) -> bool {
197 matches!(self, Self::Thai | Self::Burmese | Self::Lao | Self::Khmer)
198 }
199}
200
201#[derive(Debug, Error)]
202#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
203pub enum Error {
204 #[error("No tokenizer found for algorithm {0:?}, you might want to enable a crate feature that corresponds to desired language.")]
205 NoTokenizer(Algorithm),
206}
207
208#[derive(Debug, Clone, Copy)]
221#[repr(u8)]
222pub enum MatchMode {
223 Exact,
224 Fuzzy(f64),
225 Both(f64),
226}
227
228#[derive(Debug, Clone)]
229pub struct Token {
230 pub text: String,
231 pub start: usize, pub len: usize, }
234
235impl<T> PartialEq<T> for Token
236where
237 T: AsRef<str>,
238{
239 fn eq(&self, other: &T) -> bool {
240 self.text == other.as_ref()
241 }
242}
243
244impl PartialEq for Token {
245 fn eq(&self, other: &Self) -> bool {
246 self.text == other.text
247 }
248}
249
250impl Eq for Token {}
251
252#[derive(Debug, Clone, Copy, PartialEq)]
253#[repr(u8)]
254pub enum MatchResult {
255 Exact((usize, usize)),
257 Fuzzy((usize, usize), f64),
259}
260
261#[cfg(feature = "serde")]
262impl Serialize for MatchResult {
263 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
264 where
265 S: Serializer,
266 {
267 match *self {
268 MatchResult::Exact((a, b)) => (a, b).serialize(serializer),
269 MatchResult::Fuzzy((a, b), score) => (a, b, score).serialize(serializer),
270 }
271 }
272}
273
274#[cfg(feature = "serde")]
275impl<'de> Deserialize<'de> for MatchResult {
276 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
277 where
278 D: Deserializer<'de>,
279 {
280 struct MatchResultVisitor;
281
282 impl<'de> Visitor<'de> for MatchResultVisitor {
283 type Value = MatchResult;
284
285 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
286 formatter.write_str("a tuple of length 2 or 3")
287 }
288
289 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
290 where
291 A: SeqAccess<'de>,
292 {
293 let a: usize = seq
294 .next_element()?
295 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
296 let b: usize = seq
297 .next_element()?
298 .ok_or_else(|| de::Error::invalid_length(1, &self))?;
299
300 if let Some(score) = seq.next_element::<f64>()? {
301 Ok(MatchResult::Fuzzy((a, b), score))
302 } else {
303 Ok(MatchResult::Exact((a, b)))
304 }
305 }
306 }
307
308 deserializer.deserialize_seq(MatchResultVisitor)
309 }
310}
311
312#[cfg(feature = "serde")]
313impl Serialize for MatchMode {
314 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
315 where
316 S: Serializer,
317 {
318 let mut tup = serializer.serialize_tuple(2)?;
319 match *self {
320 MatchMode::Exact => {
321 tup.serialize_element(&0u8)?;
322 tup.serialize_element(&0.0f64)?;
323 }
324 MatchMode::Fuzzy(v) => {
325 tup.serialize_element(&1u8)?;
326 tup.serialize_element(&v)?;
327 }
328 MatchMode::Both(v) => {
329 tup.serialize_element(&2u8)?;
330 tup.serialize_element(&v)?;
331 }
332 }
333 tup.end()
334 }
335}
336
337#[cfg(feature = "serde")]
338impl<'de> Deserialize<'de> for MatchMode {
339 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
340 where
341 D: Deserializer<'de>,
342 {
343 struct MatchModeVisitor;
344
345 impl<'de> Visitor<'de> for MatchModeVisitor {
346 type Value = MatchMode;
347
348 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
349 formatter.write_str("a tuple [u8, f64]")
350 }
351
352 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
353 where
354 A: SeqAccess<'de>,
355 {
356 let tag: u8 = seq
357 .next_element()?
358 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
359
360 let value: f64 = seq
361 .next_element()?
362 .ok_or_else(|| de::Error::invalid_length(1, &self))?;
363
364 match tag {
365 0 => Ok(MatchMode::Exact),
366 1 => Ok(MatchMode::Fuzzy(value)),
367 2 => Ok(MatchMode::Both(value)),
368 _ => Err(de::Error::custom(format!("invalid MatchMode tag: {}", tag))),
369 }
370 }
371 }
372
373 deserializer.deserialize_tuple(2, MatchModeVisitor)
374 }
375}
376
377#[cfg(feature = "snowball")]
378fn normalize_punctuation(s: &str) -> String {
379 s.chars()
380 .map(|c| match c as u32 {
381 0x2010..=0x2015 => '\'',
382 0x201C..=0x201F => '"',
383 0x2018..=0x201B => '-',
384 _ => c,
385 })
386 .collect()
387}
388
389#[cfg(feature = "snowball")]
390fn tokenize_snowball(text: &str, algorithm: Algorithm, case_sensitive: bool) -> Vec<Token> {
391 let mut tokens = Vec::new();
392
393 for (byte_start, word) in text.unicode_word_indices() {
395 let trimmed = word.trim_matches('\'');
396
397 if !trimmed.chars().any(|c| c.is_alphabetic() || c.is_numeric()) {
398 continue;
399 }
400
401 let start = text[..byte_start].chars().count();
403 let len = trimmed.chars().count();
404
405 let normalized: String = trimmed.nfkc().collect();
407 let normalized = normalize_punctuation(&normalized);
408
409 let token_text = if case_sensitive {
410 stem(
411 unsafe { transmute::<Algorithm, SnowballAlgorithm>(algorithm) },
412 &normalized,
413 )
414 .into_owned()
415 } else {
416 stem(
417 unsafe { transmute::<Algorithm, SnowballAlgorithm>(algorithm) },
418 &normalized.to_lowercase(),
419 )
420 .into_owned()
421 };
422
423 tokens.push(Token {
424 text: token_text,
425 start,
426 len,
427 });
428 }
429
430 tokens
431}
432
433#[cfg(any(
434 feature = "japanese-ipadic-neologd-lindera",
435 feature = "japanese-ipadic-lindera",
436 feature = "japanese-unidic-lindera",
437 feature = "chinese-lindera",
438 feature = "korean-lindera",
439 feature = "japanese-icu",
440 feature = "chinese-icu"
441))]
442fn tokenize_cjk(text: &str, algorithm: Algorithm) -> Vec<Token> {
443 match algorithm {
444 Algorithm::Chinese => {
445 #[cfg(feature = "chinese-lindera")]
446 {
447 CHINESE_TOKENIZER.with(|t| {
448 t.tokenize(text)
449 .unwrap()
450 .into_iter()
451 .map(|tok| {
452 let start = text[..tok.byte_start].chars().count();
453 let len = tok.surface.chars().count();
454
455 Token {
456 text: tok.surface.into_owned(),
457 start,
458 len,
459 }
460 })
461 .collect()
462 })
463 }
464
465 #[cfg(feature = "chinese-icu")]
466 tokenize_cjk_icu(text, algorithm)
467 }
468
469 Algorithm::Japanese => {
470 #[cfg(any(
471 feature = "japanese-ipadic-neologd-lindera",
472 feature = "japanese-ipadic-lindera",
473 feature = "japanese-unidic-lindera",
474 ))]
475 {
476 JAPANESE_TOKENIZER.with(|t| {
477 t.tokenize(text)
478 .unwrap()
479 .into_iter()
480 .map(|tok| {
481 let start = text[..tok.byte_start].chars().count();
482 let len = tok.surface.chars().count();
483
484 Token {
485 text: tok.surface.into_owned(),
486 start,
487 len,
488 }
489 })
490 .collect()
491 })
492 }
493
494 #[cfg(feature = "japanese-icu")]
495 tokenize_cjk_icu(text, algorithm)
496 }
497
498 Algorithm::Korean =>
499 {
500 #[cfg(feature = "korean-lindera")]
501 KOREAN_TOKENIZER.with(|t| {
502 t.tokenize(text)
503 .unwrap()
504 .into_iter()
505 .map(|tok| {
506 let start = text[..tok.byte_start].chars().count();
507 let len = tok.surface.chars().count();
508
509 Token {
510 text: tok.surface.into_owned(),
511 start,
512 len,
513 }
514 })
515 .collect()
516 })
517 }
518
519 _ => unreachable!(),
520 }
521}
522
523#[cfg(any(feature = "japanese-icu", feature = "chinese-icu"))]
524fn tokenize_cjk_icu(text: &str, _algorithm: Algorithm) -> Vec<Token> {
525 let segmenter = WordSegmenter::new_auto(WordBreakInvariantOptions::default());
526
527 segmenter
528 .segment_str(text)
529 .tuple_windows()
530 .map(|(i, j)| {
531 let slice = &text[i..j];
532
533 Token {
534 text: slice.to_owned(),
535 start: text[..i].chars().count(),
536 len: slice.chars().count(),
537 }
538 })
539 .collect()
540}
541
542#[cfg(feature = "southeast-asian")]
543fn tokenize_southeast_asian(text: &str, _algorithm: Algorithm) -> Vec<Token> {
544 let segmenter = WordSegmenter::new_lstm(WordBreakInvariantOptions::default());
545
546 segmenter
547 .segment_str(text)
548 .tuple_windows()
549 .map(|(i, j)| {
550 let slice = &text[i..j];
551
552 Token {
553 text: slice.to_owned(),
554 start: text[..i].chars().count(),
555 len: slice.chars().count(),
556 }
557 })
558 .collect()
559}
560
561pub fn tokenize(
590 text: &str,
591 algorithm: Algorithm,
592 case_sensitive: bool,
593) -> Result<Vec<Token>, Error> {
594 if algorithm.is_snowball() {
595 #[cfg(feature = "snowball")]
596 return Ok(tokenize_snowball(text, algorithm, case_sensitive));
597 } else if algorithm.is_cjk() {
598 #[cfg(any(
599 feature = "japanese-ipadic-neologd-lindera",
600 feature = "japanese-ipadic-lindera",
601 feature = "japanese-unidic-lindera",
602 feature = "chinese-lindera",
603 feature = "korean-lindera",
604 feature = "japanese-icu",
605 feature = "chinese-icu"
606 ))]
607 return Ok(tokenize_cjk(text, algorithm));
608 } else if algorithm.is_southeast_asian() {
609 #[cfg(feature = "southeast-asian")]
610 return Ok(tokenize_southeast_asian(text, algorithm));
611 }
612
613 Err(Error::NoTokenizer(algorithm))
614}
615
616fn find_exact_match(haystack: &[Token], needle: &[Token], permissive: bool) -> Option<MatchResult> {
617 haystack.windows(needle.len()).find_map(|window| {
618 let matches = if permissive {
619 window.iter().zip(needle).all(|(a, b)| {
620 let a_lower = a.text.to_lowercase();
621 let b_lower = b.text.to_lowercase();
622
623 if a_lower == b_lower {
624 let a_upper_count = a.text.chars().filter(|c| c.is_uppercase()).count();
625 let b_upper_count = b.text.chars().filter(|c| c.is_uppercase()).count();
626
627 a_upper_count >= b_upper_count
628 } else {
629 false
630 }
631 })
632 } else {
633 window == needle
634 };
635
636 matches.then_some(MatchResult::Exact((
637 window[0].start,
638 needle.iter().fold(0, |mut acc, a| {
639 acc += a.len;
640 acc
641 }),
642 )))
643 })
644}
645
646fn find_fuzzy_match(
647 haystack: &[Token],
648 needle: &[Token],
649 threshold: f64,
650 permissive: bool,
651 _collapse: bool,
652) -> Option<MatchResult> {
653 haystack.windows(needle.len()).find_map(|window| {
654 let score = window
655 .iter()
656 .zip(needle)
657 .map(|(a, b)| {
658 if permissive {
659 strsim::normalized_levenshtein(&a.text.to_lowercase(), &b.text.to_lowercase())
660 } else {
661 strsim::normalized_levenshtein(&a.text, &b.text)
662 }
663 })
664 .sum::<f64>()
665 / needle.len() as f64;
666
667 let passes_threshold = if score >= threshold && permissive {
668 window.iter().zip(needle).all(|(a, b)| {
669 let a_upper_count = a.text.chars().filter(|c| c.is_uppercase()).count();
670 let b_upper_count = b.text.chars().filter(|c| c.is_uppercase()).count();
671
672 a_upper_count >= b_upper_count
673 })
674 } else {
675 score >= threshold
676 };
677
678 passes_threshold.then_some(MatchResult::Fuzzy(
679 (
680 window[0].start,
681 window.iter().fold(0, |mut acc, a| {
682 acc += a.len;
683 acc
684 }),
685 ),
686 score,
687 ))
688 })
689}
690
691pub fn find_match(
720 haystack: &[Token],
721 needle: &[Token],
722 mode: MatchMode,
723 permissive: bool,
724) -> Option<MatchResult> {
725 if needle.len() == 0 || needle.len() > haystack.len() {
726 return None;
727 }
728
729 match mode {
730 MatchMode::Exact => find_exact_match(&haystack, &needle, permissive),
731 MatchMode::Fuzzy(threshold) => {
732 find_fuzzy_match(&haystack, &needle, threshold, permissive, false)
733 }
734 MatchMode::Both(threshold) => find_exact_match(&haystack, &needle, permissive)
735 .or_else(|| find_fuzzy_match(&haystack, &needle, threshold, permissive, false)),
736 }
737}
738
739pub fn find_all_matches(
767 haystack: &[Token],
768 needle: &[Token],
769 mode: MatchMode,
770 permissive: bool,
771) -> Vec<MatchResult> {
772 if needle.len() == 0 || needle.len() > haystack.len() {
773 return Vec::new();
774 }
775
776 let mut results = Vec::new();
777 let mut offset = 0;
778
779 while offset < haystack.len() {
780 let slice = &haystack[offset..];
781 let found = find_match(slice, needle, mode, permissive);
782
783 match found {
784 Some(t) => {
785 match t {
786 MatchResult::Exact((start, _)) => {
787 let absolute_start = offset + start;
788 offset = absolute_start + 1;
789 }
790 MatchResult::Fuzzy((start, _), _) => {
791 let absolute_start = offset + start;
792 offset = absolute_start + 1;
793 }
794 }
795
796 results.push(t);
797 }
798 None => break,
799 }
800 }
801
802 results
803}