Skip to main content

lance_tokenizer/
ascii_folding_filter.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::mem;
5
6use unicode_normalization::{UnicodeNormalization, char::is_combining_mark};
7
8use crate::{Token, TokenFilter, TokenStream, Tokenizer};
9
10#[derive(Clone)]
11pub struct AsciiFoldingFilter;
12
13impl TokenFilter for AsciiFoldingFilter {
14    type Tokenizer<T: Tokenizer> = AsciiFoldingFilterWrapper<T>;
15
16    fn transform<T: Tokenizer>(self, tokenizer: T) -> Self::Tokenizer<T> {
17        AsciiFoldingFilterWrapper {
18            tokenizer,
19            buffer: String::new(),
20        }
21    }
22}
23
24#[derive(Clone)]
25pub struct AsciiFoldingFilterWrapper<T> {
26    tokenizer: T,
27    buffer: String,
28}
29
30impl<T: Tokenizer> Tokenizer for AsciiFoldingFilterWrapper<T> {
31    type TokenStream<'a> = AsciiFoldingFilterTokenStream<'a, T::TokenStream<'a>>;
32
33    fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> {
34        self.buffer.clear();
35        AsciiFoldingFilterTokenStream {
36            buffer: &mut self.buffer,
37            tail: self.tokenizer.token_stream(text),
38        }
39    }
40}
41
42pub struct AsciiFoldingFilterTokenStream<'a, T> {
43    buffer: &'a mut String,
44    tail: T,
45}
46
47impl<T: TokenStream> TokenStream for AsciiFoldingFilterTokenStream<'_, T> {
48    fn advance(&mut self) -> bool {
49        if !self.tail.advance() {
50            return false;
51        }
52        if !self.token_mut().text.is_ascii() {
53            to_ascii(&self.tail.token().text, self.buffer);
54            mem::swap(&mut self.tail.token_mut().text, self.buffer);
55        }
56        true
57    }
58
59    fn token(&self) -> &Token {
60        self.tail.token()
61    }
62
63    fn token_mut(&mut self) -> &mut Token {
64        self.tail.token_mut()
65    }
66}
67
68fn to_ascii(text: &str, output: &mut String) {
69    output.clear();
70    for ch in text.chars() {
71        if ch.is_ascii() {
72            output.push(ch);
73            continue;
74        }
75
76        if let Some(mapped) = fold_char(ch) {
77            output.push_str(mapped);
78            continue;
79        }
80
81        let original_len = output.len();
82        for decomposed in ch.nfkd() {
83            if decomposed.is_ascii() {
84                output.push(decomposed);
85            } else if is_combining_mark(decomposed) {
86                continue;
87            } else if let Some(mapped) = fold_char(decomposed) {
88                output.push_str(mapped);
89            }
90        }
91
92        if output.len() == original_len {
93            output.push(ch);
94        }
95    }
96}
97
98fn fold_char(ch: char) -> Option<&'static str> {
99    match ch {
100        'ß' => Some("ss"),
101        'ẞ' => Some("SS"),
102        'Æ' => Some("AE"),
103        'æ' => Some("ae"),
104        'Œ' => Some("OE"),
105        'œ' => Some("oe"),
106        'Ø' => Some("O"),
107        'ø' => Some("o"),
108        'Ł' => Some("L"),
109        'ł' => Some("l"),
110        'Đ' | 'Ð' => Some("D"),
111        'đ' | 'ð' => Some("d"),
112        'Þ' => Some("TH"),
113        'þ' => Some("th"),
114        'Ħ' => Some("H"),
115        'ħ' => Some("h"),
116        'Ŧ' => Some("T"),
117        'ŧ' => Some("t"),
118        'Ŋ' => Some("N"),
119        'ŋ' => Some("n"),
120        'ı' => Some("i"),
121        'ĸ' => Some("k"),
122        'ſ' => Some("s"),
123        _ => None,
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use crate::{AsciiFoldingFilter, RawTokenizer, TextAnalyzer, Token};
130
131    fn collect_tokens(text: &str) -> Vec<Token> {
132        let mut analyzer = TextAnalyzer::builder(RawTokenizer::default())
133            .filter(AsciiFoldingFilter)
134            .build();
135        let mut stream = analyzer.token_stream(text);
136        let mut tokens = Vec::new();
137        stream.process(&mut |token| tokens.push(token.clone()));
138        tokens
139    }
140
141    #[test]
142    fn test_ascii_folding_accents() {
143        let tokens = collect_tokens("café");
144        assert_eq!(tokens[0].text, "cafe");
145    }
146
147    #[test]
148    fn test_ascii_folding_sharp_s() {
149        let tokens = collect_tokens("straße");
150        assert_eq!(tokens[0].text, "strasse");
151    }
152}