1#[derive(Clone, Copy, Debug)]
10pub struct TokenizerOptions {
11 pub lowercase: bool,
12 pub split_contractions: bool,
13 pub remove_stopwords: bool,
14 pub remove_punctuation: bool,
15}
16
17impl Default for TokenizerOptions {
18 fn default() -> Self {
19 Self {
20 lowercase: false,
21 split_contractions: false,
22 remove_stopwords: false,
23 remove_punctuation: false,
24 }
25 }
26}
27
28pub struct SimpleTokenizer {
29 opts: TokenizerOptions,
30}
31
32impl SimpleTokenizer {
33 pub fn new() -> Self {
34 Self {
35 opts: TokenizerOptions::default(),
36 }
37 }
38
39 pub fn with_options(opts: TokenizerOptions) -> Self {
40 Self { opts }
41 }
42
43 pub fn split(&self, input: &str) -> Vec<String> {
44 if input.is_empty() {
45 return Vec::new();
46 }
47 let mut text = if self.opts.lowercase {
49 input.to_lowercase()
50 } else {
51 input.to_string()
52 };
53 if self.opts.split_contractions {
55 text = self.process_contractions(&text);
56 }
57 let pre = self.process_delimiters(&text);
58 let mut tokens: Vec<String> = pre.split_whitespace().map(|s| s.to_string()).collect();
60 let mut out: Vec<String> = Vec::with_capacity(tokens.len() + 4);
62 for t in tokens.drain(..) {
63 if let Some(last) = t.as_bytes().last() {
64 if *last == b'.' {
65 let stem = &t[..t.len() - 1];
66 if !stem.is_empty() && !is_abbreviation(stem) {
67 out.push(stem.to_string());
68 out.push(".".to_string());
69 continue;
70 }
71 }
72 }
73 if self.opts.remove_stopwords && is_stopword(&t) {
75 continue;
76 }
77 if self.opts.remove_punctuation && is_punctuation(&t) {
79 continue;
80 }
81 out.push(t);
82 }
83 out
84 }
85
86 fn process_delimiters(&self, text: &str) -> String {
87 let mut out = String::with_capacity(text.len() * 2);
88
89 let mut i = 0;
90 let b = text.as_bytes();
91 while i < b.len() {
92 let (cp, len) = decode_utf8(&b[i..]);
93 if is_whitespace(cp) {
94 if out.as_bytes().last().copied() != Some(b' ') {
95 out.push(' ');
96 }
97 } else if cp == b'.' as u32 {
98 let mut j = i;
100 let mut run = 0;
101 while j < b.len() && b[j] == b'.' {
102 j += 1;
103 run += 1;
104 }
105 if run >= 3 {
106 for _ in 0..run {
107 if out.as_bytes().last().copied() != Some(b' ') {
108 out.push(' ');
109 }
110 out.push('.');
111 out.push(' ');
112 }
113 i += run;
114 continue;
115 } else {
116 let mut k = i + 1;
118 while k < b.len() && (b[k] == b' ' || b[k] == b'\t' || b[k] == b'\r') {
119 k += 1;
120 }
121 if k >= b.len() || (k < b.len() && b[k] == b'\n') {
122 if out.as_bytes().last().copied() != Some(b' ') {
123 out.push(' ');
124 }
125 out.push('.');
126 out.push(' ');
127 } else {
128 out.push('.');
129 }
130 }
131 } else if is_word(cp) {
132 out.push_str(unsafe { std::str::from_utf8_unchecked(&b[i..i + len]) });
134 } else {
135 if out.as_bytes().last().copied() != Some(b' ') {
136 out.push(' ');
137 }
138 out.push_str(unsafe { std::str::from_utf8_unchecked(&b[i..i + len]) });
139 out.push(' ');
140 }
141 i += len;
142 }
143 out
144 }
145
146 fn process_contractions(&self, text: &str) -> String {
147 let mut s = text
150 .replace("won't", "will not")
151 .replace("Won't", "Will not")
152 .replace("shan't", "shall not")
153 .replace("Shan't", "Shall not")
154 .replace("can't", "can not")
155 .replace("Can't", "Can not")
156 .replace("ain't", "is not")
157 .replace("Ain't", "Is not")
158 .replace("cannot", "can not")
159 .replace("Cannot", "Can not");
160 s = s.replace("n't", " not");
162 for suf in ["'ll", "'re", "'ve", "'s", "'m", "'d"] {
164 s = s.replace(suf, &format!(" {}", suf));
165 }
166 s
167 }
168}
169
170pub fn preprocess_bm25(input: &str) -> String {
176 if input.is_empty() {
177 return String::new();
178 }
179 let mut out = String::with_capacity(input.len());
180 let mut chars = input.chars().peekable();
181 while let Some(ch) = chars.next() {
182 match ch {
183 '\u{00AD}' | '\u{200B}' | '\u{FEFF}' => { }
184 '\r' | '\t' => {
185 out.push(' ');
186 }
187 '\x0C' => {
188 out.push(' ');
189 } '-' => {
191 let it = chars.clone();
193 let mut is_break = false;
194 let mut consumed = 0;
195 for nc in it {
196 if nc == '\n' {
197 is_break = true;
198 consumed += 1;
199 break;
200 } else if nc == '\r' || nc == '\t' || nc == ' ' {
201 consumed += 1;
202 continue;
203 } else {
204 break;
205 }
206 }
207 if is_break {
208 for _ in 0..consumed {
210 let _ = chars.next();
211 }
212 out.push(' ');
213 } else {
214 out.push('-');
215 }
216 }
217 '\n' => {
218 out.push(' ');
219 }
220 c if c.is_control() => {
221 out.push(' ');
222 }
223 c => out.push(c),
224 }
225 }
226 let mut collapsed = String::with_capacity(out.len());
228 let mut last_space = false;
229 for c in out.chars() {
230 if c.is_whitespace() {
231 if !last_space {
232 collapsed.push(' ');
233 last_space = true;
234 }
235 } else {
236 collapsed.push(c);
237 last_space = false;
238 }
239 }
240 collapsed
241}
242
243fn strip_possessive(s: &str) -> &str {
249 let mut prev: Option<(usize, char)> = None;
251 let mut last: Option<(usize, char)> = None;
252 for (i, c) in s.char_indices() {
253 prev = last;
254 last = Some((i, c));
255 }
256 if let (Some((pi, pc)), Some((_li, lc))) = (prev, last) {
257 if (lc == 's' || lc == 'S') && (pc == '\'' || pc == '\u{2019}') {
258 return &s[..pi];
259 }
260 }
261 s
262}
263
264pub fn bm25_keep_token(mut tok: &str) -> bool {
265 if tok.is_empty() {
266 return false;
267 }
268 fn is_trim_punct(c: char) -> bool {
270 matches!(
271 c,
272 '.' | ','
273 | ';'
274 | ':'
275 | '"'
276 | '\''
277 | '('
278 | ')'
279 | '['
280 | ']'
281 | '{'
282 | '}'
283 | '!'
284 | '?'
285 | '%'
286 | '+'
287 | '-'
288 | '/'
289 | '\\'
290 | '*'
291 | '&'
292 | '#'
293 | '@'
294 | '~'
295 | '`'
296 | '|'
297 )
298 }
299 tok = tok.trim_matches(is_trim_punct);
300 if tok.len() < 2 {
301 return false;
302 }
303 tok = strip_possessive(tok);
305 if tok.len() < 2 {
306 return false;
307 }
308 if tok.len() >= 4 && tok.as_bytes()[0..4].eq_ignore_ascii_case(b"utm_") {
310 return false;
311 }
312 if tok.contains("---") {
314 return false;
315 }
316 let mut has_ascii_letter = false;
317 let mut upper_seq_only = true;
318 for ch in tok.chars() {
319 if ch.is_ascii_alphabetic() {
320 has_ascii_letter = true;
321 }
322 if !matches!(
323 ch,
324 'A' | 'C'
325 | 'D'
326 | 'E'
327 | 'F'
328 | 'G'
329 | 'H'
330 | 'I'
331 | 'K'
332 | 'L'
333 | 'M'
334 | 'N'
335 | 'P'
336 | 'Q'
337 | 'R'
338 | 'S'
339 | 'T'
340 | 'V'
341 | 'W'
342 | 'Y'
343 | '-'
344 ) {
345 upper_seq_only = false;
346 }
347 }
348 if has_ascii_letter {
349 if upper_seq_only && tok.len() >= 10 {
351 return false;
352 }
353 return true; }
355 for ch in tok.chars() {
357 if !(ch.is_ascii_digit() || matches!(ch, '+' | '-' | '.' | ',' | '/' | '\\')) {
358 return false;
360 }
361 }
362 false
364}
365
366pub fn bm25_normalize_token(tok: &str) -> Option<String> {
368 if tok.is_empty() {
369 return None;
370 }
371 if tok.contains("---") {
373 return None;
374 }
375 fn is_trim_punct(c: char) -> bool {
376 matches!(
377 c,
378 '.' | ','
379 | ';'
380 | ':'
381 | '"'
382 | '\''
383 | '('
384 | ')'
385 | '['
386 | ']'
387 | '{'
388 | '}'
389 | '!'
390 | '?'
391 | '%'
392 | '+'
393 | '-'
394 | '/'
395 | '\\'
396 | '*'
397 | '&'
398 | '#'
399 | '@'
400 | '~'
401 | '`'
402 | '|'
403 )
404 }
405 let mut s = tok.trim_matches(is_trim_punct);
406 if s.is_empty() {
407 return None;
408 }
409 s = strip_possessive(s);
410 if s.len() < 2 {
411 return None;
412 }
413 if s.len() >= 4 && s.as_bytes()[0..4].eq_ignore_ascii_case(b"utm_") {
415 return None;
416 }
417 if s.contains("---") {
418 return None;
419 }
420 let mut has_ascii_letter = false;
422 let mut upper_seq_only = true;
423 for ch in s.chars() {
424 if ch.is_ascii_alphabetic() {
425 has_ascii_letter = true;
426 }
427 if !matches!(
428 ch,
429 'A' | 'C'
430 | 'D'
431 | 'E'
432 | 'F'
433 | 'G'
434 | 'H'
435 | 'I'
436 | 'K'
437 | 'L'
438 | 'M'
439 | 'N'
440 | 'P'
441 | 'Q'
442 | 'R'
443 | 'S'
444 | 'T'
445 | 'V'
446 | 'W'
447 | 'Y'
448 | '-'
449 ) {
450 upper_seq_only = false;
451 }
452 }
453 if has_ascii_letter {
454 if upper_seq_only && s.len() >= 10 {
455 return None;
456 }
457 return Some(s.to_string());
458 }
459 for ch in s.chars() {
461 if !(ch.is_ascii_digit() || matches!(ch, '+' | '-' | '.' | ',' | '/' | '\\')) {
462 return None;
463 }
464 }
465 None
466}
467
468fn decode_utf8(s: &[u8]) -> (u32, usize) {
469 let c = s[0];
470 if c < 0x80 {
471 return (c as u32, 1);
472 }
473 if c & 0xE0 == 0xC0 && s.len() >= 2 {
474 return ((((c & 0x1F) as u32) << 6) | ((s[1] & 0x3F) as u32), 2);
475 }
476 if c & 0xF0 == 0xE0 && s.len() >= 3 {
477 return (
478 (((c & 0x0F) as u32) << 12) | (((s[1] & 0x3F) as u32) << 6) | ((s[2] & 0x3F) as u32),
479 3,
480 );
481 }
482 if c & 0xF8 == 0xF0 && s.len() >= 4 {
483 return (
484 (((c & 0x07) as u32) << 18)
485 | (((s[1] & 0x3F) as u32) << 12)
486 | (((s[2] & 0x3F) as u32) << 6)
487 | ((s[3] & 0x3F) as u32),
488 4,
489 );
490 }
491 (c as u32, 1)
492}
493
494fn is_whitespace(cp: u32) -> bool {
495 cp == b' ' as u32 || cp == b'\t' as u32 || cp == b'\n' as u32 || cp == b'\r' as u32
496}
497
498fn is_ascii_alnum_underscore(cp: u32) -> bool {
499 (cp >= b'A' as u32 && cp <= b'Z' as u32)
500 || (cp >= b'a' as u32 && cp <= b'z' as u32)
501 || (cp >= b'0' as u32 && cp <= b'9' as u32)
502 || cp == b'_' as u32
503}
504
505fn is_allowed_punct(cp: u32) -> bool {
506 cp == b'.' as u32
507 || cp == b'\'' as u32
508 || cp == b'-' as u32
509 || cp == b'/' as u32
510 || cp == b'&' as u32
511}
512
513fn is_word(cp: u32) -> bool {
514 if cp >= 0x80 {
515 return true;
516 }
517 if is_ascii_alnum_underscore(cp) {
518 return true;
519 }
520 if is_allowed_punct(cp) {
521 return true;
522 }
523 false
524}
525
526fn is_abbreviation(tok: &str) -> bool {
527 crate::bm25::english_abbreviations::contains(tok)
529}
530
531pub fn is_stopword(tok: &str) -> bool {
532 crate::bm25::english_stop_words::contains(tok)
534}
535
536#[cfg(test)]
537mod bm25_norm_tests {
538 use super::*;
539
540 #[test]
541 fn preprocess_dehyphenates_line_breaks_and_controls() {
542 let s = "High-\nquality and\tbar\x0C";
543 let out = preprocess_bm25(s);
544 assert!(out.contains("High"));
545 assert!(out.contains("quality"));
546 assert!(out.contains("and"));
547 assert!(out.contains("bar"));
548 assert!(!out.contains("\x0C"));
549 assert!(!out.contains("-\n"));
550 }
551
552 #[test]
553 fn normalize_strips_possessive_ascii_and_unicode() {
554 assert_eq!(bm25_normalize_token("doctor's").as_deref(), Some("doctor"));
555 assert_eq!(bm25_normalize_token("women’s").as_deref(), Some("women"));
556 }
557
558 #[test]
559 fn normalize_drops_numeric_and_url_tracking() {
560 assert_eq!(bm25_normalize_token("-0.03"), None);
561 assert_eq!(bm25_normalize_token("utm_campaign"), None);
562 }
563
564 #[test]
565 fn normalize_drops_triple_hyphen_and_sequences() {
566 assert_eq!(bm25_normalize_token("---ABC"), None);
567 let aa = "ACDEFGHIKLMNPQRSTVWY-".repeat(1); assert_eq!(bm25_normalize_token(&aa), None);
569 }
570
571 #[test]
572 fn normalize_keeps_biomedical_patterns() {
573 assert_eq!(bm25_normalize_token("il-6").as_deref(), Some("il-6"));
574 assert_eq!(bm25_normalize_token("p53").as_deref(), Some("p53"));
575 assert_eq!(
576 bm25_normalize_token("covid-19").as_deref(),
577 Some("covid-19")
578 );
579 }
580
581 #[test]
582 fn normalize_trims_leading_punct() {
583 assert_eq!(
584 bm25_normalize_token("&chibnall").as_deref(),
585 Some("chibnall")
586 );
587 assert_eq!(
588 bm25_normalize_token("'administrators'").as_deref(),
589 Some("administrators")
590 );
591 }
592}
593
594fn is_punctuation(tok: &str) -> bool {
595 crate::bm25::english_punctuations::contains(tok)
597}
598
599#[cfg(test)]
600mod simple_tokenizer_tests {
601 use super::*;
602
603 #[test]
604 fn basic_tokens() {
605 let t = SimpleTokenizer::new();
606 assert_eq!(
607 t.split("Hello, world!").as_slice(),
608 ["Hello", ",", "world", "!"]
609 );
610 assert_eq!(
611 t.split("self-driving and/or R&D").as_slice(),
612 ["self-driving", "and/or", "R&D"]
613 );
614 assert_eq!(
615 t.split("End of sentence.").as_slice(),
616 ["End", "of", "sentence", "."]
617 );
618 }
619
620 #[test]
621 fn unicode() {
622 let t = SimpleTokenizer::new();
623 assert_eq!(t.split("café naïve").as_slice(), ["café", "naïve"]);
624 assert_eq!(t.split("привет мир").as_slice(), ["привет", "мир"]);
625 }
626
627 #[test]
628 fn contractions_and_stopwords() {
629 let t = SimpleTokenizer::with_options(TokenizerOptions {
630 lowercase: true,
631 split_contractions: true,
632 remove_stopwords: true,
633 remove_punctuation: false,
634 });
635 let toks = t.split("I can't and won't do it");
637 assert_eq!(toks.as_slice(), [] as [&str; 0]);
638 }
639
640 #[test]
641 fn urls_emails_commas() {
642 let t = SimpleTokenizer::new();
643 assert_eq!(
644 t.split("one,two,three").as_slice(),
645 ["one", ",", "two", ",", "three"]
646 );
647 assert_eq!(
648 t.split("contact user@example.com today").as_slice(),
649 ["contact", "user", "@", "example.com", "today"]
650 );
651 assert_eq!(
652 t.split("Visit https://example.com/page").as_slice(),
653 ["Visit", "https", ":", "//example.com/page"]
654 );
655 }
656
657 #[test]
658 fn quotes_paren_currency() {
659 let t = SimpleTokenizer::new();
660 assert_eq!(t.split("\"quoted\"").as_slice(), ["\"", "quoted", "\""]);
661 assert_eq!(t.split("(example)").as_slice(), ["(", "example", ")"]);
662 assert_eq!(
663 t.split("$100 €50 £25").as_slice(),
664 ["$", "100", "€50", "£25"]
665 );
666 }
667
668 #[test]
669 fn periods_and_abbrev() {
670 let t = SimpleTokenizer::new();
671 assert_eq!(t.split("...").as_slice(), [".", ".", "."]); assert_eq!(t.split("Dr. Smith").as_slice(), ["Dr.", "Smith"]);
674 assert_eq!(
676 t.split("U.S. government").as_slice(),
677 ["U.S", ".", "government"]
678 );
679 assert_eq!(t.split("Dr.").as_slice(), ["Dr", "."]);
681 }
682
683 #[test]
684 fn whitespace_cases() {
685 let t = SimpleTokenizer::new();
686 assert_eq!(t.split("").as_slice(), [] as [&str; 0]);
687 assert_eq!(t.split(" \t \n ").as_slice(), [] as [&str; 0]);
688 assert_eq!(
689 t.split("multiple spaces here").as_slice(),
690 ["multiple", "spaces", "here"]
691 );
692 assert_eq!(
693 t.split("line1\nline2\ttab").as_slice(),
694 ["line1", "line2", "tab"]
695 );
696 assert_eq!(
697 t.split(" \t word1 word2 \n ").as_slice(),
698 ["word1", "word2"]
699 );
700 }
701
702 #[test]
703 fn numbers_and_mixed() {
704 let t = SimpleTokenizer::new();
705 assert_eq!(t.split("123 456.78").as_slice(), ["123", "456.78"]);
706 assert_eq!(
707 t.split("test123 456test").as_slice(),
708 ["test123", "456test"]
709 );
710 }
711
712 #[test]
713 fn operators_percent_dates_time() {
714 let t = SimpleTokenizer::new();
715 assert_eq!(t.split("2+2=4").as_slice(), ["2", "+", "2", "=", "4"]);
716 assert_eq!(
717 t.split("100% complete").as_slice(),
718 ["100", "%", "complete"]
719 );
720 assert_eq!(t.split("12/25/2024").as_slice(), ["12/25/2024"]);
721 assert_eq!(t.split("2024-12-25").as_slice(), ["2024-12-25"]);
722 assert_eq!(t.split("3:30pm").as_slice(), ["3", ":", "30pm"]);
723 }
724
725 #[test]
726 fn multiple_delimiters_and_apostrophes() {
727 let t = SimpleTokenizer::new();
728 assert_eq!(
729 t.split("word!!!???...").as_slice(),
730 ["word", "!", "!", "!", "?", "?", "?", ".", ".", "."]
731 );
732 assert_eq!(t.split("it's it's").as_slice(), ["it's", "it's"]);
734 }
735
736 #[test]
737 fn punctuation_removal() {
738 let t = SimpleTokenizer::with_options(TokenizerOptions {
739 lowercase: false,
740 split_contractions: false,
741 remove_stopwords: false,
742 remove_punctuation: true,
743 });
744 assert_eq!(t.split("Hello, world!").as_slice(), ["Hello", "world"]);
746 assert_eq!(
747 t.split("What? Really! Yes...").as_slice(),
748 ["What", "Really", "Yes"]
749 );
750 assert_eq!(
752 t.split("self-driving and/or R&D").as_slice(),
753 ["self-driving", "and/or", "R&D"]
754 );
755 assert_eq!(
757 t.split("(example) [test] {code}").as_slice(),
758 ["example", "test", "code"]
759 );
760 }
761
762 #[test]
763 fn combined_options() {
764 let t = SimpleTokenizer::with_options(TokenizerOptions {
765 lowercase: true,
766 split_contractions: true,
767 remove_stopwords: true,
768 remove_punctuation: true,
769 });
770 let toks = t.split("I can't believe it's working!");
772 assert_eq!(toks.as_slice(), ["believe", "working"]);
776
777 let toks2 = t.split("The quick brown fox jumps over the lazy dog.");
779 assert_eq!(
782 toks2.as_slice(),
783 ["quick", "brown", "fox", "jumps", "lazy", "dog"]
784 );
785 }
786}