Skip to main content

piper_plus/phonemize/
multilingual.rs

1//! Multilingual phonemizer for code-switching text across N languages.
2//!
3//! Generalizes the concept of bilingual phonemization to support arbitrary
4//! language combinations. Detects language segments via Unicode ranges,
5//! delegates to language-specific phonemizers, and returns unified phoneme IDs.
6//!
7//! Port of the Python `multilingual.py`.
8
9use std::collections::{HashMap, HashSet};
10use std::sync::{Mutex, OnceLock};
11
12use super::token_map::token_to_pua;
13use super::{Phonemizer, ProsodyFeature, ProsodyInfo};
14use crate::config::PhonemeIdMap;
15use crate::error::PiperError;
16
17// ---------------------------------------------------------------------------
18// UnicodeLanguageDetector
19// ---------------------------------------------------------------------------
20
21/// Detect language from Unicode character ranges.
22///
23/// Supports CJK disambiguation (JA vs ZH) by checking for kana presence.
24/// Latin characters are mapped to a configurable default language.
25pub struct UnicodeLanguageDetector {
26    languages: HashSet<String>,
27    default_latin_language: String,
28    has_ja: bool,
29    has_zh: bool,
30    has_ko: bool,
31}
32
33impl UnicodeLanguageDetector {
34    /// Create a new detector for the given set of languages.
35    ///
36    /// `default_latin_language` controls which language Latin-script
37    /// characters (A-Z, a-z, accented Latin) are assigned to.
38    pub fn new(languages: &[String], default_latin_language: &str) -> Self {
39        let lang_set: HashSet<String> = languages.iter().cloned().collect();
40        Self {
41            has_ja: lang_set.contains("ja"),
42            has_zh: lang_set.contains("zh"),
43            has_ko: lang_set.contains("ko"),
44            default_latin_language: default_latin_language.to_string(),
45            languages: lang_set,
46        }
47    }
48
49    /// Detect language for a single character.
50    ///
51    /// `context_has_kana` is used for CJK ideograph disambiguation: if the
52    /// surrounding text contains kana, CJK ideographs are classified as
53    /// Japanese rather than Chinese.
54    ///
55    /// Returns `None` for neutral characters (whitespace, digits,
56    /// ASCII punctuation, etc.).
57    pub fn detect_char(&self, ch: char, context_has_kana: bool) -> Option<&str> {
58        let cp = ch as u32;
59
60        // 1. Hiragana (U+3040-309F), Katakana (U+30A0-30FF),
61        //    Katakana Phonetic Extensions (U+31F0-31FF)
62        if (0x3040..=0x30FF).contains(&cp) || (0x31F0..=0x31FF).contains(&cp) {
63            return if self.has_ja { Some("ja") } else { None };
64        }
65
66        // 2. Hangul Syllables (U+AC00-D7AF), Jamo (U+1100-11FF),
67        //    Compatibility Jamo (U+3130-318F)
68        if (0xAC00..=0xD7AF).contains(&cp)
69            || (0x1100..=0x11FF).contains(&cp)
70            || (0x3130..=0x318F).contains(&cp)
71        {
72            return if self.has_ko { Some("ko") } else { None };
73        }
74
75        // 3. CJK Unified Ideographs (U+4E00-9FFF), Extension A (U+3400-4DBF),
76        //    Compatibility Ideographs (U+F900-FAFF)
77        if (0x4E00..=0x9FFF).contains(&cp)
78            || (0x3400..=0x4DBF).contains(&cp)
79            || (0xF900..=0xFAFF).contains(&cp)
80        {
81            if self.has_ja && self.has_zh {
82                return if context_has_kana {
83                    Some("ja")
84                } else {
85                    Some("zh")
86                };
87            }
88            if self.has_ja {
89                return Some("ja");
90            }
91            if self.has_zh {
92                return Some("zh");
93            }
94            return None;
95        }
96
97        // 4. Fullwidth Latin letters: U+FF21-FF3A (A-Z), U+FF41-FF5A (a-z)
98        if (0xFF21..=0xFF3A).contains(&cp) || (0xFF41..=0xFF5A).contains(&cp) {
99            return if self.languages.contains(&self.default_latin_language) {
100                Some(&self.default_latin_language)
101            } else {
102                None
103            };
104        }
105
106        // 5. CJK punctuation (U+3000-303F) and fullwidth forms
107        //    (U+FF00-FF20, U+FF3B-FF40, U+FF5B-FFEF),
108        //    excluding fullwidth Latin letters handled above.
109        if (0x3000..=0x303F).contains(&cp)
110            || (0xFF00..=0xFF20).contains(&cp)
111            || (0xFF3B..=0xFF40).contains(&cp)
112            || (0xFF5B..=0xFFEF).contains(&cp)
113        {
114            return if self.has_ja { Some("ja") } else { None };
115        }
116
117        // 6. Latin characters: A-Z, a-z, and extended Latin with diacritics
118        //    (U+00C0-00D6, U+00D8-00F6, U+00F8-00FF)
119        //    Excludes multiplication sign (U+00D7) and division sign (U+00F7).
120        if ch.is_ascii_alphabetic()
121            || (0x00C0..=0x00D6).contains(&cp)
122            || (0x00D8..=0x00F6).contains(&cp)
123            || (0x00F8..=0x00FF).contains(&cp)
124        {
125            return if self.languages.contains(&self.default_latin_language) {
126                Some(&self.default_latin_language)
127            } else {
128                None
129            };
130        }
131
132        // 7. Everything else: digits, ASCII punctuation, whitespace → neutral
133        None
134    }
135
136    /// Check if text contains any Hiragana or Katakana characters.
137    pub fn has_kana(&self, text: &str) -> bool {
138        text.chars().any(|ch| {
139            let cp = ch as u32;
140            (0x3040..=0x30FF).contains(&cp) || (0x31F0..=0x31FF).contains(&cp)
141        })
142    }
143}
144
145// ---------------------------------------------------------------------------
146// segment_text
147// ---------------------------------------------------------------------------
148
149/// Split text into `(language, segment_text)` pairs using Unicode detection.
150///
151/// Neutral characters (whitespace, digits, punctuation) are absorbed into
152/// the preceding language segment. If no language-specific characters are
153/// found (e.g., text is only digits), falls back to `default_latin_language`.
154pub fn segment_text(text: &str, detector: &UnicodeLanguageDetector) -> Vec<(String, String)> {
155    if text.trim().is_empty() {
156        return Vec::new();
157    }
158
159    let context_has_kana = detector.has_kana(text);
160
161    let mut segments: Vec<(String, String)> = Vec::new();
162    let mut current_lang: Option<&str> = None;
163    let mut current_chars = String::new();
164
165    for ch in text.chars() {
166        let lang = detector.detect_char(ch, context_has_kana);
167
168        if let Some(detected) = lang {
169            if let Some(prev) = current_lang
170                && detected != prev
171            {
172                // Language changed — flush the current segment
173                segments.push((prev.to_string(), std::mem::take(&mut current_chars)));
174                // current_chars is now empty String (no allocation needed for clear)
175            }
176            current_lang = Some(detected);
177        }
178        // If lang is None (neutral char), keep current_lang unchanged
179        // so the neutral char gets absorbed into the current segment.
180        current_chars.push(ch);
181    }
182
183    // Flush remaining characters
184    if let Some(lang) = current_lang
185        && !current_chars.is_empty()
186    {
187        segments.push((lang.to_string(), current_chars));
188    }
189
190    // Fallback: if no language-specific chars were detected, use default
191    if segments.is_empty() && !text.trim().is_empty() {
192        segments.push((detector.default_latin_language.clone(), text.to_string()));
193    }
194
195    segments
196}
197
198// ---------------------------------------------------------------------------
199// default_post_process_ids
200// ---------------------------------------------------------------------------
201
202/// Shared BOS/EOS/padding post-processing (espeak-ng compatible).
203///
204/// Used by EN, ZH, KO, ES, FR, PT phonemizers and by
205/// `MultilingualPhonemizer`. Inserts pad tokens between every phoneme ID
206/// and wraps with BOS (^) / EOS markers.
207///
208/// The `eos_token` parameter allows a dynamic EOS (e.g., `"?"` or PUA
209/// question markers from Japanese). Falls back to `"$"` when the
210/// requested token is not found in the map.
211pub fn default_post_process_ids(
212    ids: Vec<i64>,
213    prosody: Vec<Option<ProsodyFeature>>,
214    id_map: &PhonemeIdMap,
215    eos_token: &str,
216) -> (Vec<i64>, Vec<Option<ProsodyFeature>>) {
217    let pad_ids = id_map.get("_").cloned().unwrap_or_else(|| vec![0]);
218    let bos_ids = id_map.get("^");
219    let eos_ids = id_map.get(eos_token).or_else(|| id_map.get("$"));
220
221    // Intersperse: pad after every phoneme, but skip after existing pad
222    // tokens to match the training data padding scheme.
223    let mut padded_ids = Vec::with_capacity(ids.len() * 2);
224    let mut padded_prosody = Vec::with_capacity(ids.len() * 2);
225
226    for (id, p) in ids.iter().zip(prosody.iter()) {
227        padded_ids.push(*id);
228        padded_prosody.push(*p);
229        if !pad_ids.contains(id) {
230            padded_ids.extend_from_slice(&pad_ids);
231            padded_prosody.extend(std::iter::repeat_n(None, pad_ids.len()));
232        }
233    }
234
235    // Wrap with BOS
236    if let Some(bos) = bos_ids {
237        let mut with_bos_ids = Vec::with_capacity(bos.len() + 1 + padded_ids.len());
238        with_bos_ids.extend_from_slice(bos);
239        with_bos_ids.push(pad_ids[0]);
240        with_bos_ids.extend_from_slice(&padded_ids);
241        let mut with_bos_prosody = Vec::with_capacity(bos.len() + 1 + padded_prosody.len());
242        with_bos_prosody.extend(std::iter::repeat_n(None, bos.len() + 1));
243        with_bos_prosody.extend_from_slice(&padded_prosody);
244        padded_ids = with_bos_ids;
245        padded_prosody = with_bos_prosody;
246    }
247
248    // Append EOS
249    if let Some(eos) = eos_ids {
250        padded_ids.extend_from_slice(eos);
251        padded_prosody.extend(std::iter::repeat_n(None, eos.len()));
252    }
253
254    (padded_ids, padded_prosody)
255}
256
257// ---------------------------------------------------------------------------
258// PassthroughPhonemizer
259// ---------------------------------------------------------------------------
260
261/// A simple phonemizer that performs character-level tokenization.
262///
263/// Used for languages without a native Rust phonemizer (en, zh, ko, es, fr, pt).
264/// Each character becomes a separate token. Relies on the phoneme_id_map from
265/// config.json for ID conversion.
266pub struct PassthroughPhonemizer {
267    lang_code: String,
268}
269
270impl PassthroughPhonemizer {
271    /// Create a new passthrough phonemizer for the given language code.
272    pub fn new(lang_code: &str) -> Self {
273        Self {
274            lang_code: lang_code.to_string(),
275        }
276    }
277}
278
279impl Phonemizer for PassthroughPhonemizer {
280    fn phonemize_with_prosody(
281        &self,
282        text: &str,
283    ) -> Result<(Vec<String>, Vec<Option<ProsodyInfo>>), PiperError> {
284        let tokens: Vec<String> = text.chars().map(|c| c.to_string()).collect();
285        let prosody: Vec<Option<ProsodyInfo>> = vec![None; tokens.len()];
286        Ok((tokens, prosody))
287    }
288
289    fn get_phoneme_id_map(&self) -> Option<&PhonemeIdMap> {
290        None
291    }
292
293    fn post_process_ids(
294        &self,
295        ids: Vec<i64>,
296        prosody: Vec<Option<ProsodyFeature>>,
297        id_map: &PhonemeIdMap,
298    ) -> (Vec<i64>, Vec<Option<ProsodyFeature>>) {
299        default_post_process_ids(ids, prosody, id_map, "$")
300    }
301
302    fn language_code(&self) -> &str {
303        &self.lang_code
304    }
305}
306
307// ---------------------------------------------------------------------------
308// MultilingualPhonemizer
309// ---------------------------------------------------------------------------
310
311/// Phonemizer that handles code-switching between N languages.
312///
313/// Segments the input text by language using Unicode ranges, delegates to
314/// language-specific phonemizers, and concatenates results in a unified
315/// phoneme space.
316///
317/// `last_eos` is set by `phonemize_with_prosody` and read by
318/// `post_process_ids`. A `Mutex` provides interior mutability while
319/// satisfying the `Send + Sync` bounds required by the `Phonemizer` trait.
320pub struct MultilingualPhonemizer {
321    languages: Vec<String>,
322    default_latin_language: String,
323    detector: UnicodeLanguageDetector,
324    phonemizers: HashMap<String, Box<dyn Phonemizer>>,
325    /// Dynamic EOS token captured during the last `phonemize_with_prosody`
326    /// call. Consumed by `post_process_ids`.
327    last_eos: Mutex<String>,
328}
329
330impl MultilingualPhonemizer {
331    /// Create a new multilingual phonemizer.
332    ///
333    /// `languages` lists the supported language codes (e.g., `["ja", "en"]`).
334    /// Each must have a corresponding entry in `phonemizers`.
335    ///
336    /// `default_latin_language` controls which language Latin-script
337    /// characters are assigned to. If not present in `languages`, falls
338    /// back to the first language.
339    pub fn new(
340        languages: Vec<String>,
341        mut default_latin_language: String,
342        phonemizers: HashMap<String, Box<dyn Phonemizer>>,
343    ) -> Self {
344        // Validate default_latin_language is in the supported set
345        if !languages.contains(&default_latin_language) {
346            default_latin_language = languages
347                .first()
348                .cloned()
349                .unwrap_or_else(|| "en".to_string());
350        }
351
352        let detector = UnicodeLanguageDetector::new(&languages, &default_latin_language);
353
354        Self {
355            languages,
356            default_latin_language,
357            detector,
358            phonemizers,
359            last_eos: Mutex::new("$".to_string()),
360        }
361    }
362
363    /// Return the list of supported language codes.
364    pub fn languages(&self) -> &[String] {
365        &self.languages
366    }
367
368    /// Detect the primary language of the text.
369    ///
370    /// Returns the language code of the first detected language segment,
371    /// or the default_latin_language if no segments are detected.
372    pub fn detect_primary_language(&self, text: &str) -> &str {
373        let segments = segment_text(text, &self.detector);
374        if let Some((lang, _)) = segments.first() {
375            // Match against known language codes to return &str with correct lifetime
376            for supported in &self.languages {
377                if supported == lang {
378                    return supported.as_str();
379                }
380            }
381        }
382        &self.default_latin_language
383    }
384
385    /// Build the set of BOS/EOS-like tokens to strip from individual
386    /// segment outputs. Includes PUA-encoded Japanese question markers.
387    /// Cached via `OnceLock` to avoid re-constructing the `HashSet` on every call.
388    fn bos_eos_tokens() -> &'static HashSet<String> {
389        static TOKENS: OnceLock<HashSet<String>> = OnceLock::new();
390        TOKENS.get_or_init(|| {
391            let mut set = HashSet::new();
392            set.insert("^".to_string());
393            set.insert("$".to_string());
394            set.insert("?".to_string());
395            // PUA-encoded question markers (?!, ?., ?~)
396            for marker in &["?!", "?.", "?~"] {
397                if let Some(pua) = token_to_pua(marker) {
398                    set.insert(pua.to_string());
399                }
400            }
401            set
402        })
403    }
404
405    /// Build the set of EOS-like tokens (subset of BOS/EOS used to track
406    /// the last EOS for dynamic post-processing).
407    /// Cached via `OnceLock` to avoid re-constructing the `HashSet` on every call.
408    fn eos_tokens() -> &'static HashSet<String> {
409        static TOKENS: OnceLock<HashSet<String>> = OnceLock::new();
410        TOKENS.get_or_init(|| {
411            let mut set = HashSet::new();
412            set.insert("$".to_string());
413            set.insert("?".to_string());
414            for marker in &["?!", "?.", "?~"] {
415                if let Some(pua) = token_to_pua(marker) {
416                    set.insert(pua.to_string());
417                }
418            }
419            set
420        })
421    }
422}
423
424impl Phonemizer for MultilingualPhonemizer {
425    fn phonemize_with_prosody(
426        &self,
427        text: &str,
428    ) -> Result<(Vec<String>, Vec<Option<ProsodyInfo>>), PiperError> {
429        let segments = segment_text(text, &self.detector);
430        if segments.is_empty() {
431            return Ok((Vec::new(), Vec::new()));
432        }
433
434        let bos_eos = Self::bos_eos_tokens();
435        let eos_set = Self::eos_tokens();
436
437        let mut all_phonemes: Vec<String> = Vec::new();
438        let mut all_prosody: Vec<Option<ProsodyInfo>> = Vec::new();
439        let mut last_eos = "$".to_string();
440
441        for (lang, segment_text) in &segments {
442            let phonemizer = self
443                .phonemizers
444                .get(lang)
445                .ok_or_else(|| PiperError::UnsupportedLanguage { code: lang.clone() })?;
446
447            let (phonemes, prosody_list) = phonemizer.phonemize_with_prosody(segment_text)?;
448
449            // Strip BOS/EOS from individual segments.
450            // This includes PUA-encoded question markers from Japanese.
451            for (ph, pr) in phonemes.iter().zip(prosody_list.iter()) {
452                if bos_eos.contains(ph) {
453                    if eos_set.contains(ph) {
454                        last_eos = ph.clone();
455                    }
456                    continue;
457                }
458                all_phonemes.push(ph.clone());
459                all_prosody.push(*pr);
460            }
461        }
462
463        // Update last_eos via interior mutability (Mutex).
464        if let Ok(mut guard) = self.last_eos.lock() {
465            *guard = last_eos;
466        }
467
468        Ok((all_phonemes, all_prosody))
469    }
470
471    fn get_phoneme_id_map(&self) -> Option<&PhonemeIdMap> {
472        // Multilingual uses the phoneme_id_map from config.json
473        // (the unified multilingual map generated during preprocessing).
474        None
475    }
476
477    fn post_process_ids(
478        &self,
479        ids: Vec<i64>,
480        prosody: Vec<Option<ProsodyFeature>>,
481        id_map: &PhonemeIdMap,
482    ) -> (Vec<i64>, Vec<Option<ProsodyFeature>>) {
483        let eos = self
484            .last_eos
485            .lock()
486            .map(|g| g.clone())
487            .unwrap_or_else(|_| "$".to_string());
488        default_post_process_ids(ids, prosody, id_map, &eos)
489    }
490
491    fn language_code(&self) -> &str {
492        // Return the default Latin language for multi-language mode.
493        &self.default_latin_language
494    }
495
496    fn detect_primary_language(&self, text: &str) -> &str {
497        // Delegate to the inherent method
498        MultilingualPhonemizer::detect_primary_language(self, text)
499    }
500}
501
502// ---------------------------------------------------------------------------
503// Tests
504// ---------------------------------------------------------------------------
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509
510    // ===== UnicodeLanguageDetector =====
511
512    fn make_detector(langs: &[&str], default_latin: &str) -> UnicodeLanguageDetector {
513        let lang_strings: Vec<String> = langs.iter().map(|s| s.to_string()).collect();
514        UnicodeLanguageDetector::new(&lang_strings, default_latin)
515    }
516
517    #[test]
518    fn test_detect_hiragana_as_ja() {
519        let det = make_detector(&["ja", "en"], "en");
520        assert_eq!(det.detect_char('\u{3042}', false), Some("ja")); // あ
521        assert_eq!(det.detect_char('\u{3093}', false), Some("ja")); // ん
522    }
523
524    #[test]
525    fn test_detect_katakana_as_ja() {
526        let det = make_detector(&["ja", "en"], "en");
527        assert_eq!(det.detect_char('\u{30A2}', false), Some("ja")); // ア
528        assert_eq!(det.detect_char('\u{30F3}', false), Some("ja")); // ン
529    }
530
531    #[test]
532    fn test_detect_katakana_phonetic_ext_as_ja() {
533        let det = make_detector(&["ja", "en"], "en");
534        assert_eq!(det.detect_char('\u{31F0}', false), Some("ja")); // ㇰ
535    }
536
537    #[test]
538    fn test_detect_hangul_as_ko() {
539        let det = make_detector(&["ja", "en", "ko"], "en");
540        assert_eq!(det.detect_char('\u{AC00}', false), Some("ko")); // 가
541        assert_eq!(det.detect_char('\u{D7AF}', false), Some("ko")); // last hangul syllable
542    }
543
544    #[test]
545    fn test_detect_hangul_jamo_as_ko() {
546        let det = make_detector(&["ko", "en"], "en");
547        assert_eq!(det.detect_char('\u{1100}', false), Some("ko")); // ᄀ
548        assert_eq!(det.detect_char('\u{3131}', false), Some("ko")); // ㄱ (compat)
549    }
550
551    #[test]
552    fn test_detect_cjk_as_zh_without_kana() {
553        let det = make_detector(&["ja", "en", "zh"], "en");
554        // CJK ideograph, no kana context → Chinese
555        assert_eq!(det.detect_char('\u{4E16}', false), Some("zh")); // 世
556    }
557
558    #[test]
559    fn test_detect_cjk_as_ja_with_kana_context() {
560        let det = make_detector(&["ja", "en", "zh"], "en");
561        // CJK ideograph, kana context → Japanese
562        assert_eq!(det.detect_char('\u{4E16}', true), Some("ja")); // 世
563    }
564
565    #[test]
566    fn test_detect_cjk_ja_only() {
567        let det = make_detector(&["ja", "en"], "en");
568        // Only JA is available, no ZH → always JA regardless of context
569        assert_eq!(det.detect_char('\u{4E16}', false), Some("ja"));
570    }
571
572    #[test]
573    fn test_detect_cjk_zh_only() {
574        let det = make_detector(&["zh", "en"], "en");
575        // Only ZH is available → always ZH
576        assert_eq!(det.detect_char('\u{4E16}', true), Some("zh"));
577    }
578
579    #[test]
580    fn test_detect_fullwidth_latin_as_default_latin() {
581        let det = make_detector(&["ja", "en"], "en");
582        assert_eq!(det.detect_char('\u{FF21}', false), Some("en")); // A
583        assert_eq!(det.detect_char('\u{FF5A}', false), Some("en")); // z
584    }
585
586    #[test]
587    fn test_detect_cjk_punctuation_as_ja() {
588        let det = make_detector(&["ja", "en"], "en");
589        assert_eq!(det.detect_char('\u{3001}', false), Some("ja")); // 、
590        assert_eq!(det.detect_char('\u{3002}', false), Some("ja")); // 。
591        assert_eq!(det.detect_char('\u{300C}', false), Some("ja")); // 「
592    }
593
594    #[test]
595    fn test_detect_latin_as_default_language() {
596        let det = make_detector(&["ja", "en"], "en");
597        assert_eq!(det.detect_char('H', false), Some("en"));
598        assert_eq!(det.detect_char('z', false), Some("en"));
599    }
600
601    #[test]
602    fn test_detect_accented_latin() {
603        let det = make_detector(&["ja", "fr"], "fr");
604        assert_eq!(det.detect_char('\u{00E9}', false), Some("fr")); // é
605        assert_eq!(det.detect_char('\u{00C0}', false), Some("fr")); // À
606    }
607
608    #[test]
609    fn test_detect_neutral_characters() {
610        let det = make_detector(&["ja", "en"], "en");
611        assert_eq!(det.detect_char(' ', false), None);
612        assert_eq!(det.detect_char('0', false), None);
613        assert_eq!(det.detect_char('!', false), None);
614        assert_eq!(det.detect_char('.', false), None);
615        assert_eq!(det.detect_char(',', false), None);
616    }
617
618    #[test]
619    fn test_detect_multiplication_sign_is_neutral() {
620        let det = make_detector(&["ja", "en"], "en");
621        // U+00D7 (×) is in the range but excluded from Latin
622        assert_eq!(det.detect_char('\u{00D7}', false), None);
623    }
624
625    #[test]
626    fn test_has_kana() {
627        let det = make_detector(&["ja", "en"], "en");
628        assert!(det.has_kana("こんにちは world"));
629        assert!(det.has_kana("アイウ"));
630        assert!(!det.has_kana("Hello world"));
631        assert!(!det.has_kana("你好世界"));
632    }
633
634    // ===== segment_text =====
635
636    #[test]
637    fn test_segment_pure_japanese() {
638        let det = make_detector(&["ja", "en"], "en");
639        let segs = segment_text("こんにちは", &det);
640        assert_eq!(segs.len(), 1);
641        assert_eq!(segs[0].0, "ja");
642        assert_eq!(segs[0].1, "こんにちは");
643    }
644
645    #[test]
646    fn test_segment_pure_english() {
647        let det = make_detector(&["ja", "en"], "en");
648        let segs = segment_text("Hello world", &det);
649        assert_eq!(segs.len(), 1);
650        assert_eq!(segs[0].0, "en");
651        assert_eq!(segs[0].1, "Hello world");
652    }
653
654    #[test]
655    fn test_segment_mixed_ja_en() {
656        let det = make_detector(&["ja", "en"], "en");
657        let segs = segment_text("今日はgood morningですね", &det);
658        assert_eq!(segs.len(), 3);
659        assert_eq!(segs[0].0, "ja");
660        assert_eq!(segs[0].1, "今日は");
661        assert_eq!(segs[1].0, "en");
662        assert_eq!(segs[1].1, "good morning");
663        assert_eq!(segs[2].0, "ja");
664        assert_eq!(segs[2].1, "ですね");
665    }
666
667    #[test]
668    fn test_segment_neutral_absorbed_into_preceding() {
669        let det = make_detector(&["ja", "en"], "en");
670        // "Hello, " — comma and space are neutral, absorbed into English
671        let segs = segment_text("Hello, こんにちは", &det);
672        assert_eq!(segs.len(), 2);
673        assert_eq!(segs[0].0, "en");
674        assert_eq!(segs[0].1, "Hello, ");
675        assert_eq!(segs[1].0, "ja");
676        assert_eq!(segs[1].1, "こんにちは");
677    }
678
679    #[test]
680    fn test_segment_leading_neutral_absorbed_into_first_language() {
681        let det = make_detector(&["ja", "en"], "en");
682        // Leading "123 " are neutral — no preceding segment, so they get
683        // absorbed into whatever language comes first.
684        let segs = segment_text("123 Hello", &det);
685        assert_eq!(segs.len(), 1);
686        assert_eq!(segs[0].0, "en");
687        assert_eq!(segs[0].1, "123 Hello");
688    }
689
690    #[test]
691    fn test_segment_empty_string() {
692        let det = make_detector(&["ja", "en"], "en");
693        let segs = segment_text("", &det);
694        assert!(segs.is_empty());
695    }
696
697    #[test]
698    fn test_segment_whitespace_only() {
699        let det = make_detector(&["ja", "en"], "en");
700        let segs = segment_text("   ", &det);
701        assert!(segs.is_empty());
702    }
703
704    #[test]
705    fn test_segment_digits_only_fallback() {
706        let det = make_detector(&["ja", "en"], "en");
707        // No language-specific characters — falls back to default
708        let segs = segment_text("12345", &det);
709        assert_eq!(segs.len(), 1);
710        assert_eq!(segs[0].0, "en");
711        assert_eq!(segs[0].1, "12345");
712    }
713
714    #[test]
715    fn test_segment_cjk_disambiguation_with_kana() {
716        let det = make_detector(&["ja", "en", "zh"], "en");
717        // Text with kana + CJK ideographs: the ideographs become JA
718        let segs = segment_text("漢字とかな", &det);
719        assert_eq!(segs.len(), 1);
720        assert_eq!(segs[0].0, "ja");
721    }
722
723    #[test]
724    fn test_segment_cjk_without_kana_is_zh() {
725        let det = make_detector(&["ja", "en", "zh"], "en");
726        // Pure CJK ideographs without kana → Chinese
727        let segs = segment_text("你好世界", &det);
728        assert_eq!(segs.len(), 1);
729        assert_eq!(segs[0].0, "zh");
730    }
731
732    #[test]
733    fn test_segment_mixed_zh_en() {
734        let det = make_detector(&["zh", "en"], "en");
735        let segs = segment_text("Hello你好", &det);
736        assert_eq!(segs.len(), 2);
737        assert_eq!(segs[0].0, "en");
738        assert_eq!(segs[0].1, "Hello");
739        assert_eq!(segs[1].0, "zh");
740        assert_eq!(segs[1].1, "你好");
741    }
742
743    // ===== default_post_process_ids =====
744
745    fn make_id_map() -> PhonemeIdMap {
746        let mut m = HashMap::new();
747        m.insert("_".to_string(), vec![0]);
748        m.insert("^".to_string(), vec![1]);
749        m.insert("$".to_string(), vec![2]);
750        m.insert("?".to_string(), vec![3]);
751        m
752    }
753
754    #[test]
755    fn test_post_process_basic_padding() {
756        let id_map = make_id_map();
757        let ids = vec![10, 11, 12];
758        let prosody = vec![None, None, None];
759        let (out_ids, out_prosody) = default_post_process_ids(ids, prosody, &id_map, "$");
760        // Expected: ^(1) + pad(0) + 10 + pad(0) + 11 + pad(0) + 12 + pad(0) + $(2)
761        assert_eq!(out_ids, vec![1, 0, 10, 0, 11, 0, 12, 0, 2]);
762        assert_eq!(out_prosody.len(), out_ids.len());
763    }
764
765    #[test]
766    fn test_post_process_skip_padding_after_pad_token() {
767        let id_map = make_id_map();
768        // ID 0 is a pad token — should NOT get another pad after it
769        let ids = vec![10, 0, 12];
770        let prosody = vec![None, None, None];
771        let (out_ids, _) = default_post_process_ids(ids, prosody, &id_map, "$");
772        // Expected: ^(1) + pad(0) + 10 + pad(0) + 0 (no pad after) + 12 + pad(0) + $(2)
773        assert_eq!(out_ids, vec![1, 0, 10, 0, 0, 12, 0, 2]);
774    }
775
776    #[test]
777    fn test_post_process_with_question_eos() {
778        let id_map = make_id_map();
779        let ids = vec![10];
780        let prosody = vec![None];
781        let (out_ids, _) = default_post_process_ids(ids, prosody, &id_map, "?");
782        // Expected: ^(1) + pad(0) + 10 + pad(0) + ?(3)
783        assert_eq!(out_ids, vec![1, 0, 10, 0, 3]);
784    }
785
786    #[test]
787    fn test_post_process_eos_fallback_to_dollar() {
788        let id_map = make_id_map();
789        let ids = vec![10];
790        let prosody = vec![None];
791        // Request EOS token "nonexistent" — should fall back to "$"
792        let (out_ids, _) = default_post_process_ids(ids, prosody, &id_map, "nonexistent");
793        // Expected: ^(1) + pad(0) + 10 + pad(0) + $(2)
794        assert_eq!(out_ids, vec![1, 0, 10, 0, 2]);
795    }
796
797    #[test]
798    fn test_post_process_empty_input() {
799        let id_map = make_id_map();
800        let ids: Vec<i64> = Vec::new();
801        let prosody: Vec<Option<ProsodyFeature>> = Vec::new();
802        let (out_ids, out_prosody) = default_post_process_ids(ids, prosody, &id_map, "$");
803        // Expected: ^(1) + pad(0) + $(2)
804        assert_eq!(out_ids, vec![1, 0, 2]);
805        assert_eq!(out_prosody.len(), out_ids.len());
806    }
807
808    #[test]
809    fn test_post_process_prosody_propagated() {
810        let id_map = make_id_map();
811        let ids = vec![10, 11];
812        let prosody = vec![Some([1, 2, 3]), None];
813        let (out_ids, out_prosody) = default_post_process_ids(ids, prosody, &id_map, "$");
814        // ^=None pad=None 10=Some([1,2,3]) pad=None 11=None pad=None $=None
815        assert_eq!(out_ids, vec![1, 0, 10, 0, 11, 0, 2]);
816        assert!(out_prosody[0].is_none()); // ^
817        assert!(out_prosody[1].is_none()); // pad
818        assert_eq!(out_prosody[2], Some([1, 2, 3])); // phoneme 10
819        assert!(out_prosody[3].is_none()); // pad
820        assert!(out_prosody[4].is_none()); // phoneme 11
821        assert!(out_prosody[5].is_none()); // pad
822        assert!(out_prosody[6].is_none()); // $
823    }
824
825    // ===== BOS/EOS token sets =====
826
827    #[test]
828    fn test_bos_eos_tokens_include_pua_markers() {
829        let set = MultilingualPhonemizer::bos_eos_tokens();
830        assert!(set.contains("^"));
831        assert!(set.contains("$"));
832        assert!(set.contains("?"));
833        // PUA markers for ?!, ?., ?~
834        assert!(set.contains(&"\u{E016}".to_string())); // ?!
835        assert!(set.contains(&"\u{E017}".to_string())); // ?.
836        assert!(set.contains(&"\u{E018}".to_string())); // ?~
837    }
838
839    #[test]
840    fn test_eos_tokens_subset() {
841        let eos_set = MultilingualPhonemizer::eos_tokens();
842        let bos_eos_set = MultilingualPhonemizer::bos_eos_tokens();
843        // EOS set should be a subset of BOS/EOS set
844        for token in eos_set {
845            assert!(
846                bos_eos_set.contains(token),
847                "EOS token {:?} not in BOS/EOS set",
848                token
849            );
850        }
851        // BOS (^) should be in bos_eos but NOT in eos
852        assert!(!eos_set.contains("^"));
853    }
854
855    // ===== Integration: post_process_ids via trait =====
856
857    #[test]
858    fn test_post_process_ids_and_prosody_lengths_match() {
859        let id_map = make_id_map();
860        let ids = vec![5, 6, 7, 8, 9];
861        let prosody: Vec<Option<ProsodyFeature>> =
862            vec![Some([1, 0, 3]), None, Some([0, 2, 4]), None, None];
863        let (out_ids, out_prosody) = default_post_process_ids(ids, prosody, &id_map, "$");
864        assert_eq!(
865            out_ids.len(),
866            out_prosody.len(),
867            "IDs ({}) and prosody ({}) length mismatch",
868            out_ids.len(),
869            out_prosody.len()
870        );
871    }
872}