rten_text/
normalizers.rs

1//! Tools for performing string normalization prior to tokenization.
2
3use std::error::Error;
4use std::fmt;
5
6use fancy_regex::Regex;
7use unicode_categories::UnicodeCategories;
8use unicode_normalization::char::{compose, decompose_canonical, decompose_compatible};
9
10struct CharNormalizer {
11    normalized: Vec<char>,
12
13    /// Temporary buffer that holds the output of a normalization step until
14    /// it is copied back to `normalized`.
15    tmp: Vec<char>,
16}
17
18impl CharNormalizer {
19    fn new() -> CharNormalizer {
20        CharNormalizer {
21            normalized: Vec::new(),
22            tmp: Vec::new(),
23        }
24    }
25
26    /// Set the input character to normalize.
27    fn set_char(&mut self, ch: char) {
28        self.tmp.push(ch);
29        self.update_normalized_from_tmp();
30    }
31
32    /// Lowercase the normalized characters.
33    fn lower_case(&mut self) {
34        for ch in &self.normalized {
35            for lower_ch in ch.to_lowercase() {
36                self.tmp.push(lower_ch);
37            }
38        }
39        self.update_normalized_from_tmp();
40    }
41
42    /// Decompose the input into NFD form and then remove any characters in
43    /// the Unicode non-spacing mark ("Mn") category.
44    fn strip_accents(&mut self) {
45        for ch in &self.normalized {
46            decompose_canonical(*ch, |decomposed| {
47                if !decomposed.is_mark_nonspacing() {
48                    self.tmp.push(decomposed);
49                }
50            });
51        }
52        self.update_normalized_from_tmp();
53    }
54
55    /// Return the normalized characters.
56    fn normalized(&self) -> &[char] {
57        &self.normalized
58    }
59
60    fn update_normalized_from_tmp(&mut self) {
61        self.normalized.clear();
62        self.normalized.extend(self.tmp.iter());
63        self.tmp.clear();
64    }
65}
66
67/// Errors occuring while normalizing text during the first phase of
68/// tokenization.
69#[derive(Clone, Debug)]
70pub enum NormalizeError {
71    RegexError(Box<fancy_regex::Error>),
72}
73
74impl fmt::Display for NormalizeError {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        match self {
77            Self::RegexError(err) => write!(f, "regex failed {}", err),
78        }
79    }
80}
81
82impl Error for NormalizeError {
83    fn source(&self) -> Option<&(dyn Error + 'static)> {
84        match self {
85            Self::RegexError(err) => Some(err),
86        }
87    }
88}
89
90impl From<fancy_regex::Error> for NormalizeError {
91    fn from(val: fancy_regex::Error) -> Self {
92        Self::RegexError(Box::new(val))
93    }
94}
95
96/// A normalizer applies normalization such as Unicode normalization and
97/// lower-casing to strings.
98///
99/// In addition to the normalized text, Normalizer methods also return mappings
100/// from positions in the normalized string back to the original string. This
101/// is useful for post-processing in NLP tasks to map machine learning model
102/// outputs back to the location in the original text.
103pub trait Normalizer: std::fmt::Debug {
104    /// Apply normalization to a string.
105    ///
106    /// Returns a tuple of `(normalized_string, offset_map)` where `offset_map`
107    /// is a mapping from byte offsets in the normalized string to corresponding
108    /// offsets in the original string.
109    fn normalize(&self, text: &str) -> Result<(String, Vec<usize>), NormalizeError>;
110}
111
112/// A [`Normalizer`] that implements normalization used by BERT and BERT-derived
113/// models.
114#[derive(Clone, Debug)]
115pub struct Bert {
116    lowercase: bool,
117    strip_accents: bool,
118}
119
120/// Configuration for a [`Bert`] normalizer.
121#[derive(Clone, Debug, Default)]
122pub struct BertOptions {
123    /// If true, convert all text to lowercase using [`char::to_lowercase`].
124    pub lowercase: bool,
125
126    /// Whether to strip accents when tokenizing. An "accent" is defined as
127    /// any unicode character in the Nonspacing Mark ("Mn") category.
128    pub strip_accents: bool,
129}
130
131impl Bert {
132    pub fn new(opts: BertOptions) -> Bert {
133        Bert {
134            lowercase: opts.lowercase,
135            strip_accents: opts.strip_accents,
136        }
137    }
138
139    /// Return true if this normalizer doesn't alter its input.
140    fn is_noop(&self) -> bool {
141        !self.lowercase && !self.strip_accents
142    }
143}
144
145impl Normalizer for Bert {
146    fn normalize(&self, text: &str) -> Result<(String, Vec<usize>), NormalizeError> {
147        if self.is_noop() {
148            let offsets = (0..text.len()).collect();
149            return Ok((text.to_string(), offsets));
150        }
151
152        let mut normalized = String::with_capacity(text.len());
153        let mut offsets = Vec::with_capacity(text.len());
154        let mut char_normalizer = CharNormalizer::new();
155
156        for (offset, ch) in text.char_indices() {
157            char_normalizer.set_char(ch);
158
159            if self.strip_accents {
160                char_normalizer.strip_accents();
161            }
162
163            if self.lowercase {
164                char_normalizer.lower_case();
165            }
166
167            for ch in char_normalizer.normalized() {
168                normalized.push(*ch);
169                for _ in 0..ch.len_utf8() {
170                    offsets.push(offset);
171                }
172            }
173        }
174
175        Ok((normalized, offsets))
176    }
177}
178
179/// Replaces occurrences of a pattern with a given string.
180#[derive(Clone, Debug)]
181pub struct Replace {
182    regex: Regex,
183    content: String,
184}
185
186impl Replace {
187    /// Replaces occurrences of `pattern` with `content`.
188    ///
189    /// `pattern` is a regex pattern. See the
190    /// [fancy-regex](https://docs.rs/fancy-regex/) docs for supported syntax.
191    pub fn new(pattern: &str, content: String) -> Result<Replace, NormalizeError> {
192        Ok(Replace {
193            regex: Regex::new(pattern)?,
194            content,
195        })
196    }
197}
198
199impl Normalizer for Replace {
200    fn normalize(&self, text: &str) -> Result<(String, Vec<usize>), NormalizeError> {
201        let mut normalized = String::with_capacity(text.len());
202        let mut offsets = Vec::with_capacity(text.len());
203
204        let mut last_match_end = 0;
205        for match_ in self.regex.find_iter(text) {
206            let match_ = match_?;
207
208            let before_match = &text[last_match_end..match_.range().start];
209            normalized.push_str(before_match);
210            offsets.extend(last_match_end..match_.range().start);
211
212            normalized.push_str(&self.content);
213            offsets.extend(std::iter::repeat(match_.range().start).take(self.content.len()));
214
215            last_match_end = match_.range().end;
216        }
217
218        normalized.push_str(&text[last_match_end..]);
219        offsets.extend(last_match_end..text.len());
220
221        Ok((normalized, offsets))
222    }
223}
224
225/// Run a series of normalizers in sequence.
226#[derive(Debug)]
227pub struct Sequence {
228    normalizers: Vec<Box<dyn Normalizer>>,
229}
230
231impl Sequence {
232    pub fn from_vec(normalizers: Vec<Box<dyn Normalizer>>) -> Self {
233        Sequence { normalizers }
234    }
235}
236
237impl Normalizer for Sequence {
238    fn normalize(&self, text: &str) -> Result<(String, Vec<usize>), NormalizeError> {
239        let mut normalized = text.to_string();
240        let mut offsets: Vec<usize> = (0..text.len()).collect();
241
242        for normalizer in &self.normalizers {
243            let (next_normalized, mut next_offsets) = normalizer.normalize(&normalized)?;
244            for offset in next_offsets.iter_mut() {
245                *offset = offsets[*offset];
246            }
247            normalized = next_normalized;
248            offsets = next_offsets;
249        }
250
251        Ok((normalized, offsets))
252    }
253}
254
255/// Temporary buffer used while normalizing text.
256struct UnicodeBuf {
257    // Work-in-progress normalized text.
258    normalized: String,
259
260    // Offset from char position in `normalized` to byte position in
261    // original text.
262    char_offsets: Vec<usize>,
263}
264
265impl UnicodeBuf {
266    fn with_capacity(len: usize) -> Self {
267        UnicodeBuf {
268            normalized: String::with_capacity(len),
269            char_offsets: Vec::with_capacity(len),
270        }
271    }
272
273    /// Add a character and its associated byte offset in the original text to
274    /// the work-in-progress buffer.
275    fn push(&mut self, ch: char, offset: usize) {
276        self.normalized.push(ch);
277        self.char_offsets.push(offset);
278    }
279
280    /// Compose `ch` with the last char in the buffer if possible, otherwise
281    /// add it the same as `push`.
282    fn push_compose(&mut self, ch: char, offset: usize) {
283        if let (Some(prev_ch), Some(prev_offset)) = (self.normalized.pop(), self.char_offsets.pop())
284        {
285            if let Some(composed_ch) = compose(prev_ch, ch) {
286                self.push(composed_ch, prev_offset);
287            } else {
288                self.push(prev_ch, prev_offset);
289                self.push(ch, offset);
290            }
291        } else {
292            self.push(ch, offset);
293        }
294    }
295
296    fn into_string_with_byte_offsets(self) -> (String, Vec<usize>) {
297        // Convert offsets from char positions in normalized text to byte
298        // positions in normalized text.
299        let UnicodeBuf {
300            normalized,
301            char_offsets,
302        } = self;
303        let mut byte_offsets = Vec::with_capacity(char_offsets.len());
304        for (ch, offset) in normalized.chars().zip(char_offsets) {
305            for _ in 0..ch.len_utf8() {
306                byte_offsets.push(offset);
307            }
308        }
309        (normalized, byte_offsets)
310    }
311}
312
313/// Normalize text into one of the standard Unicode normalization forms.
314#[derive(Clone, Debug)]
315pub enum Unicode {
316    /// Canonical composition
317    Nfc,
318    /// Canonical decomposition
319    Nfd,
320    /// Compatibility decomposition, followed by canonical composition
321    Nfkc,
322    /// Compatibility decomposition
323    Nfkd,
324}
325
326impl Normalizer for Unicode {
327    fn normalize(&self, text: &str) -> Result<(String, Vec<usize>), NormalizeError> {
328        let mut tmp = UnicodeBuf::with_capacity(text.len());
329
330        for (offset, ch) in text.char_indices() {
331            match self {
332                Self::Nfc => {
333                    tmp.push_compose(ch, offset);
334                }
335                Self::Nfd => {
336                    decompose_canonical(ch, |decomposed| {
337                        tmp.push(decomposed, offset);
338                    });
339                }
340                Self::Nfkc => {
341                    decompose_compatible(ch, |ch| {
342                        tmp.push_compose(ch, offset);
343                    });
344                }
345                Self::Nfkd => {
346                    decompose_compatible(ch, |decomposed| {
347                        tmp.push(decomposed, offset);
348                    });
349                }
350            }
351        }
352
353        Ok(tmp.into_string_with_byte_offsets())
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use rten_testing::TestCases;
360
361    use super::{Bert, BertOptions, Normalizer, Replace, Sequence, Unicode};
362
363    #[test]
364    fn test_bert_noop() {
365        let normalizer = Bert::new(BertOptions::default());
366        let inputs = [
367            "Hello world!", // Mixed case
368            "Motörhead",    // Accented
369            "lowercase",
370        ];
371        for input in inputs {
372            let (normalized, offsets) = normalizer.normalize(input).unwrap();
373            assert_eq!(normalized, input);
374            assert_eq!(offsets, (0..input.len()).collect::<Vec<_>>());
375        }
376    }
377
378    #[test]
379    fn test_bert_lowercase() {
380        let normalizer = Bert::new(BertOptions {
381            lowercase: true,
382            ..Default::default()
383        });
384
385        #[derive(Debug)]
386        struct Case<'a> {
387            input: &'a str,
388            expected: &'a str,
389            expected_offsets: Vec<usize>,
390        }
391
392        let cases = [
393            // Simple text where chars map 1:1 to lower-case version
394            Case {
395                input: "Hello World!",
396                expected: "hello world!",
397                expected_offsets: (0.."hello world!".len()).collect(),
398            },
399            // Text with chars which expand when lower-cased
400            Case {
401                input: "İİAB",
402                expected: "i\u{307}i\u{307}ab",
403
404                // The "İ" char requires two bytes in the input and expands into
405                // two characters which require one and three bytes
406                // respectively. Hence the offsets contain two groups of three
407                // equal offsets, with values separated by two.
408                expected_offsets: vec![0, 0, 0, 2, 2, 2, 4, 5],
409            },
410        ];
411
412        cases.test_each(|case| {
413            let Case {
414                input,
415                expected,
416                expected_offsets,
417            } = case;
418
419            let (normalized, offsets) = normalizer.normalize(input).unwrap();
420            assert_eq!(normalized, *expected);
421            assert_eq!(offsets, *expected_offsets);
422        })
423    }
424
425    #[test]
426    fn test_bert_strip_accepts() {
427        #[derive(Debug)]
428        struct Case<'a> {
429            input: &'a str,
430            lowercase: bool,
431            expected: &'a str,
432            expected_offsets: Vec<usize>,
433        }
434
435        let cases = [
436            // Strip accents only
437            Case {
438                input: "Motörhead",
439                lowercase: false,
440                expected: "Motorhead",
441                // Note jump in offset where the two UTF-8 char "ö" is replaced
442                // with "o".
443                expected_offsets: vec![0, 1, 2, 3, 5, 6, 7, 8, 9],
444            },
445            // Combined lowercase + strip accents
446            Case {
447                input: "Motörhead",
448                lowercase: true,
449                expected: "motorhead",
450                // Note jump in offset where the two UTF-8 char "ö" is replaced
451                // with "o".
452                expected_offsets: vec![0, 1, 2, 3, 5, 6, 7, 8, 9],
453            },
454        ];
455
456        cases.test_each(|case| {
457            let Case {
458                input,
459                lowercase,
460                expected,
461                expected_offsets,
462            } = case;
463
464            let normalizer = Bert::new(BertOptions {
465                lowercase: *lowercase,
466                strip_accents: true,
467                ..Default::default()
468            });
469
470            let (normalized, offsets) = normalizer.normalize(input).unwrap();
471            assert_eq!(normalized, *expected);
472            assert_eq!(offsets, *expected_offsets);
473        })
474    }
475
476    #[test]
477    fn test_replace() {
478        #[derive(Debug)]
479        struct Case<'a> {
480            input: &'a str,
481            pattern: &'a str,
482            content: &'a str,
483            expected: &'a str,
484            expected_offsets: Vec<usize>,
485        }
486
487        let cases = [
488            // No-op replacement
489            Case {
490                input: "nothing to do here",
491                pattern: "does-not-match",
492                content: "replacement",
493                expected: "nothing to do here",
494                expected_offsets: (0.."nothing to do here".len()).collect(),
495            },
496            // Whitespace simplification
497            Case {
498                input: "foo  bar  baz",
499                pattern: r"\s+",
500                content: " ",
501                expected: "foo bar baz",
502                expected_offsets: [0, 1, 2, 3, 5, 6, 7, 8, 10, 11, 12].into(),
503            },
504            // Pattern with overlapping matches
505            Case {
506                input: "foo   bar   baz",
507                pattern: r"  ",
508                content: " ",
509                expected: "foo  bar  baz",
510                expected_offsets: [0, 1, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13, 14].into(),
511            },
512        ];
513
514        cases.test_each(|case| {
515            let Case {
516                input,
517                pattern,
518                content,
519                expected,
520                expected_offsets,
521            } = case;
522
523            let normalizer = Replace::new(pattern, content.to_string()).unwrap();
524            let (normalized, offsets) = normalizer.normalize(input).unwrap();
525            assert_eq!(offsets.len(), normalized.len());
526            assert_eq!(normalized, *expected);
527            assert_eq!(offsets, *expected_offsets);
528        })
529    }
530
531    fn lowercase_normalizer() -> Box<dyn Normalizer> {
532        Box::new(Bert::new(BertOptions {
533            lowercase: true,
534            strip_accents: false,
535        }))
536    }
537
538    fn nfc_normalizer() -> Box<dyn Normalizer> {
539        Box::new(Unicode::Nfc)
540    }
541
542    fn replace_normalizer(pattern: &str, content: &str) -> Box<dyn Normalizer> {
543        Box::new(Replace::new(pattern, content.to_string()).unwrap())
544    }
545
546    #[test]
547    fn test_sequence() {
548        use std::panic::AssertUnwindSafe;
549
550        #[derive(Debug)]
551        struct Case<'a> {
552            input: &'a str,
553            normalizers: AssertUnwindSafe<Vec<Box<dyn Normalizer>>>,
554            expected: &'a str,
555            expected_offsets: Vec<usize>,
556        }
557
558        let cases = [
559            // NFC + Lowercase + whitespace simplification.
560            //
561            // This is the sequence used by CLIP.
562            Case {
563                input: "FOO  BAR  BAZ",
564                normalizers: AssertUnwindSafe(
565                    [
566                        nfc_normalizer(),
567                        lowercase_normalizer(),
568                        replace_normalizer(r"\s+", " "),
569                    ]
570                    .into(),
571                ),
572                expected: "foo bar baz",
573                expected_offsets: [0, 1, 2, 3, 5, 6, 7, 8, 10, 11, 12].into(),
574            },
575            // Multiple normalizers that modify offsets.
576            Case {
577                input: "FOO BAR BAZ",
578                normalizers: AssertUnwindSafe(
579                    [
580                        replace_normalizer(" ", "--"),
581                        replace_normalizer("--", "_"),
582                        lowercase_normalizer(),
583                    ]
584                    .into(),
585                ),
586                expected: "foo_bar_baz",
587                expected_offsets: (0.."foo bar baz".len()).collect(),
588            },
589            // Empty sequence
590            Case {
591                input: "foo bar baz",
592                normalizers: AssertUnwindSafe(Vec::new()),
593                expected: "foo bar baz",
594                expected_offsets: (0.."foo bar baz".len()).collect(),
595            },
596        ];
597
598        cases.test_each_value(|case| {
599            let Case {
600                input,
601                normalizers,
602                expected,
603                expected_offsets,
604            } = case;
605
606            let seq = Sequence::from_vec(normalizers.0);
607            let (normalized, offsets) = seq.normalize(input).unwrap();
608            assert_eq!(normalized, expected);
609            assert_eq!(offsets, expected_offsets);
610        })
611    }
612
613    #[test]
614    fn test_unicode() {
615        #[derive(Debug)]
616        struct Case<'a> {
617            input: &'a str,
618            normalizer: Unicode,
619            expected: &'a str,
620            expected_offsets: Vec<usize>,
621        }
622
623        let noop_case = |normalizer| Case {
624            input: "abc",
625            normalizer,
626            expected: "abc",
627            expected_offsets: [0, 1, 2].into(),
628        };
629
630        let cases = [
631            // No-op compositions and decompositions
632            noop_case(Unicode::Nfc),
633            noop_case(Unicode::Nfd),
634            noop_case(Unicode::Nfkc),
635            noop_case(Unicode::Nfkd),
636            // Composition
637            Case {
638                input: "I\u{307}ab",
639                normalizer: Unicode::Nfc,
640                expected: "İab",
641                expected_offsets: [0, 0, 3, 4].into(),
642            },
643            // Canonical decomposition
644            Case {
645                input: "İa",
646                normalizer: Unicode::Nfd,
647                expected: "I\u{307}a",
648                expected_offsets: [0, 0, 0, 2].into(),
649            },
650            // Compatible decomposition, followed by composition
651            Case {
652                input: "①",
653                normalizer: Unicode::Nfkc,
654                expected: "1",
655                expected_offsets: [0].into(),
656            },
657            Case {
658                input: "Éab",
659                normalizer: Unicode::Nfkc,
660                expected: "Éab",
661                expected_offsets: [0, 0, 2, 3].into(),
662            },
663            // Compatible decomposition
664            Case {
665                input: "Éab",
666                normalizer: Unicode::Nfkd,
667                expected: "E\u{301}ab",
668                expected_offsets: [0, 0, 0, 2, 3].into(),
669            },
670        ];
671
672        cases.test_each(|case| {
673            let Case {
674                input,
675                normalizer,
676                expected,
677                expected_offsets,
678            } = case;
679
680            let (normalized, offsets) = normalizer.normalize(input).unwrap();
681            assert_eq!(normalized, *expected);
682            assert_eq!(normalized.len(), offsets.len());
683            assert_eq!(offsets, *expected_offsets);
684        })
685    }
686}