Skip to main content

piper_plus/
voice.rs

1//! PiperVoice — テキストから音声への高レベル API
2//!
3//! テキスト入力 → 音素化 → ID 変換 → ONNX 推論 → WAV 出力
4
5use std::path::Path;
6
7use crate::config::VoiceConfig;
8use crate::engine::{OnnxEngine, SynthesisRequest, SynthesisResult};
9use crate::error::PiperError;
10use crate::phonemize::Phonemizer;
11use crate::phonemize::phoneme_converter;
12
13/// テキストから音声を合成する高レベル API
14pub struct PiperVoice {
15    config: VoiceConfig,
16    engine: OnnxEngine,
17    phonemizer: Box<dyn Phonemizer>,
18}
19
20impl PiperVoice {
21    /// モデルとconfigを読み込んで初期化
22    ///
23    /// phoneme_type に基づいて適切な Phonemizer を自動選択:
24    /// - OpenJTalk → JapanesePhonemizer (feature = "japanese")
25    /// - Bilingual/Multilingual → MultilingualPhonemizer (Unicode言語検出)
26    pub fn load(
27        model_path: &Path,
28        config_path: Option<&Path>,
29        device: &str,
30    ) -> Result<Self, PiperError> {
31        let resolved_config = VoiceConfig::resolve_config_path(model_path, config_path)?;
32        let config = VoiceConfig::load(&resolved_config)?;
33        let model_dir = model_path.parent().map(|p| p.to_path_buf());
34        let phonemizer = Self::create_phonemizer(&config, model_dir.as_deref())?;
35        let engine = OnnxEngine::load(model_path, &config, device)?;
36
37        Ok(Self {
38            config,
39            engine,
40            phonemizer,
41        })
42    }
43
44    /// phoneme_type に基づいて Phonemizer を生成する。
45    ///
46    /// `model_dir` はモデルファイルの親ディレクトリ。辞書ファイルの検索に使用。
47    /// テスト容易性のため独立関数として切り出し。
48    /// `--test-mode` (CLI) で ONNX エンジンなしに phonemizer のみ使用する場合にも利用。
49    pub fn create_phonemizer(
50        config: &VoiceConfig,
51        model_dir: Option<&Path>,
52    ) -> Result<Box<dyn Phonemizer>, PiperError> {
53        match config.phoneme_type {
54            #[cfg(feature = "japanese")]
55            crate::config::PhonemeType::OpenJTalk => {
56                Ok(Box::new(Self::create_japanese_phonemizer()?))
57            }
58            crate::config::PhonemeType::Bilingual | crate::config::PhonemeType::Multilingual => {
59                // Extract language codes from language_id_map
60                let mut languages: Vec<String> = config.language_id_map.keys().cloned().collect();
61                languages.sort(); // canonical order
62
63                if languages.is_empty() {
64                    return Err(PiperError::InvalidConfig {
65                        reason: "multilingual model requires language_id_map".to_string(),
66                    });
67                }
68
69                // Determine default Latin language
70                let default_latin = if languages.contains(&"en".to_string()) {
71                    "en".to_string()
72                } else {
73                    languages
74                        .iter()
75                        .find(|l| matches!(l.as_str(), "es" | "fr" | "pt"))
76                        .cloned()
77                        .unwrap_or_else(|| languages[0].clone())
78                };
79
80                // Build per-language phonemizers
81                let mut phonemizers: std::collections::HashMap<String, Box<dyn Phonemizer>> =
82                    std::collections::HashMap::new();
83
84                for lang in &languages {
85                    let phonemizer: Box<dyn Phonemizer> =
86                        Self::create_language_phonemizer(lang, model_dir)?;
87                    phonemizers.insert(lang.clone(), phonemizer);
88                }
89
90                Ok(Box::new(
91                    crate::phonemize::multilingual::MultilingualPhonemizer::new(
92                        languages,
93                        default_latin,
94                        phonemizers,
95                    ),
96                ))
97            }
98            _ => Err(PiperError::UnsupportedLanguage {
99                code: format!("{:?}", config.phoneme_type),
100            }),
101        }
102    }
103
104    /// 言語コードに基づいて適切な Phonemizer を生成する。
105    ///
106    /// 各言語の専用 Phonemizer を使用し、辞書が必要な言語 (ja, en, zh) は
107    /// `model_dir` 配下またはデフォルトパスから辞書を検索する。
108    /// JA は dictionary_manager による自動ダウンロードも対応。
109    /// 辞書が見つからない場合は PassthroughPhonemizer にフォールバックする。
110    fn create_language_phonemizer(
111        lang: &str,
112        model_dir: Option<&Path>,
113    ) -> Result<Box<dyn Phonemizer>, PiperError> {
114        match lang {
115            #[cfg(feature = "japanese")]
116            "ja" => match Self::create_japanese_phonemizer() {
117                Ok(p) => Ok(Box::new(p)),
118                Err(e) => {
119                    tracing::warn!("Japanese phonemizer unavailable ({}), using passthrough", e);
120                    Ok(Box::new(
121                        crate::phonemize::multilingual::PassthroughPhonemizer::new(lang),
122                    ))
123                }
124            },
125            "en" => match Self::create_english_phonemizer(model_dir) {
126                Ok(p) => Ok(Box::new(p)),
127                Err(e) => {
128                    tracing::warn!("English phonemizer unavailable ({}), using passthrough", e);
129                    Ok(Box::new(
130                        crate::phonemize::multilingual::PassthroughPhonemizer::new(lang),
131                    ))
132                }
133            },
134            "zh" => match Self::create_chinese_phonemizer(model_dir) {
135                Ok(p) => Ok(Box::new(p)),
136                Err(e) => {
137                    tracing::warn!("Chinese phonemizer unavailable ({}), using passthrough", e);
138                    Ok(Box::new(
139                        crate::phonemize::multilingual::PassthroughPhonemizer::new(lang),
140                    ))
141                }
142            },
143            "es" => Ok(Box::new(crate::phonemize::spanish::SpanishPhonemizer::new())),
144            "fr" => Ok(Box::new(crate::phonemize::french::FrenchPhonemizer::new())),
145            "pt" => Ok(Box::new(
146                crate::phonemize::portuguese::PortuguesePhonemizer::new(),
147            )),
148            "ko" => Ok(Box::new(crate::phonemize::korean::KoreanPhonemizer::new())),
149            _ => Ok(Box::new(
150                crate::phonemize::multilingual::PassthroughPhonemizer::new(lang),
151            )),
152        }
153    }
154
155    /// EnglishPhonemizer を生成する。
156    ///
157    /// CMU辞書を以下の順で検索:
158    /// 1. `CMUDICT_PATH` 環境変数
159    /// 2. `{model_dir}/cmudict_data.json`
160    /// 3. `./cmudict_data.json`
161    /// 4. `/usr/share/piper/cmudict_data.json`
162    fn create_english_phonemizer(
163        model_dir: Option<&Path>,
164    ) -> Result<crate::phonemize::english::EnglishPhonemizer, PiperError> {
165        // Try model_dir first if available
166        if let Some(dir) = model_dir {
167            let model_dict = dir.join("cmudict_data.json");
168            if model_dict.exists() {
169                return crate::phonemize::english::EnglishPhonemizer::new_with_dict(&model_dict);
170            }
171        }
172        // Fall back to default search (env var, local, system)
173        crate::phonemize::english::EnglishPhonemizer::new()
174    }
175
176    /// ChinesePhonemizer を生成する。
177    ///
178    /// Pinyin辞書を以下の順で検索:
179    /// 1. `PINYIN_SINGLE_PATH` / `PINYIN_PHRASES_PATH` 環境変数
180    /// 2. `{model_dir}/pinyin_single.json` + `{model_dir}/pinyin_phrases.json`
181    /// 3. `./pinyin_single.json` + `./pinyin_phrases.json`
182    fn create_chinese_phonemizer(
183        model_dir: Option<&Path>,
184    ) -> Result<crate::phonemize::chinese::ChinesePhonemizer, PiperError> {
185        // 1. Environment variable override
186        if let (Ok(single), Ok(phrases)) = (
187            std::env::var("PINYIN_SINGLE_PATH"),
188            std::env::var("PINYIN_PHRASES_PATH"),
189        ) {
190            let sp = std::path::PathBuf::from(&single);
191            let pp = std::path::PathBuf::from(&phrases);
192            if sp.exists() && pp.exists() {
193                return crate::phonemize::chinese::ChinesePhonemizer::new(&sp, &pp);
194            }
195        }
196
197        // 2. model_dir
198        if let Some(dir) = model_dir {
199            let single = dir.join("pinyin_single.json");
200            let phrases = dir.join("pinyin_phrases.json");
201            if single.exists() && phrases.exists() {
202                return crate::phonemize::chinese::ChinesePhonemizer::new(&single, &phrases);
203            }
204        }
205
206        // 3. Local development path
207        let single = std::path::PathBuf::from("pinyin_single.json");
208        let phrases = std::path::PathBuf::from("pinyin_phrases.json");
209        if single.exists() && phrases.exists() {
210            return crate::phonemize::chinese::ChinesePhonemizer::new(&single, &phrases);
211        }
212
213        Err(PiperError::DictionaryLoad {
214            path: "pinyin_single.json / pinyin_phrases.json not found. \
215                   Place dictionaries next to the model or set PINYIN_SINGLE_PATH / PINYIN_PHRASES_PATH env vars"
216                .to_string(),
217        })
218    }
219
220    /// テキストを音声に変換
221    ///
222    /// `language_override` を指定すると、phonemizer の自動検出を上書きして
223    /// 指定言語の language_id を使用する。多言語モデルで特定言語を強制する場合に使用。
224    pub fn synthesize_text(
225        &mut self,
226        text: &str,
227        speaker_id: Option<i64>,
228        language_override: Option<&str>,
229        noise_scale: f32,
230        length_scale: f32,
231        noise_w: f32,
232    ) -> Result<SynthesisResult, PiperError> {
233        // 1. Phonemize: テキストをトークン列 + プロソディ情報に変換
234        let (tokens, prosody) = self.phonemizer.phonemize_with_prosody(text)?;
235
236        // 2. Convert tokens to IDs using phoneme_id_map
237        let phoneme_id_map = self
238            .phonemizer
239            .get_phoneme_id_map()
240            .unwrap_or(&self.config.phoneme_id_map);
241
242        let ids = phoneme_converter::tokens_to_ids(&tokens, phoneme_id_map)?;
243        let prosody_feats = prosody_to_optional_features(&prosody);
244
245        // 3. Post-process IDs (BOS/EOS/padding insertion, language-specific)
246        let (ids, prosody_feats) =
247            self.phonemizer
248                .post_process_ids(ids, prosody_feats, phoneme_id_map);
249
250        // 4. Build prosody tensor directly from post-processed features
251        //    (single pass: Option<ProsodyFeature>[] → Option<Vec<ProsodyFeature>>)
252        let prosody_tensor = build_prosody_tensor(&prosody_feats);
253
254        // 5. Determine language_id from config
255        //    language_override が指定されていればそちらを優先。
256        //    多言語モデルの場合、テキストの最初の言語セグメントを自動検出して language_id を決定。
257        //    単言語モデルの場合は phonemizer の言語コードを使用。
258        let language_id = if self.config.needs_lid() {
259            let lang_code = if let Some(ovr) = language_override {
260                ovr
261            } else {
262                self.detect_language(text)
263            };
264            Some(
265                self.config
266                    .language_id_map
267                    .get(lang_code)
268                    .copied()
269                    .unwrap_or(0),
270            )
271        } else {
272            None
273        };
274
275        // 6. Build request and run inference
276        let request = SynthesisRequest {
277            phoneme_ids: ids,
278            prosody_features: prosody_tensor,
279            speaker_id,
280            language_id,
281            noise_scale,
282            length_scale,
283            noise_w,
284        };
285
286        self.engine.synthesize(&request)
287    }
288
289    /// テキストを音素化して phoneme IDs を返す (ONNX 推論なし)
290    ///
291    /// `--test-mode` (CI用) で phonemization パイプラインのみ検証する場合に使用。
292    pub fn phonemize_to_ids(&self, text: &str) -> Result<Vec<i64>, PiperError> {
293        let (tokens, prosody) = self.phonemizer.phonemize_with_prosody(text)?;
294
295        let phoneme_id_map = self
296            .phonemizer
297            .get_phoneme_id_map()
298            .unwrap_or(&self.config.phoneme_id_map);
299
300        let ids = phoneme_converter::tokens_to_ids(&tokens, phoneme_id_map)?;
301        let prosody_feats = prosody_to_optional_features(&prosody);
302
303        let (ids, _prosody_feats) =
304            self.phonemizer
305                .post_process_ids(ids, prosody_feats, phoneme_id_map);
306
307        Ok(ids)
308    }
309
310    /// テキストを WAV ファイルに出力 (デフォルトパラメータ使用)
311    pub fn text_to_wav_file(
312        &mut self,
313        text: &str,
314        output: &Path,
315        speaker_id: Option<i64>,
316    ) -> Result<SynthesisResult, PiperError> {
317        let result = self.synthesize_text(text, speaker_id, None, 0.667, 1.0, 0.8)?;
318        crate::audio::write_wav(output, result.sample_rate, &result.audio)?;
319        Ok(result)
320    }
321
322    /// テキストの主要言語を検出する。
323    ///
324    /// 多言語/バイリンガルモデルの場合、`MultilingualPhonemizer` の
325    /// `detect_primary_language` を使用して最初の言語セグメントを検出。
326    /// 単言語モデルの場合は phonemizer の `language_code()` にフォールバック。
327    fn detect_language(&self, text: &str) -> &str {
328        self.phonemizer.detect_primary_language(text)
329    }
330
331    /// JapanesePhonemizer を生成する。
332    ///
333    /// `naist-jdic` feature が有効なら bundled 辞書を使用し、
334    /// 無効なら `dictionary_manager::ensure_dictionary()` で外部辞書を
335    /// 自動検索・ダウンロードする。
336    #[cfg(feature = "japanese")]
337    fn create_japanese_phonemizer()
338    -> Result<crate::phonemize::japanese::JapanesePhonemizer, PiperError> {
339        #[cfg(feature = "naist-jdic")]
340        {
341            crate::phonemize::japanese::JapanesePhonemizer::new_bundled()
342        }
343        #[cfg(not(feature = "naist-jdic"))]
344        {
345            // Try dictionary_manager first (searches standard paths + auto-download)
346            match crate::dictionary_manager::ensure_dictionary() {
347                Ok(dict_path) => {
348                    tracing::info!("Using OpenJTalk dictionary from {}", dict_path.display());
349                    crate::phonemize::japanese::JapanesePhonemizer::new_with_dict(&dict_path)
350                }
351                Err(e) => {
352                    tracing::warn!(
353                        "dictionary_manager failed ({}), falling back to JapanesePhonemizer::new()",
354                        e
355                    );
356                    // Fall back to jpreprocess's own dictionary search
357                    crate::phonemize::japanese::JapanesePhonemizer::new()
358                }
359            }
360        }
361    }
362
363    /// config への参照を返す
364    pub fn config(&self) -> &VoiceConfig {
365        &self.config
366    }
367
368    /// engine への参照を返す
369    pub fn engine(&self) -> &OnnxEngine {
370        &self.engine
371    }
372}
373
374// ---------------------------------------------------------------------------
375// ヘルパー関数
376// ---------------------------------------------------------------------------
377
378/// ProsodyInfo 列を Option<ProsodyFeature> 列に変換する。
379///
380/// `synthesize_text` で phonemizer の `post_process_ids` に渡すための中間形式。
381fn prosody_to_optional_features(
382    prosody: &[Option<crate::phonemize::ProsodyInfo>],
383) -> Vec<Option<crate::phonemize::ProsodyFeature>> {
384    prosody
385        .iter()
386        .map(|p| p.map(|info| [info.a1, info.a2, info.a3]))
387        .collect()
388}
389
390/// Optional prosody features を ONNX 入力用の Vec<[i32; 3]> に変換する。
391///
392/// いずれかの要素が Some なら全体を Some(Vec) として返す。
393/// 全要素が None なら None を返す (prosody テンソル不要)。
394fn build_prosody_tensor(
395    features: &[Option<crate::phonemize::ProsodyFeature>],
396) -> Option<Vec<crate::phonemize::ProsodyFeature>> {
397    if features.iter().any(|p| p.is_some()) {
398        Some(features.iter().map(|p| p.unwrap_or([0, 0, 0])).collect())
399    } else {
400        None
401    }
402}
403
404/// ProsodyInfo 列から ONNX 入力用の Option<Vec<[i32; 3]>> に直接変換する。
405///
406/// `prosody_to_optional_features` + `build_prosody_tensor` を 1 パスに統合。
407/// 中間の `Vec<Option<[i32; 3]>>` を生成せず、いずれかが Some なら
408/// Some(Vec<[i32; 3]>) を返す。全て None なら None を返す。
409#[cfg(test)]
410fn build_prosody_direct(
411    prosody: &[Option<crate::phonemize::ProsodyInfo>],
412) -> Option<Vec<crate::phonemize::ProsodyFeature>> {
413    if prosody.iter().any(|p| p.is_some()) {
414        Some(
415            prosody
416                .iter()
417                .map(|p| match p {
418                    Some(info) => [info.a1, info.a2, info.a3],
419                    None => [0, 0, 0],
420                })
421                .collect(),
422        )
423    } else {
424        None
425    }
426}
427
428// ---------------------------------------------------------------------------
429// テスト
430// ---------------------------------------------------------------------------
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use crate::config::PhonemeType;
436    use crate::engine::SynthesisRequest;
437    use crate::phonemize::ProsodyInfo;
438    use std::collections::HashMap;
439
440    /// Helper: extract PiperError from a Result, panicking if Ok.
441    fn expect_err<T>(result: Result<T, PiperError>) -> PiperError {
442        match result {
443            Err(e) => e,
444            Ok(_) => panic!("expected Err, got Ok"),
445        }
446    }
447
448    // -----------------------------------------------------------------------
449    // 1. PiperVoice::load fails gracefully with missing model file
450    // -----------------------------------------------------------------------
451    #[test]
452    fn test_load_fails_with_missing_model() {
453        let result = PiperVoice::load(Path::new("/nonexistent/model.onnx"), None, "cpu");
454        let err = expect_err(result);
455        // config が見つからないためエラーになる
456        let msg = format!("{err}");
457        assert!(
458            msg.contains("config") || msg.contains("not found") || msg.contains("Config"),
459            "unexpected error message: {msg}"
460        );
461    }
462
463    // -----------------------------------------------------------------------
464    // 2. phoneme_type matching logic — all unsupported types return error
465    // -----------------------------------------------------------------------
466    #[test]
467    fn test_create_phonemizer_unsupported_espeak() {
468        let config = VoiceConfig {
469            audio: Default::default(),
470            num_speakers: 1,
471            num_symbols: 0,
472            phoneme_type: PhonemeType::Espeak,
473            phoneme_id_map: HashMap::new(),
474            num_languages: 1,
475            language_id_map: HashMap::new(),
476            speaker_id_map: HashMap::new(),
477        };
478        match expect_err(PiperVoice::create_phonemizer(&config, None)) {
479            PiperError::UnsupportedLanguage { code } => {
480                assert!(
481                    code.contains("Espeak"),
482                    "expected 'Espeak' in code, got: {code}"
483                );
484            }
485            other => panic!("expected UnsupportedLanguage, got: {other:?}"),
486        }
487    }
488
489    #[test]
490    fn test_create_phonemizer_bilingual_empty_language_id_map() {
491        // Bilingual with empty language_id_map should return InvalidConfig
492        let config = VoiceConfig {
493            audio: Default::default(),
494            num_speakers: 1,
495            num_symbols: 0,
496            phoneme_type: PhonemeType::Bilingual,
497            phoneme_id_map: HashMap::new(),
498            num_languages: 2,
499            language_id_map: HashMap::new(),
500            speaker_id_map: HashMap::new(),
501        };
502        match expect_err(PiperVoice::create_phonemizer(&config, None)) {
503            PiperError::InvalidConfig { reason } => {
504                assert!(
505                    reason.contains("language_id_map"),
506                    "expected 'language_id_map' in reason, got: {reason}"
507                );
508            }
509            other => panic!("expected InvalidConfig, got: {other:?}"),
510        }
511    }
512
513    #[test]
514    fn test_create_phonemizer_bilingual_success() {
515        // Bilingual with populated language_id_map should succeed
516        // Uses en+es (no "ja") to avoid NAIST-JDIC dependency in tests
517        let config = VoiceConfig {
518            audio: Default::default(),
519            num_speakers: 330,
520            num_symbols: 97,
521            phoneme_type: PhonemeType::Bilingual,
522            phoneme_id_map: HashMap::new(),
523            num_languages: 2,
524            language_id_map: [("en".into(), 0i64), ("es".into(), 1)]
525                .into_iter()
526                .collect(),
527            speaker_id_map: HashMap::new(),
528        };
529        let result = PiperVoice::create_phonemizer(&config, None);
530        assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
531        let phonemizer = result.unwrap();
532        // MultilingualPhonemizer returns default_latin_language as language_code
533        assert_eq!(phonemizer.language_code(), "en");
534    }
535
536    #[test]
537    fn test_create_phonemizer_multilingual_success() {
538        // Multilingual with populated language_id_map should succeed
539        // Uses en+zh+es+fr+pt (no "ja") to avoid NAIST-JDIC dependency in tests
540        let config = VoiceConfig {
541            audio: Default::default(),
542            num_speakers: 571,
543            num_symbols: 173,
544            phoneme_type: PhonemeType::Multilingual,
545            phoneme_id_map: HashMap::new(),
546            num_languages: 5,
547            language_id_map: [
548                ("en".into(), 0i64),
549                ("zh".into(), 1),
550                ("es".into(), 2),
551                ("fr".into(), 3),
552                ("pt".into(), 4),
553            ]
554            .into_iter()
555            .collect(),
556            speaker_id_map: HashMap::new(),
557        };
558        let result = PiperVoice::create_phonemizer(&config, None);
559        assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
560        let phonemizer = result.unwrap();
561        assert_eq!(phonemizer.language_code(), "en");
562    }
563
564    #[test]
565    fn test_create_phonemizer_multilingual_empty_language_id_map() {
566        // Multilingual with empty language_id_map should return InvalidConfig
567        let config = VoiceConfig {
568            audio: Default::default(),
569            num_speakers: 571,
570            num_symbols: 173,
571            phoneme_type: PhonemeType::Multilingual,
572            phoneme_id_map: HashMap::new(),
573            num_languages: 6,
574            language_id_map: HashMap::new(),
575            speaker_id_map: HashMap::new(),
576        };
577        match expect_err(PiperVoice::create_phonemizer(&config, None)) {
578            PiperError::InvalidConfig { reason } => {
579                assert!(
580                    reason.contains("language_id_map"),
581                    "expected 'language_id_map' in reason, got: {reason}"
582                );
583            }
584            other => panic!("expected InvalidConfig, got: {other:?}"),
585        }
586    }
587
588    #[test]
589    fn test_create_phonemizer_multilingual_default_latin_fallback() {
590        // When 'en' is not in language_id_map, should fall back to es/fr/pt
591        // Uses zh+es (no "ja" or "en") to test fallback
592        let config = VoiceConfig {
593            audio: Default::default(),
594            num_speakers: 100,
595            num_symbols: 100,
596            phoneme_type: PhonemeType::Multilingual,
597            phoneme_id_map: HashMap::new(),
598            num_languages: 2,
599            language_id_map: [("zh".into(), 0i64), ("es".into(), 1)]
600                .into_iter()
601                .collect(),
602            speaker_id_map: HashMap::new(),
603        };
604        let result = PiperVoice::create_phonemizer(&config, None);
605        assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
606        let phonemizer = result.unwrap();
607        // Should fall back to "es" as the default Latin language
608        assert_eq!(phonemizer.language_code(), "es");
609    }
610
611    #[test]
612    fn test_create_phonemizer_multilingual_detect_language() {
613        // Test that detect_primary_language works through the trait
614        // Uses en+zh (no "ja") to avoid NAIST-JDIC dependency
615        let config = VoiceConfig {
616            audio: Default::default(),
617            num_speakers: 330,
618            num_symbols: 97,
619            phoneme_type: PhonemeType::Bilingual,
620            phoneme_id_map: HashMap::new(),
621            num_languages: 2,
622            language_id_map: [("en".into(), 0i64), ("zh".into(), 1)]
623                .into_iter()
624                .collect(),
625            speaker_id_map: HashMap::new(),
626        };
627        let phonemizer = PiperVoice::create_phonemizer(&config, None).unwrap();
628        // English text should be detected as "en"
629        assert_eq!(phonemizer.detect_primary_language("Hello world"), "en");
630        // Chinese text should be detected as "zh"
631        assert_eq!(phonemizer.detect_primary_language("你好世界"), "zh");
632    }
633
634    #[test]
635    fn test_create_phonemizer_unsupported_text() {
636        let config = VoiceConfig {
637            audio: Default::default(),
638            num_speakers: 1,
639            num_symbols: 0,
640            phoneme_type: PhonemeType::Text,
641            phoneme_id_map: HashMap::new(),
642            num_languages: 1,
643            language_id_map: HashMap::new(),
644            speaker_id_map: HashMap::new(),
645        };
646        match expect_err(PiperVoice::create_phonemizer(&config, None)) {
647            PiperError::UnsupportedLanguage { code } => {
648                assert!(
649                    code.contains("Text"),
650                    "expected 'Text' in code, got: {code}"
651                );
652            }
653            other => panic!("expected UnsupportedLanguage, got: {other:?}"),
654        }
655    }
656
657    // -----------------------------------------------------------------------
658    // 3. language_id determination
659    // -----------------------------------------------------------------------
660    #[test]
661    fn test_language_id_single_language_no_lid() {
662        let config = VoiceConfig {
663            audio: Default::default(),
664            num_speakers: 1,
665            num_symbols: 0,
666            phoneme_type: PhonemeType::OpenJTalk,
667            phoneme_id_map: HashMap::new(),
668            num_languages: 1,
669            language_id_map: HashMap::new(),
670            speaker_id_map: HashMap::new(),
671        };
672        // Single language: needs_lid() should return false
673        assert!(!config.needs_lid());
674        assert!(!config.is_multilingual());
675    }
676
677    #[test]
678    fn test_language_id_multilingual_needs_lid() {
679        let config = VoiceConfig {
680            audio: Default::default(),
681            num_speakers: 571,
682            num_symbols: 173,
683            phoneme_type: PhonemeType::Multilingual,
684            phoneme_id_map: HashMap::new(),
685            num_languages: 6,
686            language_id_map: [
687                ("ja".into(), 0i64),
688                ("en".into(), 1),
689                ("zh".into(), 2),
690                ("es".into(), 3),
691                ("fr".into(), 4),
692                ("pt".into(), 5),
693            ]
694            .into_iter()
695            .collect(),
696            speaker_id_map: HashMap::new(),
697        };
698        assert!(config.needs_lid());
699        assert_eq!(config.language_id_map.get("ja"), Some(&0));
700        assert_eq!(config.language_id_map.get("en"), Some(&1));
701        assert_eq!(config.language_id_map.get("zh"), Some(&2));
702        // Unknown language falls back to 0
703        assert_eq!(config.language_id_map.get("ko").copied().unwrap_or(0), 0);
704    }
705
706    #[test]
707    fn test_language_id_bilingual_needs_lid() {
708        let config = VoiceConfig {
709            audio: Default::default(),
710            num_speakers: 330,
711            num_symbols: 97,
712            phoneme_type: PhonemeType::Bilingual,
713            phoneme_id_map: HashMap::new(),
714            num_languages: 2,
715            language_id_map: [("ja".into(), 0i64), ("en".into(), 1)]
716                .into_iter()
717                .collect(),
718            speaker_id_map: HashMap::new(),
719        };
720        assert!(config.needs_lid());
721        assert_eq!(config.language_id_map.get("ja"), Some(&0));
722        assert_eq!(config.language_id_map.get("en"), Some(&1));
723    }
724
725    // -----------------------------------------------------------------------
726    // 4. SynthesisRequest construction
727    // -----------------------------------------------------------------------
728    #[test]
729    fn test_synthesis_request_construction_basic() {
730        let ids = vec![1i64, 8, 5, 39, 42, 10, 2];
731        let request = SynthesisRequest {
732            phoneme_ids: ids.clone(),
733            prosody_features: None,
734            speaker_id: Some(0),
735            language_id: None,
736            noise_scale: 0.667,
737            length_scale: 1.0,
738            noise_w: 0.8,
739        };
740        assert_eq!(request.phoneme_ids, ids);
741        assert!(request.prosody_features.is_none());
742        assert_eq!(request.speaker_id, Some(0));
743        assert!(request.language_id.is_none());
744    }
745
746    #[test]
747    fn test_synthesis_request_construction_with_prosody() {
748        let prosody_feats = vec![[-2, 1, 5], [0, 2, 5], [1, 3, 5]];
749        let request = SynthesisRequest {
750            phoneme_ids: vec![1, 2, 3],
751            prosody_features: Some(prosody_feats.clone()),
752            speaker_id: Some(3),
753            language_id: Some(0),
754            noise_scale: 0.5,
755            length_scale: 1.2,
756            noise_w: 0.6,
757        };
758        assert_eq!(request.prosody_features.as_ref().unwrap().len(), 3);
759        assert_eq!(request.prosody_features.as_ref().unwrap()[0], [-2, 1, 5]);
760        assert_eq!(request.speaker_id, Some(3));
761        assert_eq!(request.language_id, Some(0));
762    }
763
764    #[test]
765    fn test_synthesis_request_construction_multilingual() {
766        let request = SynthesisRequest {
767            phoneme_ids: vec![1, 5, 10, 20],
768            prosody_features: None,
769            speaker_id: Some(100),
770            language_id: Some(2), // zh
771            noise_scale: 0.667,
772            length_scale: 1.0,
773            noise_w: 0.8,
774        };
775        assert_eq!(request.language_id, Some(2));
776        assert_eq!(request.speaker_id, Some(100));
777    }
778
779    // -----------------------------------------------------------------------
780    // 5. Prosody feature conversion
781    // -----------------------------------------------------------------------
782    #[test]
783    fn test_prosody_to_optional_features_with_values() {
784        let prosody = vec![
785            Some(ProsodyInfo {
786                a1: -2,
787                a2: 1,
788                a3: 5,
789            }),
790            None,
791            Some(ProsodyInfo {
792                a1: 0,
793                a2: 3,
794                a3: 5,
795            }),
796        ];
797        let result = prosody_to_optional_features(&prosody);
798        assert_eq!(result.len(), 3);
799        assert_eq!(result[0], Some([-2, 1, 5]));
800        assert_eq!(result[1], None);
801        assert_eq!(result[2], Some([0, 3, 5]));
802    }
803
804    #[test]
805    fn test_prosody_to_optional_features_all_none() {
806        let prosody: Vec<Option<ProsodyInfo>> = vec![None, None, None];
807        let result = prosody_to_optional_features(&prosody);
808        assert!(result.iter().all(|p| p.is_none()));
809    }
810
811    #[test]
812    fn test_prosody_to_optional_features_empty() {
813        let prosody: Vec<Option<ProsodyInfo>> = vec![];
814        let result = prosody_to_optional_features(&prosody);
815        assert!(result.is_empty());
816    }
817
818    #[test]
819    fn test_build_prosody_tensor_with_some() {
820        let features = vec![Some([-2, 1, 5]), None, Some([0, 3, 5])];
821        let tensor = build_prosody_tensor(&features);
822        assert!(tensor.is_some());
823        let t = tensor.unwrap();
824        assert_eq!(t.len(), 3);
825        assert_eq!(t[0], [-2, 1, 5]);
826        assert_eq!(t[1], [0, 0, 0]); // None -> zero-filled
827        assert_eq!(t[2], [0, 3, 5]);
828    }
829
830    #[test]
831    fn test_build_prosody_tensor_all_none() {
832        let features: Vec<Option<[i32; 3]>> = vec![None, None];
833        let tensor = build_prosody_tensor(&features);
834        assert!(tensor.is_none());
835    }
836
837    #[test]
838    fn test_build_prosody_tensor_empty() {
839        let features: Vec<Option<[i32; 3]>> = vec![];
840        let tensor = build_prosody_tensor(&features);
841        assert!(tensor.is_none());
842    }
843
844    // -----------------------------------------------------------------------
845    // 6. build_prosody_direct (consolidated single-pass conversion)
846    // -----------------------------------------------------------------------
847    #[test]
848    fn test_build_prosody_direct_with_some() {
849        let prosody = vec![
850            Some(ProsodyInfo {
851                a1: -2,
852                a2: 1,
853                a3: 5,
854            }),
855            None,
856            Some(ProsodyInfo {
857                a1: 0,
858                a2: 3,
859                a3: 5,
860            }),
861        ];
862        let tensor = build_prosody_direct(&prosody);
863        assert!(tensor.is_some());
864        let t = tensor.unwrap();
865        assert_eq!(t.len(), 3);
866        assert_eq!(t[0], [-2, 1, 5]);
867        assert_eq!(t[1], [0, 0, 0]); // None -> zero-filled
868        assert_eq!(t[2], [0, 3, 5]);
869    }
870
871    #[test]
872    fn test_build_prosody_direct_all_none() {
873        let prosody: Vec<Option<ProsodyInfo>> = vec![None, None];
874        let tensor = build_prosody_direct(&prosody);
875        assert!(tensor.is_none());
876    }
877
878    #[test]
879    fn test_build_prosody_direct_empty() {
880        let prosody: Vec<Option<ProsodyInfo>> = vec![];
881        let tensor = build_prosody_direct(&prosody);
882        assert!(tensor.is_none());
883    }
884
885    #[test]
886    fn test_build_prosody_direct_matches_two_step() {
887        // Verify build_prosody_direct produces the same result as
888        // prosody_to_optional_features + build_prosody_tensor
889        let prosody = vec![
890            Some(ProsodyInfo {
891                a1: 1,
892                a2: 2,
893                a3: 3,
894            }),
895            None,
896            Some(ProsodyInfo {
897                a1: -1,
898                a2: 0,
899                a3: 7,
900            }),
901            None,
902        ];
903        let two_step = build_prosody_tensor(&prosody_to_optional_features(&prosody));
904        let direct = build_prosody_direct(&prosody);
905        assert_eq!(two_step, direct);
906    }
907
908    // -----------------------------------------------------------------------
909    // phoneme_converter integration (tokens_to_ids)
910    // -----------------------------------------------------------------------
911    #[test]
912    fn test_tokens_to_ids_via_converter() {
913        let mut id_map: HashMap<String, Vec<i64>> = HashMap::new();
914        id_map.insert("a".into(), vec![5]);
915        id_map.insert("k".into(), vec![10]);
916        id_map.insert("o".into(), vec![15]);
917
918        let tokens: Vec<String> = vec!["a".into(), "k".into(), "o".into()];
919        let ids = phoneme_converter::tokens_to_ids(&tokens, &id_map).unwrap();
920        assert_eq!(ids, vec![5, 10, 15]);
921    }
922
923    #[test]
924    fn test_tokens_to_ids_unknown_phoneme() {
925        let id_map: HashMap<String, Vec<i64>> = HashMap::new();
926        let tokens: Vec<String> = vec!["xyz".into()];
927        let result = phoneme_converter::tokens_to_ids(&tokens, &id_map);
928        assert!(result.is_err());
929        match result.unwrap_err() {
930            PiperError::PhonemeIdNotFound { phoneme } => {
931                assert_eq!(phoneme, "xyz");
932            }
933            other => panic!("expected PhonemeIdNotFound, got: {other:?}"),
934        }
935    }
936
937    // -----------------------------------------------------------------------
938    // phonemize_to_ids — cannot be unit-tested without an ONNX model
939    // -----------------------------------------------------------------------
940    // `PiperVoice::phonemize_to_ids` requires a fully initialized `PiperVoice`
941    // (ONNX engine + config), so it cannot be unit-tested without a real model
942    // file. Its internals are covered by the component tests above:
943    //   - phonemize_with_prosody: tested via language-specific phonemizer tests
944    //   - tokens_to_ids: tested in phoneme_converter::tests and above
945    //   - post_process_ids: tested in phonemizer trait tests
946    // End-to-end testing of phonemize_to_ids is done via integration tests
947    // (test_custom_dict_integration.rs) and CLI --test-mode.
948}