lance_tokenizer/
ascii_folding_filter.rs1use std::mem;
5
6use unicode_normalization::{UnicodeNormalization, char::is_combining_mark};
7
8use crate::{Token, TokenFilter, TokenStream, Tokenizer};
9
10#[derive(Clone)]
11pub struct AsciiFoldingFilter;
12
13impl TokenFilter for AsciiFoldingFilter {
14 type Tokenizer<T: Tokenizer> = AsciiFoldingFilterWrapper<T>;
15
16 fn transform<T: Tokenizer>(self, tokenizer: T) -> Self::Tokenizer<T> {
17 AsciiFoldingFilterWrapper {
18 tokenizer,
19 buffer: String::new(),
20 }
21 }
22}
23
24#[derive(Clone)]
25pub struct AsciiFoldingFilterWrapper<T> {
26 tokenizer: T,
27 buffer: String,
28}
29
30impl<T: Tokenizer> Tokenizer for AsciiFoldingFilterWrapper<T> {
31 type TokenStream<'a> = AsciiFoldingFilterTokenStream<'a, T::TokenStream<'a>>;
32
33 fn token_stream<'a>(&'a mut self, text: &'a str) -> Self::TokenStream<'a> {
34 self.buffer.clear();
35 AsciiFoldingFilterTokenStream {
36 buffer: &mut self.buffer,
37 tail: self.tokenizer.token_stream(text),
38 }
39 }
40}
41
42pub struct AsciiFoldingFilterTokenStream<'a, T> {
43 buffer: &'a mut String,
44 tail: T,
45}
46
47impl<T: TokenStream> TokenStream for AsciiFoldingFilterTokenStream<'_, T> {
48 fn advance(&mut self) -> bool {
49 if !self.tail.advance() {
50 return false;
51 }
52 if !self.token_mut().text.is_ascii() {
53 to_ascii(&self.tail.token().text, self.buffer);
54 mem::swap(&mut self.tail.token_mut().text, self.buffer);
55 }
56 true
57 }
58
59 fn token(&self) -> &Token {
60 self.tail.token()
61 }
62
63 fn token_mut(&mut self) -> &mut Token {
64 self.tail.token_mut()
65 }
66}
67
68fn to_ascii(text: &str, output: &mut String) {
69 output.clear();
70 for ch in text.chars() {
71 if ch.is_ascii() {
72 output.push(ch);
73 continue;
74 }
75
76 if let Some(mapped) = fold_char(ch) {
77 output.push_str(mapped);
78 continue;
79 }
80
81 let original_len = output.len();
82 for decomposed in ch.nfkd() {
83 if decomposed.is_ascii() {
84 output.push(decomposed);
85 } else if is_combining_mark(decomposed) {
86 continue;
87 } else if let Some(mapped) = fold_char(decomposed) {
88 output.push_str(mapped);
89 }
90 }
91
92 if output.len() == original_len {
93 output.push(ch);
94 }
95 }
96}
97
98fn fold_char(ch: char) -> Option<&'static str> {
99 match ch {
100 'ß' => Some("ss"),
101 'ẞ' => Some("SS"),
102 'Æ' => Some("AE"),
103 'æ' => Some("ae"),
104 'Œ' => Some("OE"),
105 'œ' => Some("oe"),
106 'Ø' => Some("O"),
107 'ø' => Some("o"),
108 'Ł' => Some("L"),
109 'ł' => Some("l"),
110 'Đ' | 'Ð' => Some("D"),
111 'đ' | 'ð' => Some("d"),
112 'Þ' => Some("TH"),
113 'þ' => Some("th"),
114 'Ħ' => Some("H"),
115 'ħ' => Some("h"),
116 'Ŧ' => Some("T"),
117 'ŧ' => Some("t"),
118 'Ŋ' => Some("N"),
119 'ŋ' => Some("n"),
120 'ı' => Some("i"),
121 'ĸ' => Some("k"),
122 'ſ' => Some("s"),
123 _ => None,
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use crate::{AsciiFoldingFilter, RawTokenizer, TextAnalyzer, Token};
130
131 fn collect_tokens(text: &str) -> Vec<Token> {
132 let mut analyzer = TextAnalyzer::builder(RawTokenizer::default())
133 .filter(AsciiFoldingFilter)
134 .build();
135 let mut stream = analyzer.token_stream(text);
136 let mut tokens = Vec::new();
137 stream.process(&mut |token| tokens.push(token.clone()));
138 tokens
139 }
140
141 #[test]
142 fn test_ascii_folding_accents() {
143 let tokens = collect_tokens("café");
144 assert_eq!(tokens[0].text, "cafe");
145 }
146
147 #[test]
148 fn test_ascii_folding_sharp_s() {
149 let tokens = collect_tokens("straße");
150 assert_eq!(tokens[0].text, "strasse");
151 }
152}