1use std::collections::{HashMap, VecDeque};
2
3use crate::expand_tasks::get_tasks_for_language;
4use crate::lang_detect::StreamingLanguageDetector;
5use crate::semantic::Language;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum TextUnit {
9 Word(String, Language),
10 Space,
11 ClauseBoundary(char),
12 Punctuation(char),
13}
14
15impl TextUnit {
16 pub fn from_expand_unit(unit: ExpandUnit, language: Language) -> Self {
19 match unit {
20 ExpandUnit::Word(s) | ExpandUnit::Number(s) => TextUnit::Word(s, language),
21 ExpandUnit::Mark(c) if c.is_whitespace() => TextUnit::Space,
22 ExpandUnit::Mark(c) if matches!(c, ',' | '.' | '!' | '?' | ';' | ':') => {
23 TextUnit::ClauseBoundary(c)
24 }
25 ExpandUnit::Mark(c) => TextUnit::Punctuation(c),
26 }
27 }
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Hash)]
31pub enum ExpandUnit {
32 Word(String),
33 Mark(char),
34 Number(String),
35}
36
37impl ExpandUnit {
38 pub fn tokenize(input: &str) -> Vec<Self> {
41 let mut units = Vec::new();
42 let mut buffer = String::new();
43 let mut buffer_is_number = false;
44
45 let flush = |buffer: &mut String, is_number: bool, units: &mut Vec<Self>| {
46 if !buffer.is_empty() {
47 let content = std::mem::take(buffer);
48 if is_number {
49 units.push(ExpandUnit::Number(content));
50 } else {
51 units.push(ExpandUnit::Word(content));
52 }
53 }
54 };
55
56 for ch in input.chars() {
57 if ch.is_alphabetic() || ch == '\'' {
58 if !buffer.is_empty() && buffer_is_number {
59 flush(&mut buffer, buffer_is_number, &mut units);
60 }
61 buffer.push(ch);
62 buffer_is_number = false;
63 } else if ch.is_ascii_digit() {
64 if !buffer.is_empty() && !buffer_is_number {
65 flush(&mut buffer, buffer_is_number, &mut units);
66 }
67 buffer.push(ch);
68 buffer_is_number = true;
69 } else {
70 flush(&mut buffer, buffer_is_number, &mut units);
71 units.push(ExpandUnit::Mark(ch));
72 }
73 }
74 flush(&mut buffer, buffer_is_number, &mut units);
75 units
76 }
77}
78
79#[derive(Debug, Clone, PartialEq, Eq)]
80pub enum ExpandResult {
81 Maybe,
82 Replace(usize, Vec<ExpandUnit>),
83}
84
85pub trait ExpandTask: Send + Sync {
86 fn expand(&self, queue: &VecDeque<ExpandUnit>) -> Option<ExpandResult>;
87}
88
89pub struct TextExpand {
90 tasks_by_lang: HashMap<Language, Vec<Box<dyn ExpandTask>>>,
91 current_language: Language,
92
93 lang_detector: Option<StreamingLanguageDetector>,
95
96 input_units: VecDeque<ExpandUnit>,
98 input_langs: VecDeque<Language>,
99 output_units: VecDeque<(ExpandUnit, Language)>,
100
101 buffer: String,
103 buffer_is_number: bool,
104}
105
106impl TextExpand {
107 pub fn with_language(language: Language) -> Self {
109 let mut tasks_by_lang = HashMap::new();
110 tasks_by_lang.insert(language, get_tasks_for_language(language));
111 Self {
112 tasks_by_lang,
113 current_language: language,
114 lang_detector: None,
115 input_units: VecDeque::new(),
116 input_langs: VecDeque::new(),
117 output_units: VecDeque::new(),
118 buffer: String::new(),
119 buffer_is_number: false,
120 }
121 }
122
123 pub fn with_detector(
125 languages: &[Language],
126 default_language: Language,
127 detector: StreamingLanguageDetector,
128 ) -> Self {
129 let mut tasks_by_lang = HashMap::new();
130 for &lang in languages {
131 tasks_by_lang.insert(lang, get_tasks_for_language(lang));
132 }
133 Self {
134 tasks_by_lang,
135 current_language: default_language,
136 lang_detector: Some(detector),
137 input_units: VecDeque::new(),
138 input_langs: VecDeque::new(),
139 output_units: VecDeque::new(),
140 buffer: String::new(),
141 buffer_is_number: false,
142 }
143 }
144
145 pub fn new(tasks: Vec<Box<dyn ExpandTask>>) -> Self {
147 let mut tasks_by_lang = HashMap::new();
148 tasks_by_lang.insert(Language::English, tasks);
149 Self {
150 tasks_by_lang,
151 current_language: Language::English,
152 lang_detector: None,
153 input_units: VecDeque::new(),
154 input_langs: VecDeque::new(),
155 output_units: VecDeque::new(),
156 buffer: String::new(),
157 buffer_is_number: false,
158 }
159 }
160
161 pub fn push(&mut self, ch: char) -> Option<(ExpandUnit, Language)> {
162 self.process_char(ch);
163 self.try_expand(false);
164 self.output_units.pop_front()
165 }
166
167 pub fn finish(&mut self) -> Option<(ExpandUnit, Language)> {
168 self.flush_buffer();
169 self.try_expand(true);
170 self.output_units.pop_front()
171 }
172
173 fn process_char(&mut self, ch: char) {
174 if ch.is_alphabetic() || ch == '\'' {
175 if !self.buffer.is_empty() && self.buffer_is_number {
176 self.flush_buffer();
177 }
178 self.buffer.push(ch);
179 self.buffer_is_number = false;
180 } else if ch.is_ascii_digit() {
181 if !self.buffer.is_empty() && !self.buffer_is_number {
182 self.flush_buffer();
183 }
184 self.buffer.push(ch);
185 self.buffer_is_number = true;
186 } else {
187 self.flush_buffer();
188 let mark = ExpandUnit::Mark(ch);
189 let lang = if let Some(detector) = &mut self.lang_detector {
190 let lang = detector.push(&mark);
191 if matches!(ch, '.' | '?' | '!') {
192 detector.reset_context();
193 }
194 lang
195 } else {
196 self.current_language
197 };
198 self.input_units.push_back(mark);
199 self.input_langs.push_back(lang);
200 }
201 }
202
203 fn flush_buffer(&mut self) {
204 if self.buffer.is_empty() {
205 return;
206 }
207 let content = std::mem::take(&mut self.buffer);
208 let unit = if self.buffer_is_number {
209 ExpandUnit::Number(content)
210 } else {
211 ExpandUnit::Word(content)
212 };
213
214 let lang = if let Some(detector) = &mut self.lang_detector {
215 detector.push(&unit)
216 } else {
217 self.current_language
218 };
219
220 self.input_units.push_back(unit);
221 self.input_langs.push_back(lang);
222 }
223
224 fn try_expand(&mut self, is_final: bool) {
225 'outer: while !self.input_units.is_empty() {
226 debug_assert_eq!(
227 self.input_units.len(),
228 self.input_langs.len(),
229 "parallel queue invariant violated"
230 );
231
232 let front_lang = self.input_langs[0];
233 let tasks = self
234 .tasks_by_lang
235 .get(&front_lang)
236 .map(Vec::as_slice)
237 .unwrap_or(&[]);
238
239 for task in tasks {
240 match task.expand(&self.input_units) {
241 Some(ExpandResult::Maybe) => {
242 if !is_final {
243 break 'outer;
244 }
245 }
246 Some(ExpandResult::Replace(n, new_units)) => {
247 debug_assert!(n > 0, "ExpandTask::expand must consume at least one unit");
248 for _ in 0..n {
249 self.input_units.pop_front();
250 self.input_langs.pop_front();
251 }
252 for unit in new_units.into_iter().rev() {
254 self.input_units.push_front(unit);
255 self.input_langs.push_front(front_lang);
256 }
257 continue 'outer;
258 }
259 None => {}
260 }
261 }
262
263 if let Some(unit) = self.input_units.pop_front() {
265 let lang = self
266 .input_langs
267 .pop_front()
268 .unwrap_or(self.current_language);
269 self.output_units.push_back((unit, lang));
270 }
271 }
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use crate::semantic::Language;
279
280 fn run_test(lang: Language, input: &str, expected: Vec<ExpandUnit>) {
281 let mut expander = TextExpand::with_language(lang);
282 let mut units = Vec::new();
283 for ch in input.chars() {
284 if let Some((unit, _lang)) = expander.push(ch) {
285 units.push(unit);
286 }
287 }
288 while let Some((unit, _lang)) = expander.finish() {
289 units.push(unit);
290 }
291 assert_eq!(
292 units, expected,
293 "Failed for input: '{}' in {:?}",
294 input, lang
295 );
296 }
297
298 #[test]
299 fn test_text_expand_cases_en() {
300 let cases = vec![
301 (
302 "12:30",
303 vec![
304 ExpandUnit::Word("twelve".into()),
305 ExpandUnit::Word("thirty".into()),
306 ],
307 ),
308 (
309 "12:00",
310 vec![
311 ExpandUnit::Word("twelve".into()),
312 ExpandUnit::Word("o'clock".into()),
313 ],
314 ),
315 (
316 "12:05",
317 vec![
318 ExpandUnit::Word("twelve".into()),
319 ExpandUnit::Word("oh".into()),
320 ExpandUnit::Word("five".into()),
321 ],
322 ),
323 (
324 "24/03/2026",
325 vec![
326 ExpandUnit::Word("March".into()),
327 ExpandUnit::Word("twenty".into()),
328 ExpandUnit::Word("fourth".into()),
329 ExpandUnit::Mark(','),
330 ExpandUnit::Word("two".into()),
331 ExpandUnit::Word("thousand".into()),
332 ExpandUnit::Word("and".into()),
333 ExpandUnit::Word("twenty".into()),
334 ExpandUnit::Word("six".into()),
335 ],
336 ),
337 (
338 "hello 123",
339 vec![
340 ExpandUnit::Word("hello".into()),
341 ExpandUnit::Mark(' '),
342 ExpandUnit::Word("one".into()),
343 ExpandUnit::Word("hundred".into()),
344 ExpandUnit::Word("and".into()),
345 ExpandUnit::Word("twenty".into()),
346 ExpandUnit::Word("three".into()),
347 ],
348 ),
349 (
350 "ABC HFP",
351 vec![
352 ExpandUnit::Word("A".into()),
353 ExpandUnit::Mark(' '),
354 ExpandUnit::Word("B".into()),
355 ExpandUnit::Mark(' '),
356 ExpandUnit::Word("C".into()),
357 ExpandUnit::Mark(' '),
358 ExpandUnit::Word("H".into()),
359 ExpandUnit::Mark(' '),
360 ExpandUnit::Word("F".into()),
361 ExpandUnit::Mark(' '),
362 ExpandUnit::Word("P".into()),
363 ],
364 ),
365 (
366 "Dr Smith vs Mr John",
367 vec![
368 ExpandUnit::Word("doctor".into()),
369 ExpandUnit::Mark(' '),
370 ExpandUnit::Word("Smith".into()),
371 ExpandUnit::Mark(' '),
372 ExpandUnit::Word("versus".into()),
373 ExpandUnit::Mark(' '),
374 ExpandUnit::Word("mister".into()),
375 ExpandUnit::Mark(' '),
376 ExpandUnit::Word("John".into()),
377 ],
378 ),
379 ];
380
381 for (input, expected) in cases {
382 run_test(Language::English, input, expected);
383 }
384 }
385
386 #[test]
387 fn test_text_expand_cases_vi() {
388 let cases = vec![
389 (
390 "12:30",
391 vec![
392 ExpandUnit::Word("mười".into()),
393 ExpandUnit::Word("hai".into()),
394 ExpandUnit::Word("giờ".into()),
395 ExpandUnit::Word("ba".into()),
396 ExpandUnit::Word("mươi".into()),
397 ExpandUnit::Word("phút".into()),
398 ],
399 ),
400 (
401 "24/03",
402 vec![
403 ExpandUnit::Word("ngày".into()),
404 ExpandUnit::Word("hai".into()),
405 ExpandUnit::Word("mươi".into()),
406 ExpandUnit::Word("tư".into()),
407 ExpandUnit::Word("tháng".into()),
408 ExpandUnit::Word("ba".into()),
409 ],
410 ),
411 (
412 "105",
413 vec![
414 ExpandUnit::Word("một".into()),
415 ExpandUnit::Word("trăm".into()),
416 ExpandUnit::Word("linh".into()),
417 ExpandUnit::Word("năm".into()),
418 ],
419 ),
420 (
421 "21",
422 vec![
423 ExpandUnit::Word("hai".into()),
424 ExpandUnit::Word("mươi".into()),
425 ExpandUnit::Word("mốt".into()),
426 ],
427 ),
428 (
429 "15",
430 vec![
431 ExpandUnit::Word("mười".into()),
432 ExpandUnit::Word("lăm".into()),
433 ],
434 ),
435 (
436 "FPT abc TP hcm v.v.",
437 vec![
438 ExpandUnit::Word("F".into()),
439 ExpandUnit::Mark(' '),
440 ExpandUnit::Word("P".into()),
441 ExpandUnit::Mark(' '),
442 ExpandUnit::Word("T".into()),
443 ExpandUnit::Mark(' '),
444 ExpandUnit::Word("abc".into()),
445 ExpandUnit::Mark(' '),
446 ExpandUnit::Word("thành".into()),
447 ExpandUnit::Mark(' '),
448 ExpandUnit::Word("phố".into()),
449 ExpandUnit::Mark(' '),
450 ExpandUnit::Word("hcm".into()), ExpandUnit::Mark(' '),
452 ExpandUnit::Word("v".into()),
453 ExpandUnit::Mark('.'),
454 ExpandUnit::Word("v".into()),
455 ExpandUnit::Mark('.'),
456 ],
457 ),
458 ];
459
460 for (input, expected) in cases {
461 run_test(Language::Vietnamese, input, expected);
462 }
463 }
464}