Skip to main content

piper_phoneme_streaming/
text_expand.rs

1use std::collections::{HashMap, VecDeque};
2
3use crate::expand_tasks::get_tasks_for_language;
4use crate::lang_detect::StreamingLanguageDetector;
5use crate::semantic::Language;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum TextUnit {
9    Word(String, Language),
10    Space,
11    ClauseBoundary(char),
12    Punctuation(char),
13}
14
15impl TextUnit {
16    /// Convert an [`ExpandUnit`] to a [`TextUnit`], tagging word/number units
17    /// with the given `language`.
18    pub fn from_expand_unit(unit: ExpandUnit, language: Language) -> Self {
19        match unit {
20            ExpandUnit::Word(s) | ExpandUnit::Number(s) => TextUnit::Word(s, language),
21            ExpandUnit::Mark(c) if c.is_whitespace() => TextUnit::Space,
22            ExpandUnit::Mark(c) if matches!(c, ',' | '.' | '!' | '?' | ';' | ':') => {
23                TextUnit::ClauseBoundary(c)
24            }
25            ExpandUnit::Mark(c) => TextUnit::Punctuation(c),
26        }
27    }
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Hash)]
31pub enum ExpandUnit {
32    Word(String),
33    Mark(char),
34    Number(String),
35}
36
37impl ExpandUnit {
38    /// Tokenize a raw string into a list of `ExpandUnit`s using the same logic
39    /// as `TextExpand::process_char`.
40    pub fn tokenize(input: &str) -> Vec<Self> {
41        let mut units = Vec::new();
42        let mut buffer = String::new();
43        let mut buffer_is_number = false;
44
45        let flush = |buffer: &mut String, is_number: bool, units: &mut Vec<Self>| {
46            if !buffer.is_empty() {
47                let content = std::mem::take(buffer);
48                if is_number {
49                    units.push(ExpandUnit::Number(content));
50                } else {
51                    units.push(ExpandUnit::Word(content));
52                }
53            }
54        };
55
56        for ch in input.chars() {
57            if ch.is_alphabetic() || ch == '\'' {
58                if !buffer.is_empty() && buffer_is_number {
59                    flush(&mut buffer, buffer_is_number, &mut units);
60                }
61                buffer.push(ch);
62                buffer_is_number = false;
63            } else if ch.is_ascii_digit() {
64                if !buffer.is_empty() && !buffer_is_number {
65                    flush(&mut buffer, buffer_is_number, &mut units);
66                }
67                buffer.push(ch);
68                buffer_is_number = true;
69            } else {
70                flush(&mut buffer, buffer_is_number, &mut units);
71                units.push(ExpandUnit::Mark(ch));
72            }
73        }
74        flush(&mut buffer, buffer_is_number, &mut units);
75        units
76    }
77}
78
79#[derive(Debug, Clone, PartialEq, Eq)]
80pub enum ExpandResult {
81    Maybe,
82    Replace(usize, Vec<ExpandUnit>),
83}
84
85pub trait ExpandTask: Send + Sync {
86    fn expand(&self, queue: &VecDeque<ExpandUnit>) -> Option<ExpandResult>;
87}
88
89pub struct TextExpand {
90    tasks_by_lang: HashMap<Language, Vec<Box<dyn ExpandTask>>>,
91    current_language: Language,
92
93    /// Optional streaming language detector. `None` in single-language mode.
94    lang_detector: Option<StreamingLanguageDetector>,
95
96    // Parallel queues — invariant: input_units.len() == input_langs.len()
97    input_units: VecDeque<ExpandUnit>,
98    input_langs: VecDeque<Language>,
99    output_units: VecDeque<(ExpandUnit, Language)>,
100
101    // Tokenizer state
102    buffer: String,
103    buffer_is_number: bool,
104}
105
106impl TextExpand {
107    /// Single-language mode (backward compatible). No detection overhead.
108    pub fn with_language(language: Language) -> Self {
109        let mut tasks_by_lang = HashMap::new();
110        tasks_by_lang.insert(language, get_tasks_for_language(language));
111        Self {
112            tasks_by_lang,
113            current_language: language,
114            lang_detector: None,
115            input_units: VecDeque::new(),
116            input_langs: VecDeque::new(),
117            output_units: VecDeque::new(),
118            buffer: String::new(),
119            buffer_is_number: false,
120        }
121    }
122
123    /// Multi-language mode with a pre-built [`StreamingLanguageDetector`].
124    pub fn with_detector(
125        languages: &[Language],
126        default_language: Language,
127        detector: StreamingLanguageDetector,
128    ) -> Self {
129        let mut tasks_by_lang = HashMap::new();
130        for &lang in languages {
131            tasks_by_lang.insert(lang, get_tasks_for_language(lang));
132        }
133        Self {
134            tasks_by_lang,
135            current_language: default_language,
136            lang_detector: Some(detector),
137            input_units: VecDeque::new(),
138            input_langs: VecDeque::new(),
139            output_units: VecDeque::new(),
140            buffer: String::new(),
141            buffer_is_number: false,
142        }
143    }
144
145    /// Test/internal constructor: flat task list, no detection, English default.
146    pub fn new(tasks: Vec<Box<dyn ExpandTask>>) -> Self {
147        let mut tasks_by_lang = HashMap::new();
148        tasks_by_lang.insert(Language::English, tasks);
149        Self {
150            tasks_by_lang,
151            current_language: Language::English,
152            lang_detector: None,
153            input_units: VecDeque::new(),
154            input_langs: VecDeque::new(),
155            output_units: VecDeque::new(),
156            buffer: String::new(),
157            buffer_is_number: false,
158        }
159    }
160
161    pub fn push(&mut self, ch: char) -> Option<(ExpandUnit, Language)> {
162        self.process_char(ch);
163        self.try_expand(false);
164        self.output_units.pop_front()
165    }
166
167    pub fn finish(&mut self) -> Option<(ExpandUnit, Language)> {
168        self.flush_buffer();
169        self.try_expand(true);
170        self.output_units.pop_front()
171    }
172
173    fn process_char(&mut self, ch: char) {
174        if ch.is_alphabetic() || ch == '\'' {
175            if !self.buffer.is_empty() && self.buffer_is_number {
176                self.flush_buffer();
177            }
178            self.buffer.push(ch);
179            self.buffer_is_number = false;
180        } else if ch.is_ascii_digit() {
181            if !self.buffer.is_empty() && !self.buffer_is_number {
182                self.flush_buffer();
183            }
184            self.buffer.push(ch);
185            self.buffer_is_number = true;
186        } else {
187            self.flush_buffer();
188            let mark = ExpandUnit::Mark(ch);
189            let lang = if let Some(detector) = &mut self.lang_detector {
190                let lang = detector.push(&mark);
191                if matches!(ch, '.' | '?' | '!') {
192                    detector.reset_context();
193                }
194                lang
195            } else {
196                self.current_language
197            };
198            self.input_units.push_back(mark);
199            self.input_langs.push_back(lang);
200        }
201    }
202
203    fn flush_buffer(&mut self) {
204        if self.buffer.is_empty() {
205            return;
206        }
207        let content = std::mem::take(&mut self.buffer);
208        let unit = if self.buffer_is_number {
209            ExpandUnit::Number(content)
210        } else {
211            ExpandUnit::Word(content)
212        };
213
214        let lang = if let Some(detector) = &mut self.lang_detector {
215            detector.push(&unit)
216        } else {
217            self.current_language
218        };
219
220        self.input_units.push_back(unit);
221        self.input_langs.push_back(lang);
222    }
223
224    fn try_expand(&mut self, is_final: bool) {
225        'outer: while !self.input_units.is_empty() {
226            debug_assert_eq!(
227                self.input_units.len(),
228                self.input_langs.len(),
229                "parallel queue invariant violated"
230            );
231
232            let front_lang = self.input_langs[0];
233            let tasks = self
234                .tasks_by_lang
235                .get(&front_lang)
236                .map(Vec::as_slice)
237                .unwrap_or(&[]);
238
239            for task in tasks {
240                match task.expand(&self.input_units) {
241                    Some(ExpandResult::Maybe) => {
242                        if !is_final {
243                            break 'outer;
244                        }
245                    }
246                    Some(ExpandResult::Replace(n, new_units)) => {
247                        debug_assert!(n > 0, "ExpandTask::expand must consume at least one unit");
248                        for _ in 0..n {
249                            self.input_units.pop_front();
250                            self.input_langs.pop_front();
251                        }
252                        // Prepend replacements inheriting the triggering language
253                        for unit in new_units.into_iter().rev() {
254                            self.input_units.push_front(unit);
255                            self.input_langs.push_front(front_lang);
256                        }
257                        continue 'outer;
258                    }
259                    None => {}
260                }
261            }
262
263            // No task matched — emit with language
264            if let Some(unit) = self.input_units.pop_front() {
265                let lang = self
266                    .input_langs
267                    .pop_front()
268                    .unwrap_or(self.current_language);
269                self.output_units.push_back((unit, lang));
270            }
271        }
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use crate::semantic::Language;
279
280    fn run_test(lang: Language, input: &str, expected: Vec<ExpandUnit>) {
281        let mut expander = TextExpand::with_language(lang);
282        let mut units = Vec::new();
283        for ch in input.chars() {
284            if let Some((unit, _lang)) = expander.push(ch) {
285                units.push(unit);
286            }
287        }
288        while let Some((unit, _lang)) = expander.finish() {
289            units.push(unit);
290        }
291        assert_eq!(
292            units, expected,
293            "Failed for input: '{}' in {:?}",
294            input, lang
295        );
296    }
297
298    #[test]
299    fn test_text_expand_cases_en() {
300        let cases = vec![
301            (
302                "12:30",
303                vec![
304                    ExpandUnit::Word("twelve".into()),
305                    ExpandUnit::Word("thirty".into()),
306                ],
307            ),
308            (
309                "12:00",
310                vec![
311                    ExpandUnit::Word("twelve".into()),
312                    ExpandUnit::Word("o'clock".into()),
313                ],
314            ),
315            (
316                "12:05",
317                vec![
318                    ExpandUnit::Word("twelve".into()),
319                    ExpandUnit::Word("oh".into()),
320                    ExpandUnit::Word("five".into()),
321                ],
322            ),
323            (
324                "24/03/2026",
325                vec![
326                    ExpandUnit::Word("March".into()),
327                    ExpandUnit::Word("twenty".into()),
328                    ExpandUnit::Word("fourth".into()),
329                    ExpandUnit::Mark(','),
330                    ExpandUnit::Word("two".into()),
331                    ExpandUnit::Word("thousand".into()),
332                    ExpandUnit::Word("and".into()),
333                    ExpandUnit::Word("twenty".into()),
334                    ExpandUnit::Word("six".into()),
335                ],
336            ),
337            (
338                "hello 123",
339                vec![
340                    ExpandUnit::Word("hello".into()),
341                    ExpandUnit::Mark(' '),
342                    ExpandUnit::Word("one".into()),
343                    ExpandUnit::Word("hundred".into()),
344                    ExpandUnit::Word("and".into()),
345                    ExpandUnit::Word("twenty".into()),
346                    ExpandUnit::Word("three".into()),
347                ],
348            ),
349            (
350                "ABC HFP",
351                vec![
352                    ExpandUnit::Word("A".into()),
353                    ExpandUnit::Mark(' '),
354                    ExpandUnit::Word("B".into()),
355                    ExpandUnit::Mark(' '),
356                    ExpandUnit::Word("C".into()),
357                    ExpandUnit::Mark(' '),
358                    ExpandUnit::Word("H".into()),
359                    ExpandUnit::Mark(' '),
360                    ExpandUnit::Word("F".into()),
361                    ExpandUnit::Mark(' '),
362                    ExpandUnit::Word("P".into()),
363                ],
364            ),
365            (
366                "Dr Smith vs Mr John",
367                vec![
368                    ExpandUnit::Word("doctor".into()),
369                    ExpandUnit::Mark(' '),
370                    ExpandUnit::Word("Smith".into()),
371                    ExpandUnit::Mark(' '),
372                    ExpandUnit::Word("versus".into()),
373                    ExpandUnit::Mark(' '),
374                    ExpandUnit::Word("mister".into()),
375                    ExpandUnit::Mark(' '),
376                    ExpandUnit::Word("John".into()),
377                ],
378            ),
379        ];
380
381        for (input, expected) in cases {
382            run_test(Language::English, input, expected);
383        }
384    }
385
386    #[test]
387    fn test_text_expand_cases_vi() {
388        let cases = vec![
389            (
390                "12:30",
391                vec![
392                    ExpandUnit::Word("mười".into()),
393                    ExpandUnit::Word("hai".into()),
394                    ExpandUnit::Word("giờ".into()),
395                    ExpandUnit::Word("ba".into()),
396                    ExpandUnit::Word("mươi".into()),
397                    ExpandUnit::Word("phút".into()),
398                ],
399            ),
400            (
401                "24/03",
402                vec![
403                    ExpandUnit::Word("ngày".into()),
404                    ExpandUnit::Word("hai".into()),
405                    ExpandUnit::Word("mươi".into()),
406                    ExpandUnit::Word("tư".into()),
407                    ExpandUnit::Word("tháng".into()),
408                    ExpandUnit::Word("ba".into()),
409                ],
410            ),
411            (
412                "105",
413                vec![
414                    ExpandUnit::Word("một".into()),
415                    ExpandUnit::Word("trăm".into()),
416                    ExpandUnit::Word("linh".into()),
417                    ExpandUnit::Word("năm".into()),
418                ],
419            ),
420            (
421                "21",
422                vec![
423                    ExpandUnit::Word("hai".into()),
424                    ExpandUnit::Word("mươi".into()),
425                    ExpandUnit::Word("mốt".into()),
426                ],
427            ),
428            (
429                "15",
430                vec![
431                    ExpandUnit::Word("mười".into()),
432                    ExpandUnit::Word("lăm".into()),
433                ],
434            ),
435            (
436                "FPT abc TP hcm v.v.",
437                vec![
438                    ExpandUnit::Word("F".into()),
439                    ExpandUnit::Mark(' '),
440                    ExpandUnit::Word("P".into()),
441                    ExpandUnit::Mark(' '),
442                    ExpandUnit::Word("T".into()),
443                    ExpandUnit::Mark(' '),
444                    ExpandUnit::Word("abc".into()),
445                    ExpandUnit::Mark(' '),
446                    ExpandUnit::Word("thành".into()),
447                    ExpandUnit::Mark(' '),
448                    ExpandUnit::Word("phố".into()),
449                    ExpandUnit::Mark(' '),
450                    ExpandUnit::Word("hcm".into()), // lower case not an acronym
451                    ExpandUnit::Mark(' '),
452                    ExpandUnit::Word("v".into()),
453                    ExpandUnit::Mark('.'),
454                    ExpandUnit::Word("v".into()),
455                    ExpandUnit::Mark('.'),
456                ],
457            ),
458        ];
459
460        for (input, expected) in cases {
461            run_test(Language::Vietnamese, input, expected);
462        }
463    }
464}