1use std::error::Error;
4use std::fmt;
5
6use fancy_regex::Regex;
7use unicode_categories::UnicodeCategories;
8use unicode_normalization::char::{compose, decompose_canonical, decompose_compatible};
9
10struct CharNormalizer {
11 normalized: Vec<char>,
12
13 tmp: Vec<char>,
16}
17
18impl CharNormalizer {
19 fn new() -> CharNormalizer {
20 CharNormalizer {
21 normalized: Vec::new(),
22 tmp: Vec::new(),
23 }
24 }
25
26 fn set_char(&mut self, ch: char) {
28 self.tmp.push(ch);
29 self.update_normalized_from_tmp();
30 }
31
32 fn lower_case(&mut self) {
34 for ch in &self.normalized {
35 for lower_ch in ch.to_lowercase() {
36 self.tmp.push(lower_ch);
37 }
38 }
39 self.update_normalized_from_tmp();
40 }
41
42 fn strip_accents(&mut self) {
45 for ch in &self.normalized {
46 decompose_canonical(*ch, |decomposed| {
47 if !decomposed.is_mark_nonspacing() {
48 self.tmp.push(decomposed);
49 }
50 });
51 }
52 self.update_normalized_from_tmp();
53 }
54
55 fn normalized(&self) -> &[char] {
57 &self.normalized
58 }
59
60 fn update_normalized_from_tmp(&mut self) {
61 self.normalized.clear();
62 self.normalized.extend(self.tmp.iter());
63 self.tmp.clear();
64 }
65}
66
67#[derive(Clone, Debug)]
70pub enum NormalizeError {
71 RegexError(Box<fancy_regex::Error>),
72}
73
74impl fmt::Display for NormalizeError {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 match self {
77 Self::RegexError(err) => write!(f, "regex failed {}", err),
78 }
79 }
80}
81
82impl Error for NormalizeError {
83 fn source(&self) -> Option<&(dyn Error + 'static)> {
84 match self {
85 Self::RegexError(err) => Some(err),
86 }
87 }
88}
89
90impl From<fancy_regex::Error> for NormalizeError {
91 fn from(val: fancy_regex::Error) -> Self {
92 Self::RegexError(Box::new(val))
93 }
94}
95
96pub trait Normalizer: std::fmt::Debug {
104 fn normalize(&self, text: &str) -> Result<(String, Vec<usize>), NormalizeError>;
110}
111
112#[derive(Clone, Debug)]
115pub struct Bert {
116 lowercase: bool,
117 strip_accents: bool,
118}
119
120#[derive(Clone, Debug, Default)]
122pub struct BertOptions {
123 pub lowercase: bool,
125
126 pub strip_accents: bool,
129}
130
131impl Bert {
132 pub fn new(opts: BertOptions) -> Bert {
133 Bert {
134 lowercase: opts.lowercase,
135 strip_accents: opts.strip_accents,
136 }
137 }
138
139 fn is_noop(&self) -> bool {
141 !self.lowercase && !self.strip_accents
142 }
143}
144
145impl Normalizer for Bert {
146 fn normalize(&self, text: &str) -> Result<(String, Vec<usize>), NormalizeError> {
147 if self.is_noop() {
148 let offsets = (0..text.len()).collect();
149 return Ok((text.to_string(), offsets));
150 }
151
152 let mut normalized = String::with_capacity(text.len());
153 let mut offsets = Vec::with_capacity(text.len());
154 let mut char_normalizer = CharNormalizer::new();
155
156 for (offset, ch) in text.char_indices() {
157 char_normalizer.set_char(ch);
158
159 if self.strip_accents {
160 char_normalizer.strip_accents();
161 }
162
163 if self.lowercase {
164 char_normalizer.lower_case();
165 }
166
167 for ch in char_normalizer.normalized() {
168 normalized.push(*ch);
169 for _ in 0..ch.len_utf8() {
170 offsets.push(offset);
171 }
172 }
173 }
174
175 Ok((normalized, offsets))
176 }
177}
178
179#[derive(Clone, Debug)]
181pub struct Replace {
182 regex: Regex,
183 content: String,
184}
185
186impl Replace {
187 pub fn new(pattern: &str, content: String) -> Result<Replace, NormalizeError> {
192 Ok(Replace {
193 regex: Regex::new(pattern)?,
194 content,
195 })
196 }
197}
198
199impl Normalizer for Replace {
200 fn normalize(&self, text: &str) -> Result<(String, Vec<usize>), NormalizeError> {
201 let mut normalized = String::with_capacity(text.len());
202 let mut offsets = Vec::with_capacity(text.len());
203
204 let mut last_match_end = 0;
205 for match_ in self.regex.find_iter(text) {
206 let match_ = match_?;
207
208 let before_match = &text[last_match_end..match_.range().start];
209 normalized.push_str(before_match);
210 offsets.extend(last_match_end..match_.range().start);
211
212 normalized.push_str(&self.content);
213 offsets.extend(std::iter::repeat(match_.range().start).take(self.content.len()));
214
215 last_match_end = match_.range().end;
216 }
217
218 normalized.push_str(&text[last_match_end..]);
219 offsets.extend(last_match_end..text.len());
220
221 Ok((normalized, offsets))
222 }
223}
224
225#[derive(Debug)]
227pub struct Sequence {
228 normalizers: Vec<Box<dyn Normalizer>>,
229}
230
231impl Sequence {
232 pub fn from_vec(normalizers: Vec<Box<dyn Normalizer>>) -> Self {
233 Sequence { normalizers }
234 }
235}
236
237impl Normalizer for Sequence {
238 fn normalize(&self, text: &str) -> Result<(String, Vec<usize>), NormalizeError> {
239 let mut normalized = text.to_string();
240 let mut offsets: Vec<usize> = (0..text.len()).collect();
241
242 for normalizer in &self.normalizers {
243 let (next_normalized, mut next_offsets) = normalizer.normalize(&normalized)?;
244 for offset in next_offsets.iter_mut() {
245 *offset = offsets[*offset];
246 }
247 normalized = next_normalized;
248 offsets = next_offsets;
249 }
250
251 Ok((normalized, offsets))
252 }
253}
254
255struct UnicodeBuf {
257 normalized: String,
259
260 char_offsets: Vec<usize>,
263}
264
265impl UnicodeBuf {
266 fn with_capacity(len: usize) -> Self {
267 UnicodeBuf {
268 normalized: String::with_capacity(len),
269 char_offsets: Vec::with_capacity(len),
270 }
271 }
272
273 fn push(&mut self, ch: char, offset: usize) {
276 self.normalized.push(ch);
277 self.char_offsets.push(offset);
278 }
279
280 fn push_compose(&mut self, ch: char, offset: usize) {
283 if let (Some(prev_ch), Some(prev_offset)) = (self.normalized.pop(), self.char_offsets.pop())
284 {
285 if let Some(composed_ch) = compose(prev_ch, ch) {
286 self.push(composed_ch, prev_offset);
287 } else {
288 self.push(prev_ch, prev_offset);
289 self.push(ch, offset);
290 }
291 } else {
292 self.push(ch, offset);
293 }
294 }
295
296 fn into_string_with_byte_offsets(self) -> (String, Vec<usize>) {
297 let UnicodeBuf {
300 normalized,
301 char_offsets,
302 } = self;
303 let mut byte_offsets = Vec::with_capacity(char_offsets.len());
304 for (ch, offset) in normalized.chars().zip(char_offsets) {
305 for _ in 0..ch.len_utf8() {
306 byte_offsets.push(offset);
307 }
308 }
309 (normalized, byte_offsets)
310 }
311}
312
313#[derive(Clone, Debug)]
315pub enum Unicode {
316 Nfc,
318 Nfd,
320 Nfkc,
322 Nfkd,
324}
325
326impl Normalizer for Unicode {
327 fn normalize(&self, text: &str) -> Result<(String, Vec<usize>), NormalizeError> {
328 let mut tmp = UnicodeBuf::with_capacity(text.len());
329
330 for (offset, ch) in text.char_indices() {
331 match self {
332 Self::Nfc => {
333 tmp.push_compose(ch, offset);
334 }
335 Self::Nfd => {
336 decompose_canonical(ch, |decomposed| {
337 tmp.push(decomposed, offset);
338 });
339 }
340 Self::Nfkc => {
341 decompose_compatible(ch, |ch| {
342 tmp.push_compose(ch, offset);
343 });
344 }
345 Self::Nfkd => {
346 decompose_compatible(ch, |decomposed| {
347 tmp.push(decomposed, offset);
348 });
349 }
350 }
351 }
352
353 Ok(tmp.into_string_with_byte_offsets())
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use rten_testing::TestCases;
360
361 use super::{Bert, BertOptions, Normalizer, Replace, Sequence, Unicode};
362
363 #[test]
364 fn test_bert_noop() {
365 let normalizer = Bert::new(BertOptions::default());
366 let inputs = [
367 "Hello world!", "Motörhead", "lowercase",
370 ];
371 for input in inputs {
372 let (normalized, offsets) = normalizer.normalize(input).unwrap();
373 assert_eq!(normalized, input);
374 assert_eq!(offsets, (0..input.len()).collect::<Vec<_>>());
375 }
376 }
377
378 #[test]
379 fn test_bert_lowercase() {
380 let normalizer = Bert::new(BertOptions {
381 lowercase: true,
382 ..Default::default()
383 });
384
385 #[derive(Debug)]
386 struct Case<'a> {
387 input: &'a str,
388 expected: &'a str,
389 expected_offsets: Vec<usize>,
390 }
391
392 let cases = [
393 Case {
395 input: "Hello World!",
396 expected: "hello world!",
397 expected_offsets: (0.."hello world!".len()).collect(),
398 },
399 Case {
401 input: "İİAB",
402 expected: "i\u{307}i\u{307}ab",
403
404 expected_offsets: vec![0, 0, 0, 2, 2, 2, 4, 5],
409 },
410 ];
411
412 cases.test_each(|case| {
413 let Case {
414 input,
415 expected,
416 expected_offsets,
417 } = case;
418
419 let (normalized, offsets) = normalizer.normalize(input).unwrap();
420 assert_eq!(normalized, *expected);
421 assert_eq!(offsets, *expected_offsets);
422 })
423 }
424
425 #[test]
426 fn test_bert_strip_accepts() {
427 #[derive(Debug)]
428 struct Case<'a> {
429 input: &'a str,
430 lowercase: bool,
431 expected: &'a str,
432 expected_offsets: Vec<usize>,
433 }
434
435 let cases = [
436 Case {
438 input: "Motörhead",
439 lowercase: false,
440 expected: "Motorhead",
441 expected_offsets: vec![0, 1, 2, 3, 5, 6, 7, 8, 9],
444 },
445 Case {
447 input: "Motörhead",
448 lowercase: true,
449 expected: "motorhead",
450 expected_offsets: vec![0, 1, 2, 3, 5, 6, 7, 8, 9],
453 },
454 ];
455
456 cases.test_each(|case| {
457 let Case {
458 input,
459 lowercase,
460 expected,
461 expected_offsets,
462 } = case;
463
464 let normalizer = Bert::new(BertOptions {
465 lowercase: *lowercase,
466 strip_accents: true,
467 ..Default::default()
468 });
469
470 let (normalized, offsets) = normalizer.normalize(input).unwrap();
471 assert_eq!(normalized, *expected);
472 assert_eq!(offsets, *expected_offsets);
473 })
474 }
475
476 #[test]
477 fn test_replace() {
478 #[derive(Debug)]
479 struct Case<'a> {
480 input: &'a str,
481 pattern: &'a str,
482 content: &'a str,
483 expected: &'a str,
484 expected_offsets: Vec<usize>,
485 }
486
487 let cases = [
488 Case {
490 input: "nothing to do here",
491 pattern: "does-not-match",
492 content: "replacement",
493 expected: "nothing to do here",
494 expected_offsets: (0.."nothing to do here".len()).collect(),
495 },
496 Case {
498 input: "foo bar baz",
499 pattern: r"\s+",
500 content: " ",
501 expected: "foo bar baz",
502 expected_offsets: [0, 1, 2, 3, 5, 6, 7, 8, 10, 11, 12].into(),
503 },
504 Case {
506 input: "foo bar baz",
507 pattern: r" ",
508 content: " ",
509 expected: "foo bar baz",
510 expected_offsets: [0, 1, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13, 14].into(),
511 },
512 ];
513
514 cases.test_each(|case| {
515 let Case {
516 input,
517 pattern,
518 content,
519 expected,
520 expected_offsets,
521 } = case;
522
523 let normalizer = Replace::new(pattern, content.to_string()).unwrap();
524 let (normalized, offsets) = normalizer.normalize(input).unwrap();
525 assert_eq!(offsets.len(), normalized.len());
526 assert_eq!(normalized, *expected);
527 assert_eq!(offsets, *expected_offsets);
528 })
529 }
530
531 fn lowercase_normalizer() -> Box<dyn Normalizer> {
532 Box::new(Bert::new(BertOptions {
533 lowercase: true,
534 strip_accents: false,
535 }))
536 }
537
538 fn nfc_normalizer() -> Box<dyn Normalizer> {
539 Box::new(Unicode::Nfc)
540 }
541
542 fn replace_normalizer(pattern: &str, content: &str) -> Box<dyn Normalizer> {
543 Box::new(Replace::new(pattern, content.to_string()).unwrap())
544 }
545
546 #[test]
547 fn test_sequence() {
548 use std::panic::AssertUnwindSafe;
549
550 #[derive(Debug)]
551 struct Case<'a> {
552 input: &'a str,
553 normalizers: AssertUnwindSafe<Vec<Box<dyn Normalizer>>>,
554 expected: &'a str,
555 expected_offsets: Vec<usize>,
556 }
557
558 let cases = [
559 Case {
563 input: "FOO BAR BAZ",
564 normalizers: AssertUnwindSafe(
565 [
566 nfc_normalizer(),
567 lowercase_normalizer(),
568 replace_normalizer(r"\s+", " "),
569 ]
570 .into(),
571 ),
572 expected: "foo bar baz",
573 expected_offsets: [0, 1, 2, 3, 5, 6, 7, 8, 10, 11, 12].into(),
574 },
575 Case {
577 input: "FOO BAR BAZ",
578 normalizers: AssertUnwindSafe(
579 [
580 replace_normalizer(" ", "--"),
581 replace_normalizer("--", "_"),
582 lowercase_normalizer(),
583 ]
584 .into(),
585 ),
586 expected: "foo_bar_baz",
587 expected_offsets: (0.."foo bar baz".len()).collect(),
588 },
589 Case {
591 input: "foo bar baz",
592 normalizers: AssertUnwindSafe(Vec::new()),
593 expected: "foo bar baz",
594 expected_offsets: (0.."foo bar baz".len()).collect(),
595 },
596 ];
597
598 cases.test_each_value(|case| {
599 let Case {
600 input,
601 normalizers,
602 expected,
603 expected_offsets,
604 } = case;
605
606 let seq = Sequence::from_vec(normalizers.0);
607 let (normalized, offsets) = seq.normalize(input).unwrap();
608 assert_eq!(normalized, expected);
609 assert_eq!(offsets, expected_offsets);
610 })
611 }
612
613 #[test]
614 fn test_unicode() {
615 #[derive(Debug)]
616 struct Case<'a> {
617 input: &'a str,
618 normalizer: Unicode,
619 expected: &'a str,
620 expected_offsets: Vec<usize>,
621 }
622
623 let noop_case = |normalizer| Case {
624 input: "abc",
625 normalizer,
626 expected: "abc",
627 expected_offsets: [0, 1, 2].into(),
628 };
629
630 let cases = [
631 noop_case(Unicode::Nfc),
633 noop_case(Unicode::Nfd),
634 noop_case(Unicode::Nfkc),
635 noop_case(Unicode::Nfkd),
636 Case {
638 input: "I\u{307}ab",
639 normalizer: Unicode::Nfc,
640 expected: "İab",
641 expected_offsets: [0, 0, 3, 4].into(),
642 },
643 Case {
645 input: "İa",
646 normalizer: Unicode::Nfd,
647 expected: "I\u{307}a",
648 expected_offsets: [0, 0, 0, 2].into(),
649 },
650 Case {
652 input: "①",
653 normalizer: Unicode::Nfkc,
654 expected: "1",
655 expected_offsets: [0].into(),
656 },
657 Case {
658 input: "Éab",
659 normalizer: Unicode::Nfkc,
660 expected: "Éab",
661 expected_offsets: [0, 0, 2, 3].into(),
662 },
663 Case {
665 input: "Éab",
666 normalizer: Unicode::Nfkd,
667 expected: "E\u{301}ab",
668 expected_offsets: [0, 0, 0, 2, 3].into(),
669 },
670 ];
671
672 cases.test_each(|case| {
673 let Case {
674 input,
675 normalizer,
676 expected,
677 expected_offsets,
678 } = case;
679
680 let (normalized, offsets) = normalizer.normalize(input).unwrap();
681 assert_eq!(normalized, *expected);
682 assert_eq!(normalized.len(), offsets.len());
683 assert_eq!(offsets, *expected_offsets);
684 })
685 }
686}