Skip to main content

piper_plus_g2p/
multilingual.rs

1//! Multilingual phonemizer for code-switching text across N languages.
2//!
3//! Generalizes the concept of bilingual phonemization to support arbitrary
4//! language combinations. Detects language segments via Unicode ranges,
5//! delegates to language-specific phonemizers, and returns unified phoneme IDs.
6//!
7//! Port of the Python `multilingual.py`.
8
9use std::collections::{HashMap, HashSet};
10use std::sync::{Mutex, OnceLock};
11
12use crate::error::G2pError;
13use crate::phonemizer::{PhonemeIdMap, Phonemizer, ProsodyFeature, ProsodyInfo};
14use crate::token_map::token_to_pua;
15
16// ---------------------------------------------------------------------------
17// UnicodeLanguageDetector
18// ---------------------------------------------------------------------------
19
20/// Detect language from Unicode character ranges.
21///
22/// Supports CJK disambiguation (JA vs ZH) by checking for kana presence.
23/// Latin characters are mapped to a configurable default language.
24pub struct UnicodeLanguageDetector {
25    languages: HashSet<String>,
26    default_latin_language: String,
27    has_ja: bool,
28    has_zh: bool,
29    has_ko: bool,
30}
31
32impl UnicodeLanguageDetector {
33    /// Create a new detector for the given set of languages.
34    ///
35    /// `default_latin_language` controls which language Latin-script
36    /// characters (A-Z, a-z, accented Latin) are assigned to.
37    pub fn new(languages: &[String], default_latin_language: &str) -> Self {
38        let lang_set: HashSet<String> = languages.iter().cloned().collect();
39        Self {
40            has_ja: lang_set.contains("ja"),
41            has_zh: lang_set.contains("zh"),
42            has_ko: lang_set.contains("ko"),
43            default_latin_language: default_latin_language.to_string(),
44            languages: lang_set,
45        }
46    }
47
48    /// Detect language for a single character.
49    ///
50    /// `context_has_kana` is used for CJK ideograph disambiguation: if the
51    /// surrounding text contains kana, CJK ideographs are classified as
52    /// Japanese rather than Chinese.
53    ///
54    /// Returns `None` for neutral characters (whitespace, digits,
55    /// ASCII punctuation, etc.).
56    pub fn detect_char(&self, ch: char, context_has_kana: bool) -> Option<&str> {
57        let cp = ch as u32;
58
59        // 1. Hiragana (U+3040-309F), Katakana (U+30A0-30FF),
60        //    Katakana Phonetic Extensions (U+31F0-31FF)
61        if (0x3040..=0x30FF).contains(&cp) || (0x31F0..=0x31FF).contains(&cp) {
62            return if self.has_ja { Some("ja") } else { None };
63        }
64
65        // 2. Hangul Syllables (U+AC00-D7AF), Jamo (U+1100-11FF),
66        //    Compatibility Jamo (U+3130-318F)
67        if (0xAC00..=0xD7AF).contains(&cp)
68            || (0x1100..=0x11FF).contains(&cp)
69            || (0x3130..=0x318F).contains(&cp)
70        {
71            return if self.has_ko { Some("ko") } else { None };
72        }
73
74        // 3. CJK Unified Ideographs (U+4E00-9FFF), Extension A (U+3400-4DBF),
75        //    Compatibility Ideographs (U+F900-FAFF)
76        if (0x4E00..=0x9FFF).contains(&cp)
77            || (0x3400..=0x4DBF).contains(&cp)
78            || (0xF900..=0xFAFF).contains(&cp)
79        {
80            if self.has_ja && self.has_zh {
81                return if context_has_kana {
82                    Some("ja")
83                } else {
84                    Some("zh")
85                };
86            }
87            if self.has_ja {
88                return Some("ja");
89            }
90            if self.has_zh {
91                return Some("zh");
92            }
93            return None;
94        }
95
96        // 4. Fullwidth Latin letters: U+FF21-FF3A (A-Z), U+FF41-FF5A (a-z)
97        if (0xFF21..=0xFF3A).contains(&cp) || (0xFF41..=0xFF5A).contains(&cp) {
98            return if self.languages.contains(&self.default_latin_language) {
99                Some(&self.default_latin_language)
100            } else {
101                None
102            };
103        }
104
105        // 5. CJK punctuation (U+3000-303F) and fullwidth forms
106        //    (U+FF00-FF20, U+FF3B-FF40, U+FF5B-FFEF),
107        //    excluding fullwidth Latin letters handled above.
108        if (0x3000..=0x303F).contains(&cp)
109            || (0xFF00..=0xFF20).contains(&cp)
110            || (0xFF3B..=0xFF40).contains(&cp)
111            || (0xFF5B..=0xFFEF).contains(&cp)
112        {
113            return if self.has_ja { Some("ja") } else { None };
114        }
115
116        // 6. Latin characters: A-Z, a-z, and extended Latin with diacritics
117        //    (U+00C0-00D6, U+00D8-00F6, U+00F8-00FF)
118        //    Excludes multiplication sign (U+00D7) and division sign (U+00F7).
119        if ch.is_ascii_alphabetic()
120            || (0x00C0..=0x00D6).contains(&cp)
121            || (0x00D8..=0x00F6).contains(&cp)
122            || (0x00F8..=0x00FF).contains(&cp)
123        {
124            return if self.languages.contains(&self.default_latin_language) {
125                Some(&self.default_latin_language)
126            } else {
127                None
128            };
129        }
130
131        // 7. Everything else: digits, ASCII punctuation, whitespace → neutral
132        None
133    }
134
135    /// Check if text contains any Hiragana or Katakana characters.
136    pub fn has_kana(&self, text: &str) -> bool {
137        text.chars().any(|ch| {
138            let cp = ch as u32;
139            (0x3040..=0x30FF).contains(&cp) || (0x31F0..=0x31FF).contains(&cp)
140        })
141    }
142}
143
144// ---------------------------------------------------------------------------
145// segment_text
146// ---------------------------------------------------------------------------
147
148/// Split text into `(language, segment_text)` pairs using Unicode detection.
149///
150/// Neutral characters (whitespace, digits, punctuation) are absorbed into
151/// the preceding language segment. If no language-specific characters are
152/// found (e.g., text is only digits), falls back to `default_latin_language`.
153pub fn segment_text(text: &str, detector: &UnicodeLanguageDetector) -> Vec<(String, String)> {
154    if text.trim().is_empty() {
155        return Vec::new();
156    }
157
158    let context_has_kana = detector.has_kana(text);
159
160    let mut segments: Vec<(String, String)> = Vec::new();
161    let mut current_lang: Option<&str> = None;
162    let mut current_chars = String::new();
163
164    for ch in text.chars() {
165        let lang = detector.detect_char(ch, context_has_kana);
166
167        if let Some(detected) = lang {
168            if let Some(prev) = current_lang
169                && detected != prev
170            {
171                // Language changed — flush the current segment
172                segments.push((prev.to_string(), std::mem::take(&mut current_chars)));
173                // current_chars is now empty String (no allocation needed for clear)
174            }
175            current_lang = Some(detected);
176        }
177        // If lang is None (neutral char), keep current_lang unchanged
178        // so the neutral char gets absorbed into the current segment.
179        current_chars.push(ch);
180    }
181
182    // Flush remaining characters
183    if let Some(lang) = current_lang
184        && !current_chars.is_empty()
185    {
186        segments.push((lang.to_string(), current_chars));
187    }
188
189    // Fallback: if no language-specific chars were detected, use default
190    if segments.is_empty() && !text.trim().is_empty() {
191        segments.push((detector.default_latin_language.clone(), text.to_string()));
192    }
193
194    segments
195}
196
197/// A text segment with its detected language.
198#[derive(Debug, Clone)]
199pub struct TextSegment {
200    /// ISO 639-1 language code.
201    pub language: String,
202    /// The text content of this segment.
203    pub text: String,
204}
205
206// ---------------------------------------------------------------------------
207// default_post_process_ids
208// ---------------------------------------------------------------------------
209
210/// Shared BOS/EOS/padding post-processing (espeak-ng compatible).
211///
212/// Used by EN, ZH, KO, ES, FR, PT phonemizers and by
213/// `MultilingualPhonemizer`. Inserts pad tokens between every phoneme ID
214/// and wraps with BOS (^) / EOS markers.
215///
216/// The `eos_token` parameter allows a dynamic EOS (e.g., `"?"` or PUA
217/// question markers from Japanese). Falls back to `"$"` when the
218/// requested token is not found in the map.
219pub fn default_post_process_ids(
220    ids: Vec<i64>,
221    prosody: Vec<Option<ProsodyFeature>>,
222    id_map: &PhonemeIdMap,
223    eos_token: &str,
224) -> (Vec<i64>, Vec<Option<ProsodyFeature>>) {
225    let pad_ids = id_map.get("_").cloned().unwrap_or_else(|| vec![0]);
226    let bos_ids = id_map.get("^");
227    let eos_ids = id_map.get(eos_token).or_else(|| id_map.get("$"));
228
229    // Intersperse: pad after every phoneme, but skip after existing pad
230    // tokens to match the training data padding scheme.
231    let mut padded_ids = Vec::with_capacity(ids.len() * 2);
232    let mut padded_prosody = Vec::with_capacity(ids.len() * 2);
233
234    for (id, p) in ids.iter().zip(prosody.iter()) {
235        padded_ids.push(*id);
236        padded_prosody.push(*p);
237        if !pad_ids.contains(id) {
238            padded_ids.extend_from_slice(&pad_ids);
239            padded_prosody.extend(std::iter::repeat_n(None, pad_ids.len()));
240        }
241    }
242
243    // Wrap with BOS
244    if let Some(bos) = bos_ids {
245        let mut with_bos_ids = Vec::with_capacity(bos.len() + 1 + padded_ids.len());
246        with_bos_ids.extend_from_slice(bos);
247        with_bos_ids.push(pad_ids[0]);
248        with_bos_ids.extend_from_slice(&padded_ids);
249        let mut with_bos_prosody = Vec::with_capacity(bos.len() + 1 + padded_prosody.len());
250        with_bos_prosody.extend(std::iter::repeat_n(None, bos.len() + 1));
251        with_bos_prosody.extend_from_slice(&padded_prosody);
252        padded_ids = with_bos_ids;
253        padded_prosody = with_bos_prosody;
254    }
255
256    // Append EOS
257    if let Some(eos) = eos_ids {
258        padded_ids.extend_from_slice(eos);
259        padded_prosody.extend(std::iter::repeat_n(None, eos.len()));
260    }
261
262    (padded_ids, padded_prosody)
263}
264
265// ---------------------------------------------------------------------------
266// PassthroughPhonemizer
267// ---------------------------------------------------------------------------
268
269/// A simple phonemizer that performs character-level tokenization.
270///
271/// Used for languages without a native Rust phonemizer (en, zh, ko, es, fr, pt).
272/// Each character becomes a separate token. Relies on the phoneme_id_map from
273/// config.json for ID conversion.
274pub struct PassthroughPhonemizer {
275    lang_code: String,
276}
277
278impl PassthroughPhonemizer {
279    /// Create a new passthrough phonemizer for the given language code.
280    pub fn new(lang_code: &str) -> Self {
281        Self {
282            lang_code: lang_code.to_string(),
283        }
284    }
285}
286
287impl Phonemizer for PassthroughPhonemizer {
288    fn phonemize_with_prosody(
289        &self,
290        text: &str,
291    ) -> Result<(Vec<String>, Vec<Option<ProsodyInfo>>), G2pError> {
292        let tokens: Vec<String> = text.chars().map(|c| c.to_string()).collect();
293        let prosody: Vec<Option<ProsodyInfo>> = vec![None; tokens.len()];
294        Ok((tokens, prosody))
295    }
296
297    fn language_code(&self) -> &str {
298        &self.lang_code
299    }
300}
301
302// ---------------------------------------------------------------------------
303// MultilingualPhonemizer
304// ---------------------------------------------------------------------------
305
306/// Phonemizer that handles code-switching between N languages.
307///
308/// Segments the input text by language using Unicode ranges, delegates to
309/// language-specific phonemizers, and concatenates results in a unified
310/// phoneme space.
311///
312/// `last_eos` is set by `phonemize_with_prosody` and accessible via
313/// `last_eos()`. A `Mutex` provides interior mutability while
314/// satisfying the `Send + Sync` bounds required by the `Phonemizer` trait.
315pub struct MultilingualPhonemizer {
316    languages: Vec<String>,
317    default_latin_language: String,
318    detector: UnicodeLanguageDetector,
319    phonemizers: HashMap<String, Box<dyn Phonemizer>>,
320    /// Dynamic EOS token captured during the last `phonemize_with_prosody`
321    /// call. Accessible via `last_eos()`.
322    last_eos: Mutex<String>,
323}
324
325impl MultilingualPhonemizer {
326    /// Create a new multilingual phonemizer.
327    ///
328    /// `languages` lists the supported language codes (e.g., `["ja", "en"]`).
329    /// Each must have a corresponding entry in `phonemizers`.
330    ///
331    /// `default_latin_language` controls which language Latin-script
332    /// characters are assigned to. If not present in `languages`, falls
333    /// back to the first language.
334    pub fn new(
335        languages: Vec<String>,
336        mut default_latin_language: String,
337        phonemizers: HashMap<String, Box<dyn Phonemizer>>,
338    ) -> Self {
339        // Validate default_latin_language is in the supported set
340        if !languages.contains(&default_latin_language) {
341            default_latin_language = languages
342                .first()
343                .cloned()
344                .unwrap_or_else(|| "en".to_string());
345        }
346
347        let detector = UnicodeLanguageDetector::new(&languages, &default_latin_language);
348
349        Self {
350            languages,
351            default_latin_language,
352            detector,
353            phonemizers,
354            last_eos: Mutex::new("$".to_string()),
355        }
356    }
357
358    /// Replace the phonemizer for a given language.
359    ///
360    /// Used by WASM external dictionary loading: initially a PassthroughPhonemizer
361    /// is used for JA, then replaced with a real JapanesePhonemizer once the
362    /// dictionary bytes are available.
363    pub fn replace_phonemizer(&mut self, lang: &str, phonemizer: Box<dyn Phonemizer>) {
364        self.phonemizers.insert(lang.to_string(), phonemizer);
365    }
366
367    /// Return the list of supported language codes.
368    pub fn languages(&self) -> &[String] {
369        &self.languages
370    }
371
372    /// Return the last EOS token captured during `phonemize_with_prosody`.
373    ///
374    /// Defaults to `"$"`. Japanese segments may produce `"?"`, `"?!"`, etc.
375    /// Used by the encoder to pick the correct EOS during ID conversion.
376    pub fn last_eos(&self) -> String {
377        self.last_eos
378            .lock()
379            .map(|g| g.clone())
380            .unwrap_or_else(|_| "$".to_string())
381    }
382
383    /// Segment mixed-language text into per-language chunks.
384    ///
385    /// Each segment contains contiguous characters of the same detected
386    /// language. Neutral characters (whitespace, digits, punctuation) are
387    /// absorbed into the preceding segment.
388    ///
389    /// Returns a list of `TextSegment { language, text }` structs.
390    pub fn segment_text_structured(&self, text: &str) -> Vec<TextSegment> {
391        segment_text(text, &self.detector)
392            .into_iter()
393            .map(|(lang, txt)| TextSegment {
394                language: lang,
395                text: txt,
396            })
397            .collect()
398    }
399
400    /// Detect the primary language of the text.
401    ///
402    /// Returns the language code of the first detected language segment,
403    /// or the default_latin_language if no segments are detected.
404    pub fn detect_primary_language(&self, text: &str) -> &str {
405        let segments = segment_text(text, &self.detector);
406        if let Some((lang, _)) = segments.first() {
407            // Match against known language codes to return &str with correct lifetime
408            for supported in &self.languages {
409                if supported == lang {
410                    return supported.as_str();
411                }
412            }
413        }
414        &self.default_latin_language
415    }
416
417    /// Phonemize text with an explicit language hint.
418    ///
419    /// When a language hint is provided and the phonemizer for that language
420    /// exists, the entire text is routed to that language's phonemizer
421    /// (bypassing Unicode-based auto-detection). This is critical for
422    /// Latin-script languages (es/fr/pt) which cannot be distinguished from
423    /// English by Unicode ranges alone.
424    ///
425    /// Falls back to auto-detected segmentation if the hint is unknown.
426    pub fn phonemize_with_language_hint(
427        &self,
428        text: &str,
429        language: &str,
430    ) -> Result<(Vec<String>, Vec<Option<ProsodyInfo>>), G2pError> {
431        if let Some(phonemizer) = self.phonemizers.get(language) {
432            let (tokens, prosody) = phonemizer.phonemize_with_prosody(text)?;
433
434            // Strip BOS/EOS tokens from the segment, then re-wrap
435            let bos_eos = Self::bos_eos_tokens();
436            let eos_set = Self::eos_tokens();
437            let mut last_eos = "$".to_string();
438            let mut filtered_tokens = Vec::new();
439            let mut filtered_prosody = Vec::new();
440            for (ph, pr) in tokens.iter().zip(prosody.iter()) {
441                if bos_eos.contains(ph) {
442                    if eos_set.contains(ph) {
443                        last_eos = ph.clone();
444                    }
445                    continue;
446                }
447                filtered_tokens.push(ph.clone());
448                filtered_prosody.push(*pr);
449            }
450
451            if let Ok(mut guard) = self.last_eos.lock() {
452                *guard = last_eos;
453            }
454
455            Ok((filtered_tokens, filtered_prosody))
456        } else {
457            // Unknown language hint — fall back to auto-detection
458            self.phonemize_with_prosody(text)
459        }
460    }
461
462    /// Build the set of BOS/EOS-like tokens to strip from individual
463    /// segment outputs. Includes PUA-encoded Japanese question markers.
464    /// Cached via `OnceLock` to avoid re-constructing the `HashSet` on every call.
465    fn bos_eos_tokens() -> &'static HashSet<String> {
466        static TOKENS: OnceLock<HashSet<String>> = OnceLock::new();
467        TOKENS.get_or_init(|| {
468            let mut set = HashSet::new();
469            set.insert("^".to_string());
470            set.insert("$".to_string());
471            set.insert("?".to_string());
472            // PUA-encoded question markers (?!, ?., ?~)
473            for marker in &["?!", "?.", "?~"] {
474                if let Some(pua) = token_to_pua(marker) {
475                    set.insert(pua.to_string());
476                }
477            }
478            set
479        })
480    }
481
482    /// Build the set of EOS-like tokens (subset of BOS/EOS used to track
483    /// the last EOS for dynamic post-processing).
484    /// Cached via `OnceLock` to avoid re-constructing the `HashSet` on every call.
485    fn eos_tokens() -> &'static HashSet<String> {
486        static TOKENS: OnceLock<HashSet<String>> = OnceLock::new();
487        TOKENS.get_or_init(|| {
488            let mut set = HashSet::new();
489            set.insert("$".to_string());
490            set.insert("?".to_string());
491            for marker in &["?!", "?.", "?~"] {
492                if let Some(pua) = token_to_pua(marker) {
493                    set.insert(pua.to_string());
494                }
495            }
496            set
497        })
498    }
499}
500
501impl Phonemizer for MultilingualPhonemizer {
502    fn phonemize_with_prosody(
503        &self,
504        text: &str,
505    ) -> Result<(Vec<String>, Vec<Option<ProsodyInfo>>), G2pError> {
506        let segments = segment_text(text, &self.detector);
507        if segments.is_empty() {
508            return Ok((Vec::new(), Vec::new()));
509        }
510
511        let bos_eos = Self::bos_eos_tokens();
512        let eos_set = Self::eos_tokens();
513
514        let mut all_phonemes: Vec<String> = Vec::new();
515        let mut all_prosody: Vec<Option<ProsodyInfo>> = Vec::new();
516        let mut last_eos = "$".to_string();
517
518        for (lang, segment_text) in &segments {
519            let phonemizer = self
520                .phonemizers
521                .get(lang)
522                .ok_or_else(|| G2pError::UnsupportedLanguage { code: lang.clone() })?;
523
524            let (phonemes, prosody_list) = phonemizer.phonemize_with_prosody(segment_text)?;
525
526            // Strip BOS/EOS from individual segments.
527            // This includes PUA-encoded question markers from Japanese.
528            for (ph, pr) in phonemes.iter().zip(prosody_list.iter()) {
529                if bos_eos.contains(ph) {
530                    if eos_set.contains(ph) {
531                        last_eos = ph.clone();
532                    }
533                    continue;
534                }
535                all_phonemes.push(ph.clone());
536                all_prosody.push(*pr);
537            }
538        }
539
540        // Update last_eos via interior mutability (Mutex).
541        if let Ok(mut guard) = self.last_eos.lock() {
542            *guard = last_eos;
543        }
544
545        Ok((all_phonemes, all_prosody))
546    }
547
548    fn language_code(&self) -> &str {
549        // Return the default Latin language for multi-language mode.
550        &self.default_latin_language
551    }
552
553    fn detect_primary_language(&self, text: &str) -> &str {
554        // Delegate to the inherent method
555        MultilingualPhonemizer::detect_primary_language(self, text)
556    }
557}
558
559// ---------------------------------------------------------------------------
560// Tests
561// ---------------------------------------------------------------------------
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566
567    // ===== UnicodeLanguageDetector =====
568
569    fn make_detector(langs: &[&str], default_latin: &str) -> UnicodeLanguageDetector {
570        let lang_strings: Vec<String> = langs.iter().map(|s| s.to_string()).collect();
571        UnicodeLanguageDetector::new(&lang_strings, default_latin)
572    }
573
574    #[test]
575    fn test_detect_hiragana_as_ja() {
576        let det = make_detector(&["ja", "en"], "en");
577        assert_eq!(det.detect_char('\u{3042}', false), Some("ja")); // あ
578        assert_eq!(det.detect_char('\u{3093}', false), Some("ja")); // ん
579    }
580
581    #[test]
582    fn test_detect_katakana_as_ja() {
583        let det = make_detector(&["ja", "en"], "en");
584        assert_eq!(det.detect_char('\u{30A2}', false), Some("ja")); // ア
585        assert_eq!(det.detect_char('\u{30F3}', false), Some("ja")); // ン
586    }
587
588    #[test]
589    fn test_detect_katakana_phonetic_ext_as_ja() {
590        let det = make_detector(&["ja", "en"], "en");
591        assert_eq!(det.detect_char('\u{31F0}', false), Some("ja")); // ㇰ
592    }
593
594    #[test]
595    fn test_detect_hangul_as_ko() {
596        let det = make_detector(&["ja", "en", "ko"], "en");
597        assert_eq!(det.detect_char('\u{AC00}', false), Some("ko")); // 가
598        assert_eq!(det.detect_char('\u{D7AF}', false), Some("ko")); // last hangul syllable
599    }
600
601    #[test]
602    fn test_detect_hangul_jamo_as_ko() {
603        let det = make_detector(&["ko", "en"], "en");
604        assert_eq!(det.detect_char('\u{1100}', false), Some("ko")); // ᄀ
605        assert_eq!(det.detect_char('\u{3131}', false), Some("ko")); // ㄱ (compat)
606    }
607
608    #[test]
609    fn test_detect_cjk_as_zh_without_kana() {
610        let det = make_detector(&["ja", "en", "zh"], "en");
611        // CJK ideograph, no kana context → Chinese
612        assert_eq!(det.detect_char('\u{4E16}', false), Some("zh")); // 世
613    }
614
615    #[test]
616    fn test_detect_cjk_as_ja_with_kana_context() {
617        let det = make_detector(&["ja", "en", "zh"], "en");
618        // CJK ideograph, kana context → Japanese
619        assert_eq!(det.detect_char('\u{4E16}', true), Some("ja")); // 世
620    }
621
622    #[test]
623    fn test_detect_cjk_ja_only() {
624        let det = make_detector(&["ja", "en"], "en");
625        // Only JA is available, no ZH → always JA regardless of context
626        assert_eq!(det.detect_char('\u{4E16}', false), Some("ja"));
627    }
628
629    #[test]
630    fn test_detect_cjk_zh_only() {
631        let det = make_detector(&["zh", "en"], "en");
632        // Only ZH is available → always ZH
633        assert_eq!(det.detect_char('\u{4E16}', true), Some("zh"));
634    }
635
636    #[test]
637    fn test_detect_fullwidth_latin_as_default_latin() {
638        let det = make_detector(&["ja", "en"], "en");
639        assert_eq!(det.detect_char('\u{FF21}', false), Some("en")); // A
640        assert_eq!(det.detect_char('\u{FF5A}', false), Some("en")); // z
641    }
642
643    #[test]
644    fn test_detect_cjk_punctuation_as_ja() {
645        let det = make_detector(&["ja", "en"], "en");
646        assert_eq!(det.detect_char('\u{3001}', false), Some("ja")); // 、
647        assert_eq!(det.detect_char('\u{3002}', false), Some("ja")); // 。
648        assert_eq!(det.detect_char('\u{300C}', false), Some("ja")); // 「
649    }
650
651    #[test]
652    fn test_detect_latin_as_default_language() {
653        let det = make_detector(&["ja", "en"], "en");
654        assert_eq!(det.detect_char('H', false), Some("en"));
655        assert_eq!(det.detect_char('z', false), Some("en"));
656    }
657
658    #[test]
659    fn test_detect_accented_latin() {
660        let det = make_detector(&["ja", "fr"], "fr");
661        assert_eq!(det.detect_char('\u{00E9}', false), Some("fr")); // é
662        assert_eq!(det.detect_char('\u{00C0}', false), Some("fr")); // À
663    }
664
665    #[test]
666    fn test_detect_neutral_characters() {
667        let det = make_detector(&["ja", "en"], "en");
668        assert_eq!(det.detect_char(' ', false), None);
669        assert_eq!(det.detect_char('0', false), None);
670        assert_eq!(det.detect_char('!', false), None);
671        assert_eq!(det.detect_char('.', false), None);
672        assert_eq!(det.detect_char(',', false), None);
673    }
674
675    #[test]
676    fn test_detect_multiplication_sign_is_neutral() {
677        let det = make_detector(&["ja", "en"], "en");
678        // U+00D7 (×) is in the range but excluded from Latin
679        assert_eq!(det.detect_char('\u{00D7}', false), None);
680    }
681
682    #[test]
683    fn test_has_kana() {
684        let det = make_detector(&["ja", "en"], "en");
685        assert!(det.has_kana("こんにちは world"));
686        assert!(det.has_kana("アイウ"));
687        assert!(!det.has_kana("Hello world"));
688        assert!(!det.has_kana("你好世界"));
689    }
690
691    // ===== segment_text =====
692
693    #[test]
694    fn test_segment_pure_japanese() {
695        let det = make_detector(&["ja", "en"], "en");
696        let segs = segment_text("こんにちは", &det);
697        assert_eq!(segs.len(), 1);
698        assert_eq!(segs[0].0, "ja");
699        assert_eq!(segs[0].1, "こんにちは");
700    }
701
702    #[test]
703    fn test_segment_pure_english() {
704        let det = make_detector(&["ja", "en"], "en");
705        let segs = segment_text("Hello world", &det);
706        assert_eq!(segs.len(), 1);
707        assert_eq!(segs[0].0, "en");
708        assert_eq!(segs[0].1, "Hello world");
709    }
710
711    #[test]
712    fn test_segment_mixed_ja_en() {
713        let det = make_detector(&["ja", "en"], "en");
714        let segs = segment_text("今日はgood morningですね", &det);
715        assert_eq!(segs.len(), 3);
716        assert_eq!(segs[0].0, "ja");
717        assert_eq!(segs[0].1, "今日は");
718        assert_eq!(segs[1].0, "en");
719        assert_eq!(segs[1].1, "good morning");
720        assert_eq!(segs[2].0, "ja");
721        assert_eq!(segs[2].1, "ですね");
722    }
723
724    #[test]
725    fn test_segment_neutral_absorbed_into_preceding() {
726        let det = make_detector(&["ja", "en"], "en");
727        // "Hello, " — comma and space are neutral, absorbed into English
728        let segs = segment_text("Hello, こんにちは", &det);
729        assert_eq!(segs.len(), 2);
730        assert_eq!(segs[0].0, "en");
731        assert_eq!(segs[0].1, "Hello, ");
732        assert_eq!(segs[1].0, "ja");
733        assert_eq!(segs[1].1, "こんにちは");
734    }
735
736    #[test]
737    fn test_segment_leading_neutral_absorbed_into_first_language() {
738        let det = make_detector(&["ja", "en"], "en");
739        // Leading "123 " are neutral — no preceding segment, so they get
740        // absorbed into whatever language comes first.
741        let segs = segment_text("123 Hello", &det);
742        assert_eq!(segs.len(), 1);
743        assert_eq!(segs[0].0, "en");
744        assert_eq!(segs[0].1, "123 Hello");
745    }
746
747    #[test]
748    fn test_segment_empty_string() {
749        let det = make_detector(&["ja", "en"], "en");
750        let segs = segment_text("", &det);
751        assert!(segs.is_empty());
752    }
753
754    #[test]
755    fn test_segment_whitespace_only() {
756        let det = make_detector(&["ja", "en"], "en");
757        let segs = segment_text("   ", &det);
758        assert!(segs.is_empty());
759    }
760
761    #[test]
762    fn test_segment_digits_only_fallback() {
763        let det = make_detector(&["ja", "en"], "en");
764        // No language-specific characters — falls back to default
765        let segs = segment_text("12345", &det);
766        assert_eq!(segs.len(), 1);
767        assert_eq!(segs[0].0, "en");
768        assert_eq!(segs[0].1, "12345");
769    }
770
771    #[test]
772    fn test_segment_cjk_disambiguation_with_kana() {
773        let det = make_detector(&["ja", "en", "zh"], "en");
774        // Text with kana + CJK ideographs: the ideographs become JA
775        let segs = segment_text("漢字とかな", &det);
776        assert_eq!(segs.len(), 1);
777        assert_eq!(segs[0].0, "ja");
778    }
779
780    #[test]
781    fn test_segment_cjk_without_kana_is_zh() {
782        let det = make_detector(&["ja", "en", "zh"], "en");
783        // Pure CJK ideographs without kana → Chinese
784        let segs = segment_text("你好世界", &det);
785        assert_eq!(segs.len(), 1);
786        assert_eq!(segs[0].0, "zh");
787    }
788
789    #[test]
790    fn test_segment_mixed_zh_en() {
791        let det = make_detector(&["zh", "en"], "en");
792        let segs = segment_text("Hello你好", &det);
793        assert_eq!(segs.len(), 2);
794        assert_eq!(segs[0].0, "en");
795        assert_eq!(segs[0].1, "Hello");
796        assert_eq!(segs[1].0, "zh");
797        assert_eq!(segs[1].1, "你好");
798    }
799
800    // ===== default_post_process_ids =====
801
802    fn make_id_map() -> PhonemeIdMap {
803        let mut m = HashMap::new();
804        m.insert("_".to_string(), vec![0]);
805        m.insert("^".to_string(), vec![1]);
806        m.insert("$".to_string(), vec![2]);
807        m.insert("?".to_string(), vec![3]);
808        m
809    }
810
811    #[test]
812    fn test_post_process_basic_padding() {
813        let id_map = make_id_map();
814        let ids = vec![10, 11, 12];
815        let prosody = vec![None, None, None];
816        let (out_ids, out_prosody) = default_post_process_ids(ids, prosody, &id_map, "$");
817        // Expected: ^(1) + pad(0) + 10 + pad(0) + 11 + pad(0) + 12 + pad(0) + $(2)
818        assert_eq!(out_ids, vec![1, 0, 10, 0, 11, 0, 12, 0, 2]);
819        assert_eq!(out_prosody.len(), out_ids.len());
820    }
821
822    #[test]
823    fn test_post_process_skip_padding_after_pad_token() {
824        let id_map = make_id_map();
825        // ID 0 is a pad token — should NOT get another pad after it
826        let ids = vec![10, 0, 12];
827        let prosody = vec![None, None, None];
828        let (out_ids, _) = default_post_process_ids(ids, prosody, &id_map, "$");
829        // Expected: ^(1) + pad(0) + 10 + pad(0) + 0 (no pad after) + 12 + pad(0) + $(2)
830        assert_eq!(out_ids, vec![1, 0, 10, 0, 0, 12, 0, 2]);
831    }
832
833    #[test]
834    fn test_post_process_with_question_eos() {
835        let id_map = make_id_map();
836        let ids = vec![10];
837        let prosody = vec![None];
838        let (out_ids, _) = default_post_process_ids(ids, prosody, &id_map, "?");
839        // Expected: ^(1) + pad(0) + 10 + pad(0) + ?(3)
840        assert_eq!(out_ids, vec![1, 0, 10, 0, 3]);
841    }
842
843    #[test]
844    fn test_post_process_eos_fallback_to_dollar() {
845        let id_map = make_id_map();
846        let ids = vec![10];
847        let prosody = vec![None];
848        // Request EOS token "nonexistent" — should fall back to "$"
849        let (out_ids, _) = default_post_process_ids(ids, prosody, &id_map, "nonexistent");
850        // Expected: ^(1) + pad(0) + 10 + pad(0) + $(2)
851        assert_eq!(out_ids, vec![1, 0, 10, 0, 2]);
852    }
853
854    #[test]
855    fn test_post_process_empty_input() {
856        let id_map = make_id_map();
857        let ids: Vec<i64> = Vec::new();
858        let prosody: Vec<Option<ProsodyFeature>> = Vec::new();
859        let (out_ids, out_prosody) = default_post_process_ids(ids, prosody, &id_map, "$");
860        // Expected: ^(1) + pad(0) + $(2)
861        assert_eq!(out_ids, vec![1, 0, 2]);
862        assert_eq!(out_prosody.len(), out_ids.len());
863    }
864
865    #[test]
866    fn test_post_process_prosody_propagated() {
867        let id_map = make_id_map();
868        let ids = vec![10, 11];
869        let prosody = vec![Some([1, 2, 3]), None];
870        let (out_ids, out_prosody) = default_post_process_ids(ids, prosody, &id_map, "$");
871        // ^=None pad=None 10=Some([1,2,3]) pad=None 11=None pad=None $=None
872        assert_eq!(out_ids, vec![1, 0, 10, 0, 11, 0, 2]);
873        assert!(out_prosody[0].is_none()); // ^
874        assert!(out_prosody[1].is_none()); // pad
875        assert_eq!(out_prosody[2], Some([1, 2, 3])); // phoneme 10
876        assert!(out_prosody[3].is_none()); // pad
877        assert!(out_prosody[4].is_none()); // phoneme 11
878        assert!(out_prosody[5].is_none()); // pad
879        assert!(out_prosody[6].is_none()); // $
880    }
881
882    // ===== BOS/EOS token sets =====
883
884    #[test]
885    fn test_bos_eos_tokens_include_pua_markers() {
886        let set = MultilingualPhonemizer::bos_eos_tokens();
887        assert!(set.contains("^"));
888        assert!(set.contains("$"));
889        assert!(set.contains("?"));
890        // PUA markers for ?!, ?., ?~
891        assert!(set.contains(&"\u{E016}".to_string())); // ?!
892        assert!(set.contains(&"\u{E017}".to_string())); // ?.
893        assert!(set.contains(&"\u{E018}".to_string())); // ?~
894    }
895
896    #[test]
897    fn test_eos_tokens_subset() {
898        let eos_set = MultilingualPhonemizer::eos_tokens();
899        let bos_eos_set = MultilingualPhonemizer::bos_eos_tokens();
900        // EOS set should be a subset of BOS/EOS set
901        for token in eos_set {
902            assert!(
903                bos_eos_set.contains(token),
904                "EOS token {:?} not in BOS/EOS set",
905                token
906            );
907        }
908        // BOS (^) should be in bos_eos but NOT in eos
909        assert!(!eos_set.contains("^"));
910    }
911
912    // ===== Integration: default_post_process_ids =====
913
914    #[test]
915    fn test_default_post_process_ids_and_prosody_lengths_match() {
916        let id_map = make_id_map();
917        let ids = vec![5, 6, 7, 8, 9];
918        let prosody: Vec<Option<ProsodyFeature>> =
919            vec![Some([1, 0, 3]), None, Some([0, 2, 4]), None, None];
920        let (out_ids, out_prosody) = default_post_process_ids(ids, prosody, &id_map, "$");
921        assert_eq!(
922            out_ids.len(),
923            out_prosody.len(),
924            "IDs ({}) and prosody ({}) length mismatch",
925            out_ids.len(),
926            out_prosody.len()
927        );
928    }
929
930    // ===== replace_phonemizer =====
931
932    #[test]
933    fn test_replace_phonemizer() {
934        // Setup: create a multilingual phonemizer with 2 languages
935        let mut phonemizers: HashMap<String, Box<dyn Phonemizer>> = HashMap::new();
936        phonemizers.insert("ja".to_string(), Box::new(PassthroughPhonemizer::new("ja")));
937        phonemizers.insert("en".to_string(), Box::new(PassthroughPhonemizer::new("en")));
938
939        let mut mp = MultilingualPhonemizer::new(
940            vec!["ja".to_string(), "en".to_string()],
941            "en".to_string(),
942            phonemizers,
943        );
944
945        // Phonemize Japanese text with passthrough (should produce character-level tokens)
946        let (tokens_before, _) = mp.phonemize_with_prosody("あ").unwrap();
947
948        // Replace JA phonemizer with a new passthrough (same type, but proves replacement works)
949        mp.replace_phonemizer("ja", Box::new(PassthroughPhonemizer::new("ja")));
950
951        // Phonemize again — should still work after replacement
952        let (tokens_after, _) = mp.phonemize_with_prosody("あ").unwrap();
953        assert_eq!(
954            tokens_before, tokens_after,
955            "replacement should produce same results"
956        );
957    }
958
959    fn make_hint_test_phonemizer() -> MultilingualPhonemizer {
960        let mut phonemizers: HashMap<String, Box<dyn Phonemizer>> = HashMap::new();
961        phonemizers.insert("ja".to_string(), Box::new(PassthroughPhonemizer::new("ja")));
962        phonemizers.insert("en".to_string(), Box::new(PassthroughPhonemizer::new("en")));
963        phonemizers.insert("es".to_string(), Box::new(PassthroughPhonemizer::new("es")));
964        MultilingualPhonemizer::new(
965            vec!["ja".to_string(), "en".to_string(), "es".to_string()],
966            "en".to_string(),
967            phonemizers,
968        )
969    }
970
971    #[test]
972    fn test_language_hint_routes_to_correct_phonemizer() {
973        let mp = make_hint_test_phonemizer();
974
975        // Without hint: "Hola" is Latin → default_latin (en)
976        let (tokens_auto, _) = mp.phonemize_with_prosody("Hola").unwrap();
977
978        // With hint "es": routes directly to es phonemizer
979        let (tokens_hint, _) = mp.phonemize_with_language_hint("Hola", "es").unwrap();
980
981        // Both should produce output (not empty)
982        assert!(!tokens_auto.is_empty(), "auto-detect should produce tokens");
983        assert!(
984            !tokens_hint.is_empty(),
985            "language hint should produce tokens"
986        );
987    }
988
989    #[test]
990    fn test_language_hint_unknown_falls_back_to_auto() {
991        let mp = make_hint_test_phonemizer();
992
993        // Unknown language hint should fall back to auto-detection
994        let (tokens, _) = mp.phonemize_with_language_hint("Hello", "xx").unwrap();
995        assert!(
996            !tokens.is_empty(),
997            "unknown hint should fall back to auto-detect"
998        );
999    }
1000
1001    #[test]
1002    fn test_language_hint_ja_matches_auto_detect() {
1003        let mp = make_hint_test_phonemizer();
1004
1005        // "あ" with ja hint → JA phonemizer
1006        let (tokens_hint, _) = mp.phonemize_with_language_hint("あ", "ja").unwrap();
1007        let (tokens_auto, _) = mp.phonemize_with_prosody("あ").unwrap();
1008
1009        // Both should produce the same result since auto-detect also detects ja
1010        assert_eq!(
1011            tokens_hint, tokens_auto,
1012            "ja hint should match auto-detected ja"
1013        );
1014    }
1015}