1use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8
9use crate::config::{Language, NormalizerConfig, Operator};
10use crate::contractions::fix_contractions;
11use crate::error::{Result, WeTextError};
12use crate::text_normalizer::FstTextNormalizer;
13use crate::token_parser::TokenParser;
14
15struct FstCache {
17 fsts: HashMap<String, FstTextNormalizer>,
18 fst_dir: PathBuf,
19}
20
21impl FstCache {
22 fn new<P: AsRef<Path>>(fst_dir: P) -> Self {
23 Self {
24 fsts: HashMap::new(),
25 fst_dir: fst_dir.as_ref().to_path_buf(),
26 }
27 }
28
29 fn get_or_load(&mut self, relative_path: &str) -> Result<&FstTextNormalizer> {
30 if !self.fsts.contains_key(relative_path) {
31 let full_path = self.fst_dir.join(relative_path);
32 let normalizer = FstTextNormalizer::from_file(&full_path)?;
33 self.fsts.insert(relative_path.to_string(), normalizer);
34 }
35 Ok(self.fsts.get(relative_path).unwrap())
36 }
37}
38
39pub struct Normalizer {
55 config: NormalizerConfig,
56 cache: FstCache,
57}
58
59impl Normalizer {
60 pub fn new<P: AsRef<Path>>(fst_dir: P, config: NormalizerConfig) -> Self {
66 Self {
67 config,
68 cache: FstCache::new(fst_dir),
69 }
70 }
71
72 pub fn with_defaults<P: AsRef<Path>>(fst_dir: P) -> Self {
74 Self::new(fst_dir, NormalizerConfig::default())
75 }
76
77 pub fn normalize(&mut self, text: &str) -> Result<String> {
79 self.normalize_with_config(text, &self.config.clone())
80 }
81
82 pub fn normalize_with_config(
84 &mut self,
85 text: &str,
86 config: &NormalizerConfig,
87 ) -> Result<String> {
88 let mut text = text.to_string();
89
90 if config.fix_contractions && text.contains('\'') {
92 text = fix_contractions(&text);
93 }
94
95 text = self.preprocess(&text, config)?;
97
98 let lang = if config.lang == Language::Auto {
100 Self::detect_language(&text)
101 } else {
102 config.lang
103 };
104
105 if self.should_normalize(&text, config.operator, config.remove_erhua) {
107 let lang = if lang == Language::En && config.operator == Operator::Itn {
110 Language::Zh
111 } else {
112 lang
113 };
114
115 text = self.tag(&text, lang, config)?;
117
118 text = self.reorder(&text, lang, config.operator)?;
120
121 text = self.verbalize(&text, lang, config)?;
123 }
124
125 text = self.postprocess(&text, config)?;
127
128 Ok(text)
129 }
130
131 fn detect_language(text: &str) -> Language {
143 let mut has_cjk = false;
144 let mut has_alpha = false;
145
146 for ch in text.chars() {
147 if ('\u{3040}'..='\u{309f}').contains(&ch) || ('\u{30a0}'..='\u{30ff}').contains(&ch) {
152 return Language::Ja;
153 }
154
155 if ('\u{4e00}'..='\u{9fff}').contains(&ch) {
159 has_cjk = true;
160 }
161
162 if ch.is_ascii_alphabetic() {
164 has_alpha = true;
165 }
166 }
167
168 if has_cjk {
170 return Language::Zh;
171 }
172
173 if !text.is_empty() && !has_alpha {
176 return Language::Zh;
177 }
178
179 Language::En
180 }
181
182 fn should_normalize(&self, text: &str, operator: Operator, remove_erhua: bool) -> bool {
184 if operator == Operator::Tn {
185 if text.chars().any(|c| c.is_ascii_digit()) {
187 return true;
188 }
189 if remove_erhua && (text.contains('儿') || text.contains('兒')) {
191 return true;
192 }
193 false
194 } else {
195 !text.is_empty()
197 }
198 }
199
200 fn preprocess(&mut self, text: &str, config: &NormalizerConfig) -> Result<String> {
202 let mut result = text.trim().to_string();
203
204 if config.traditional_to_simple {
205 let fst = self.cache.get_or_load("traditional_to_simple.fst")?;
206 result = fst.normalize(&result)?;
207 }
208
209 Ok(result)
210 }
211
212 fn postprocess(&mut self, text: &str, config: &NormalizerConfig) -> Result<String> {
214 let mut result = text.to_string();
215
216 if config.full_to_half {
217 let fst = self.cache.get_or_load("full_to_half.fst")?;
218 result = fst.normalize(&result)?;
219 }
220
221 if config.remove_interjections {
222 let fst = self.cache.get_or_load("remove_interjections.fst")?;
223 result = fst.normalize(&result)?;
224 }
225
226 if config.remove_puncts {
227 let fst = self.cache.get_or_load("remove_puncts.fst")?;
228 result = fst.normalize(&result)?;
229 }
230
231 if config.tag_oov {
232 let fst = self.cache.get_or_load("tag_oov.fst")?;
233 result = fst.normalize(&result)?;
234 }
235
236 Ok(result.trim().to_string())
237 }
238
239 fn tag(&mut self, text: &str, lang: Language, config: &NormalizerConfig) -> Result<String> {
241 let fst_path = match (lang, config.operator) {
242 (Language::En, Operator::Tn) => "en/tn/tagger.fst",
243 (Language::Zh, Operator::Tn) => "zh/tn/tagger.fst",
244 (Language::Zh, Operator::Itn) => {
245 if config.enable_0_to_9 {
246 "zh/itn/tagger_enable_0_to_9.fst"
247 } else {
248 "zh/itn/tagger.fst"
249 }
250 }
251 (Language::Ja, Operator::Tn) => "ja/tn/tagger.fst",
252 (Language::Ja, Operator::Itn) => {
253 if config.enable_0_to_9 {
254 "ja/itn/tagger_enable_0_to_9.fst"
255 } else {
256 "ja/itn/tagger.fst"
257 }
258 }
259 _ => return Err(WeTextError::InvalidLanguage(format!("{:?}", lang))),
260 };
261
262 let fst = self.cache.get_or_load(fst_path)?;
263 let result = fst.normalize(text)?;
264 Ok(result.trim().to_string())
265 }
266
267 fn reorder(&self, text: &str, lang: Language, operator: Operator) -> Result<String> {
269 let parser = TokenParser::new(lang, operator);
270 parser.reorder(text)
271 }
272
273 fn verbalize(
275 &mut self,
276 text: &str,
277 lang: Language,
278 config: &NormalizerConfig,
279 ) -> Result<String> {
280 let fst_path = match (lang, config.operator) {
281 (Language::En, Operator::Tn) => "en/tn/verbalizer.fst",
282 (Language::Zh, Operator::Tn) => {
283 if config.remove_erhua {
284 "zh/tn/verbalizer_remove_erhua.fst"
285 } else {
286 "zh/tn/verbalizer.fst"
287 }
288 }
289 (Language::Zh, Operator::Itn) => "zh/itn/verbalizer.fst",
290 (Language::Ja, Operator::Tn) => "ja/tn/verbalizer.fst",
291 (Language::Ja, Operator::Itn) => "ja/itn/verbalizer.fst",
292 _ => return Err(WeTextError::InvalidLanguage(format!("{:?}", lang))),
293 };
294
295 let fst = self.cache.get_or_load(fst_path)?;
296 let result = fst.normalize(text)?;
297 Ok(result.trim().to_string())
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_detect_language() {
307 assert_eq!(Normalizer::detect_language("hello world"), Language::En);
309 assert_eq!(Normalizer::detect_language("Hello, World!"), Language::En);
310
311 assert_eq!(Normalizer::detect_language("你好世界"), Language::Zh);
313 assert_eq!(Normalizer::detect_language("今天是2024年"), Language::Zh);
314
315 assert_eq!(Normalizer::detect_language("こんにちは"), Language::Ja); assert_eq!(Normalizer::detect_language("カタカナ"), Language::Ja); assert_eq!(Normalizer::detect_language("東京タワー"), Language::Ja); assert_eq!(Normalizer::detect_language("123"), Language::Zh);
322 assert_eq!(Normalizer::detect_language("2024"), Language::Zh);
323
324 assert_eq!(Normalizer::detect_language(""), Language::En); }
327}