1use std::path::Path;
6
7use crate::config::VoiceConfig;
8use crate::engine::{OnnxEngine, SynthesisRequest, SynthesisResult};
9use crate::error::PiperError;
10use crate::phonemize::Phonemizer;
11use crate::phonemize::phoneme_converter;
12
13pub struct PiperVoice {
15 config: VoiceConfig,
16 engine: OnnxEngine,
17 phonemizer: Box<dyn Phonemizer>,
18}
19
20impl PiperVoice {
21 pub fn load(
27 model_path: &Path,
28 config_path: Option<&Path>,
29 device: &str,
30 ) -> Result<Self, PiperError> {
31 let resolved_config = VoiceConfig::resolve_config_path(model_path, config_path)?;
32 let config = VoiceConfig::load(&resolved_config)?;
33 let model_dir = model_path.parent().map(|p| p.to_path_buf());
34 let phonemizer = Self::create_phonemizer(&config, model_dir.as_deref())?;
35 let engine = OnnxEngine::load(model_path, &config, device)?;
36
37 Ok(Self {
38 config,
39 engine,
40 phonemizer,
41 })
42 }
43
44 pub fn create_phonemizer(
50 config: &VoiceConfig,
51 model_dir: Option<&Path>,
52 ) -> Result<Box<dyn Phonemizer>, PiperError> {
53 match config.phoneme_type {
54 #[cfg(feature = "japanese")]
55 crate::config::PhonemeType::OpenJTalk => {
56 Ok(Box::new(Self::create_japanese_phonemizer()?))
57 }
58 crate::config::PhonemeType::Bilingual | crate::config::PhonemeType::Multilingual => {
59 let mut languages: Vec<String> = config.language_id_map.keys().cloned().collect();
61 languages.sort(); if languages.is_empty() {
64 return Err(PiperError::InvalidConfig {
65 reason: "multilingual model requires language_id_map".to_string(),
66 });
67 }
68
69 let default_latin = if languages.contains(&"en".to_string()) {
71 "en".to_string()
72 } else {
73 languages
74 .iter()
75 .find(|l| matches!(l.as_str(), "es" | "fr" | "pt"))
76 .cloned()
77 .unwrap_or_else(|| languages[0].clone())
78 };
79
80 let mut phonemizers: std::collections::HashMap<String, Box<dyn Phonemizer>> =
82 std::collections::HashMap::new();
83
84 for lang in &languages {
85 let phonemizer: Box<dyn Phonemizer> =
86 Self::create_language_phonemizer(lang, model_dir)?;
87 phonemizers.insert(lang.clone(), phonemizer);
88 }
89
90 Ok(Box::new(
91 crate::phonemize::multilingual::MultilingualPhonemizer::new(
92 languages,
93 default_latin,
94 phonemizers,
95 ),
96 ))
97 }
98 _ => Err(PiperError::UnsupportedLanguage {
99 code: format!("{:?}", config.phoneme_type),
100 }),
101 }
102 }
103
104 fn create_language_phonemizer(
111 lang: &str,
112 model_dir: Option<&Path>,
113 ) -> Result<Box<dyn Phonemizer>, PiperError> {
114 match lang {
115 #[cfg(feature = "japanese")]
116 "ja" => match Self::create_japanese_phonemizer() {
117 Ok(p) => Ok(Box::new(p)),
118 Err(e) => {
119 tracing::warn!("Japanese phonemizer unavailable ({}), using passthrough", e);
120 Ok(Box::new(
121 crate::phonemize::multilingual::PassthroughPhonemizer::new(lang),
122 ))
123 }
124 },
125 "en" => match Self::create_english_phonemizer(model_dir) {
126 Ok(p) => Ok(Box::new(p)),
127 Err(e) => {
128 tracing::warn!("English phonemizer unavailable ({}), using passthrough", e);
129 Ok(Box::new(
130 crate::phonemize::multilingual::PassthroughPhonemizer::new(lang),
131 ))
132 }
133 },
134 "zh" => match Self::create_chinese_phonemizer(model_dir) {
135 Ok(p) => Ok(Box::new(p)),
136 Err(e) => {
137 tracing::warn!("Chinese phonemizer unavailable ({}), using passthrough", e);
138 Ok(Box::new(
139 crate::phonemize::multilingual::PassthroughPhonemizer::new(lang),
140 ))
141 }
142 },
143 "es" => Ok(Box::new(crate::phonemize::spanish::SpanishPhonemizer::new())),
144 "fr" => Ok(Box::new(crate::phonemize::french::FrenchPhonemizer::new())),
145 "pt" => Ok(Box::new(
146 crate::phonemize::portuguese::PortuguesePhonemizer::new(),
147 )),
148 "ko" => Ok(Box::new(crate::phonemize::korean::KoreanPhonemizer::new())),
149 _ => Ok(Box::new(
150 crate::phonemize::multilingual::PassthroughPhonemizer::new(lang),
151 )),
152 }
153 }
154
155 fn create_english_phonemizer(
163 model_dir: Option<&Path>,
164 ) -> Result<crate::phonemize::english::EnglishPhonemizer, PiperError> {
165 if let Some(dir) = model_dir {
167 let model_dict = dir.join("cmudict_data.json");
168 if model_dict.exists() {
169 return crate::phonemize::english::EnglishPhonemizer::new_with_dict(&model_dict);
170 }
171 }
172 crate::phonemize::english::EnglishPhonemizer::new()
174 }
175
176 fn create_chinese_phonemizer(
183 model_dir: Option<&Path>,
184 ) -> Result<crate::phonemize::chinese::ChinesePhonemizer, PiperError> {
185 if let (Ok(single), Ok(phrases)) = (
187 std::env::var("PINYIN_SINGLE_PATH"),
188 std::env::var("PINYIN_PHRASES_PATH"),
189 ) {
190 let sp = std::path::PathBuf::from(&single);
191 let pp = std::path::PathBuf::from(&phrases);
192 if sp.exists() && pp.exists() {
193 return crate::phonemize::chinese::ChinesePhonemizer::new(&sp, &pp);
194 }
195 }
196
197 if let Some(dir) = model_dir {
199 let single = dir.join("pinyin_single.json");
200 let phrases = dir.join("pinyin_phrases.json");
201 if single.exists() && phrases.exists() {
202 return crate::phonemize::chinese::ChinesePhonemizer::new(&single, &phrases);
203 }
204 }
205
206 let single = std::path::PathBuf::from("pinyin_single.json");
208 let phrases = std::path::PathBuf::from("pinyin_phrases.json");
209 if single.exists() && phrases.exists() {
210 return crate::phonemize::chinese::ChinesePhonemizer::new(&single, &phrases);
211 }
212
213 Err(PiperError::DictionaryLoad {
214 path: "pinyin_single.json / pinyin_phrases.json not found. \
215 Place dictionaries next to the model or set PINYIN_SINGLE_PATH / PINYIN_PHRASES_PATH env vars"
216 .to_string(),
217 })
218 }
219
220 pub fn synthesize_text(
225 &mut self,
226 text: &str,
227 speaker_id: Option<i64>,
228 language_override: Option<&str>,
229 noise_scale: f32,
230 length_scale: f32,
231 noise_w: f32,
232 ) -> Result<SynthesisResult, PiperError> {
233 let (tokens, prosody) = self.phonemizer.phonemize_with_prosody(text)?;
235
236 let phoneme_id_map = self
238 .phonemizer
239 .get_phoneme_id_map()
240 .unwrap_or(&self.config.phoneme_id_map);
241
242 let ids = phoneme_converter::tokens_to_ids(&tokens, phoneme_id_map)?;
243 let prosody_feats = prosody_to_optional_features(&prosody);
244
245 let (ids, prosody_feats) =
247 self.phonemizer
248 .post_process_ids(ids, prosody_feats, phoneme_id_map);
249
250 let prosody_tensor = build_prosody_tensor(&prosody_feats);
253
254 let language_id = if self.config.needs_lid() {
259 let lang_code = if let Some(ovr) = language_override {
260 ovr
261 } else {
262 self.detect_language(text)
263 };
264 Some(
265 self.config
266 .language_id_map
267 .get(lang_code)
268 .copied()
269 .unwrap_or(0),
270 )
271 } else {
272 None
273 };
274
275 let request = SynthesisRequest {
277 phoneme_ids: ids,
278 prosody_features: prosody_tensor,
279 speaker_id,
280 language_id,
281 noise_scale,
282 length_scale,
283 noise_w,
284 };
285
286 self.engine.synthesize(&request)
287 }
288
289 pub fn phonemize_to_ids(&self, text: &str) -> Result<Vec<i64>, PiperError> {
293 let (tokens, prosody) = self.phonemizer.phonemize_with_prosody(text)?;
294
295 let phoneme_id_map = self
296 .phonemizer
297 .get_phoneme_id_map()
298 .unwrap_or(&self.config.phoneme_id_map);
299
300 let ids = phoneme_converter::tokens_to_ids(&tokens, phoneme_id_map)?;
301 let prosody_feats = prosody_to_optional_features(&prosody);
302
303 let (ids, _prosody_feats) =
304 self.phonemizer
305 .post_process_ids(ids, prosody_feats, phoneme_id_map);
306
307 Ok(ids)
308 }
309
310 pub fn text_to_wav_file(
312 &mut self,
313 text: &str,
314 output: &Path,
315 speaker_id: Option<i64>,
316 ) -> Result<SynthesisResult, PiperError> {
317 let result = self.synthesize_text(text, speaker_id, None, 0.667, 1.0, 0.8)?;
318 crate::audio::write_wav(output, result.sample_rate, &result.audio)?;
319 Ok(result)
320 }
321
322 fn detect_language(&self, text: &str) -> &str {
328 self.phonemizer.detect_primary_language(text)
329 }
330
331 #[cfg(feature = "japanese")]
337 fn create_japanese_phonemizer()
338 -> Result<crate::phonemize::japanese::JapanesePhonemizer, PiperError> {
339 #[cfg(feature = "naist-jdic")]
340 {
341 crate::phonemize::japanese::JapanesePhonemizer::new_bundled()
342 }
343 #[cfg(not(feature = "naist-jdic"))]
344 {
345 match crate::dictionary_manager::ensure_dictionary() {
347 Ok(dict_path) => {
348 tracing::info!("Using OpenJTalk dictionary from {}", dict_path.display());
349 crate::phonemize::japanese::JapanesePhonemizer::new_with_dict(&dict_path)
350 }
351 Err(e) => {
352 tracing::warn!(
353 "dictionary_manager failed ({}), falling back to JapanesePhonemizer::new()",
354 e
355 );
356 crate::phonemize::japanese::JapanesePhonemizer::new()
358 }
359 }
360 }
361 }
362
363 pub fn config(&self) -> &VoiceConfig {
365 &self.config
366 }
367
368 pub fn engine(&self) -> &OnnxEngine {
370 &self.engine
371 }
372}
373
374fn prosody_to_optional_features(
382 prosody: &[Option<crate::phonemize::ProsodyInfo>],
383) -> Vec<Option<crate::phonemize::ProsodyFeature>> {
384 prosody
385 .iter()
386 .map(|p| p.map(|info| [info.a1, info.a2, info.a3]))
387 .collect()
388}
389
390fn build_prosody_tensor(
395 features: &[Option<crate::phonemize::ProsodyFeature>],
396) -> Option<Vec<crate::phonemize::ProsodyFeature>> {
397 if features.iter().any(|p| p.is_some()) {
398 Some(features.iter().map(|p| p.unwrap_or([0, 0, 0])).collect())
399 } else {
400 None
401 }
402}
403
404#[cfg(test)]
410fn build_prosody_direct(
411 prosody: &[Option<crate::phonemize::ProsodyInfo>],
412) -> Option<Vec<crate::phonemize::ProsodyFeature>> {
413 if prosody.iter().any(|p| p.is_some()) {
414 Some(
415 prosody
416 .iter()
417 .map(|p| match p {
418 Some(info) => [info.a1, info.a2, info.a3],
419 None => [0, 0, 0],
420 })
421 .collect(),
422 )
423 } else {
424 None
425 }
426}
427
428#[cfg(test)]
433mod tests {
434 use super::*;
435 use crate::config::PhonemeType;
436 use crate::engine::SynthesisRequest;
437 use crate::phonemize::ProsodyInfo;
438 use std::collections::HashMap;
439
440 fn expect_err<T>(result: Result<T, PiperError>) -> PiperError {
442 match result {
443 Err(e) => e,
444 Ok(_) => panic!("expected Err, got Ok"),
445 }
446 }
447
448 #[test]
452 fn test_load_fails_with_missing_model() {
453 let result = PiperVoice::load(Path::new("/nonexistent/model.onnx"), None, "cpu");
454 let err = expect_err(result);
455 let msg = format!("{err}");
457 assert!(
458 msg.contains("config") || msg.contains("not found") || msg.contains("Config"),
459 "unexpected error message: {msg}"
460 );
461 }
462
463 #[test]
467 fn test_create_phonemizer_unsupported_espeak() {
468 let config = VoiceConfig {
469 audio: Default::default(),
470 num_speakers: 1,
471 num_symbols: 0,
472 phoneme_type: PhonemeType::Espeak,
473 phoneme_id_map: HashMap::new(),
474 num_languages: 1,
475 language_id_map: HashMap::new(),
476 speaker_id_map: HashMap::new(),
477 };
478 match expect_err(PiperVoice::create_phonemizer(&config, None)) {
479 PiperError::UnsupportedLanguage { code } => {
480 assert!(
481 code.contains("Espeak"),
482 "expected 'Espeak' in code, got: {code}"
483 );
484 }
485 other => panic!("expected UnsupportedLanguage, got: {other:?}"),
486 }
487 }
488
489 #[test]
490 fn test_create_phonemizer_bilingual_empty_language_id_map() {
491 let config = VoiceConfig {
493 audio: Default::default(),
494 num_speakers: 1,
495 num_symbols: 0,
496 phoneme_type: PhonemeType::Bilingual,
497 phoneme_id_map: HashMap::new(),
498 num_languages: 2,
499 language_id_map: HashMap::new(),
500 speaker_id_map: HashMap::new(),
501 };
502 match expect_err(PiperVoice::create_phonemizer(&config, None)) {
503 PiperError::InvalidConfig { reason } => {
504 assert!(
505 reason.contains("language_id_map"),
506 "expected 'language_id_map' in reason, got: {reason}"
507 );
508 }
509 other => panic!("expected InvalidConfig, got: {other:?}"),
510 }
511 }
512
513 #[test]
514 fn test_create_phonemizer_bilingual_success() {
515 let config = VoiceConfig {
518 audio: Default::default(),
519 num_speakers: 330,
520 num_symbols: 97,
521 phoneme_type: PhonemeType::Bilingual,
522 phoneme_id_map: HashMap::new(),
523 num_languages: 2,
524 language_id_map: [("en".into(), 0i64), ("es".into(), 1)]
525 .into_iter()
526 .collect(),
527 speaker_id_map: HashMap::new(),
528 };
529 let result = PiperVoice::create_phonemizer(&config, None);
530 assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
531 let phonemizer = result.unwrap();
532 assert_eq!(phonemizer.language_code(), "en");
534 }
535
536 #[test]
537 fn test_create_phonemizer_multilingual_success() {
538 let config = VoiceConfig {
541 audio: Default::default(),
542 num_speakers: 571,
543 num_symbols: 173,
544 phoneme_type: PhonemeType::Multilingual,
545 phoneme_id_map: HashMap::new(),
546 num_languages: 5,
547 language_id_map: [
548 ("en".into(), 0i64),
549 ("zh".into(), 1),
550 ("es".into(), 2),
551 ("fr".into(), 3),
552 ("pt".into(), 4),
553 ]
554 .into_iter()
555 .collect(),
556 speaker_id_map: HashMap::new(),
557 };
558 let result = PiperVoice::create_phonemizer(&config, None);
559 assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
560 let phonemizer = result.unwrap();
561 assert_eq!(phonemizer.language_code(), "en");
562 }
563
564 #[test]
565 fn test_create_phonemizer_multilingual_empty_language_id_map() {
566 let config = VoiceConfig {
568 audio: Default::default(),
569 num_speakers: 571,
570 num_symbols: 173,
571 phoneme_type: PhonemeType::Multilingual,
572 phoneme_id_map: HashMap::new(),
573 num_languages: 6,
574 language_id_map: HashMap::new(),
575 speaker_id_map: HashMap::new(),
576 };
577 match expect_err(PiperVoice::create_phonemizer(&config, None)) {
578 PiperError::InvalidConfig { reason } => {
579 assert!(
580 reason.contains("language_id_map"),
581 "expected 'language_id_map' in reason, got: {reason}"
582 );
583 }
584 other => panic!("expected InvalidConfig, got: {other:?}"),
585 }
586 }
587
588 #[test]
589 fn test_create_phonemizer_multilingual_default_latin_fallback() {
590 let config = VoiceConfig {
593 audio: Default::default(),
594 num_speakers: 100,
595 num_symbols: 100,
596 phoneme_type: PhonemeType::Multilingual,
597 phoneme_id_map: HashMap::new(),
598 num_languages: 2,
599 language_id_map: [("zh".into(), 0i64), ("es".into(), 1)]
600 .into_iter()
601 .collect(),
602 speaker_id_map: HashMap::new(),
603 };
604 let result = PiperVoice::create_phonemizer(&config, None);
605 assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
606 let phonemizer = result.unwrap();
607 assert_eq!(phonemizer.language_code(), "es");
609 }
610
611 #[test]
612 fn test_create_phonemizer_multilingual_detect_language() {
613 let config = VoiceConfig {
616 audio: Default::default(),
617 num_speakers: 330,
618 num_symbols: 97,
619 phoneme_type: PhonemeType::Bilingual,
620 phoneme_id_map: HashMap::new(),
621 num_languages: 2,
622 language_id_map: [("en".into(), 0i64), ("zh".into(), 1)]
623 .into_iter()
624 .collect(),
625 speaker_id_map: HashMap::new(),
626 };
627 let phonemizer = PiperVoice::create_phonemizer(&config, None).unwrap();
628 assert_eq!(phonemizer.detect_primary_language("Hello world"), "en");
630 assert_eq!(phonemizer.detect_primary_language("你好世界"), "zh");
632 }
633
634 #[test]
635 fn test_create_phonemizer_unsupported_text() {
636 let config = VoiceConfig {
637 audio: Default::default(),
638 num_speakers: 1,
639 num_symbols: 0,
640 phoneme_type: PhonemeType::Text,
641 phoneme_id_map: HashMap::new(),
642 num_languages: 1,
643 language_id_map: HashMap::new(),
644 speaker_id_map: HashMap::new(),
645 };
646 match expect_err(PiperVoice::create_phonemizer(&config, None)) {
647 PiperError::UnsupportedLanguage { code } => {
648 assert!(
649 code.contains("Text"),
650 "expected 'Text' in code, got: {code}"
651 );
652 }
653 other => panic!("expected UnsupportedLanguage, got: {other:?}"),
654 }
655 }
656
657 #[test]
661 fn test_language_id_single_language_no_lid() {
662 let config = VoiceConfig {
663 audio: Default::default(),
664 num_speakers: 1,
665 num_symbols: 0,
666 phoneme_type: PhonemeType::OpenJTalk,
667 phoneme_id_map: HashMap::new(),
668 num_languages: 1,
669 language_id_map: HashMap::new(),
670 speaker_id_map: HashMap::new(),
671 };
672 assert!(!config.needs_lid());
674 assert!(!config.is_multilingual());
675 }
676
677 #[test]
678 fn test_language_id_multilingual_needs_lid() {
679 let config = VoiceConfig {
680 audio: Default::default(),
681 num_speakers: 571,
682 num_symbols: 173,
683 phoneme_type: PhonemeType::Multilingual,
684 phoneme_id_map: HashMap::new(),
685 num_languages: 6,
686 language_id_map: [
687 ("ja".into(), 0i64),
688 ("en".into(), 1),
689 ("zh".into(), 2),
690 ("es".into(), 3),
691 ("fr".into(), 4),
692 ("pt".into(), 5),
693 ]
694 .into_iter()
695 .collect(),
696 speaker_id_map: HashMap::new(),
697 };
698 assert!(config.needs_lid());
699 assert_eq!(config.language_id_map.get("ja"), Some(&0));
700 assert_eq!(config.language_id_map.get("en"), Some(&1));
701 assert_eq!(config.language_id_map.get("zh"), Some(&2));
702 assert_eq!(config.language_id_map.get("ko").copied().unwrap_or(0), 0);
704 }
705
706 #[test]
707 fn test_language_id_bilingual_needs_lid() {
708 let config = VoiceConfig {
709 audio: Default::default(),
710 num_speakers: 330,
711 num_symbols: 97,
712 phoneme_type: PhonemeType::Bilingual,
713 phoneme_id_map: HashMap::new(),
714 num_languages: 2,
715 language_id_map: [("ja".into(), 0i64), ("en".into(), 1)]
716 .into_iter()
717 .collect(),
718 speaker_id_map: HashMap::new(),
719 };
720 assert!(config.needs_lid());
721 assert_eq!(config.language_id_map.get("ja"), Some(&0));
722 assert_eq!(config.language_id_map.get("en"), Some(&1));
723 }
724
725 #[test]
729 fn test_synthesis_request_construction_basic() {
730 let ids = vec![1i64, 8, 5, 39, 42, 10, 2];
731 let request = SynthesisRequest {
732 phoneme_ids: ids.clone(),
733 prosody_features: None,
734 speaker_id: Some(0),
735 language_id: None,
736 noise_scale: 0.667,
737 length_scale: 1.0,
738 noise_w: 0.8,
739 };
740 assert_eq!(request.phoneme_ids, ids);
741 assert!(request.prosody_features.is_none());
742 assert_eq!(request.speaker_id, Some(0));
743 assert!(request.language_id.is_none());
744 }
745
746 #[test]
747 fn test_synthesis_request_construction_with_prosody() {
748 let prosody_feats = vec![[-2, 1, 5], [0, 2, 5], [1, 3, 5]];
749 let request = SynthesisRequest {
750 phoneme_ids: vec![1, 2, 3],
751 prosody_features: Some(prosody_feats.clone()),
752 speaker_id: Some(3),
753 language_id: Some(0),
754 noise_scale: 0.5,
755 length_scale: 1.2,
756 noise_w: 0.6,
757 };
758 assert_eq!(request.prosody_features.as_ref().unwrap().len(), 3);
759 assert_eq!(request.prosody_features.as_ref().unwrap()[0], [-2, 1, 5]);
760 assert_eq!(request.speaker_id, Some(3));
761 assert_eq!(request.language_id, Some(0));
762 }
763
764 #[test]
765 fn test_synthesis_request_construction_multilingual() {
766 let request = SynthesisRequest {
767 phoneme_ids: vec![1, 5, 10, 20],
768 prosody_features: None,
769 speaker_id: Some(100),
770 language_id: Some(2), noise_scale: 0.667,
772 length_scale: 1.0,
773 noise_w: 0.8,
774 };
775 assert_eq!(request.language_id, Some(2));
776 assert_eq!(request.speaker_id, Some(100));
777 }
778
779 #[test]
783 fn test_prosody_to_optional_features_with_values() {
784 let prosody = vec![
785 Some(ProsodyInfo {
786 a1: -2,
787 a2: 1,
788 a3: 5,
789 }),
790 None,
791 Some(ProsodyInfo {
792 a1: 0,
793 a2: 3,
794 a3: 5,
795 }),
796 ];
797 let result = prosody_to_optional_features(&prosody);
798 assert_eq!(result.len(), 3);
799 assert_eq!(result[0], Some([-2, 1, 5]));
800 assert_eq!(result[1], None);
801 assert_eq!(result[2], Some([0, 3, 5]));
802 }
803
804 #[test]
805 fn test_prosody_to_optional_features_all_none() {
806 let prosody: Vec<Option<ProsodyInfo>> = vec![None, None, None];
807 let result = prosody_to_optional_features(&prosody);
808 assert!(result.iter().all(|p| p.is_none()));
809 }
810
811 #[test]
812 fn test_prosody_to_optional_features_empty() {
813 let prosody: Vec<Option<ProsodyInfo>> = vec![];
814 let result = prosody_to_optional_features(&prosody);
815 assert!(result.is_empty());
816 }
817
818 #[test]
819 fn test_build_prosody_tensor_with_some() {
820 let features = vec![Some([-2, 1, 5]), None, Some([0, 3, 5])];
821 let tensor = build_prosody_tensor(&features);
822 assert!(tensor.is_some());
823 let t = tensor.unwrap();
824 assert_eq!(t.len(), 3);
825 assert_eq!(t[0], [-2, 1, 5]);
826 assert_eq!(t[1], [0, 0, 0]); assert_eq!(t[2], [0, 3, 5]);
828 }
829
830 #[test]
831 fn test_build_prosody_tensor_all_none() {
832 let features: Vec<Option<[i32; 3]>> = vec![None, None];
833 let tensor = build_prosody_tensor(&features);
834 assert!(tensor.is_none());
835 }
836
837 #[test]
838 fn test_build_prosody_tensor_empty() {
839 let features: Vec<Option<[i32; 3]>> = vec![];
840 let tensor = build_prosody_tensor(&features);
841 assert!(tensor.is_none());
842 }
843
844 #[test]
848 fn test_build_prosody_direct_with_some() {
849 let prosody = vec![
850 Some(ProsodyInfo {
851 a1: -2,
852 a2: 1,
853 a3: 5,
854 }),
855 None,
856 Some(ProsodyInfo {
857 a1: 0,
858 a2: 3,
859 a3: 5,
860 }),
861 ];
862 let tensor = build_prosody_direct(&prosody);
863 assert!(tensor.is_some());
864 let t = tensor.unwrap();
865 assert_eq!(t.len(), 3);
866 assert_eq!(t[0], [-2, 1, 5]);
867 assert_eq!(t[1], [0, 0, 0]); assert_eq!(t[2], [0, 3, 5]);
869 }
870
871 #[test]
872 fn test_build_prosody_direct_all_none() {
873 let prosody: Vec<Option<ProsodyInfo>> = vec![None, None];
874 let tensor = build_prosody_direct(&prosody);
875 assert!(tensor.is_none());
876 }
877
878 #[test]
879 fn test_build_prosody_direct_empty() {
880 let prosody: Vec<Option<ProsodyInfo>> = vec![];
881 let tensor = build_prosody_direct(&prosody);
882 assert!(tensor.is_none());
883 }
884
885 #[test]
886 fn test_build_prosody_direct_matches_two_step() {
887 let prosody = vec![
890 Some(ProsodyInfo {
891 a1: 1,
892 a2: 2,
893 a3: 3,
894 }),
895 None,
896 Some(ProsodyInfo {
897 a1: -1,
898 a2: 0,
899 a3: 7,
900 }),
901 None,
902 ];
903 let two_step = build_prosody_tensor(&prosody_to_optional_features(&prosody));
904 let direct = build_prosody_direct(&prosody);
905 assert_eq!(two_step, direct);
906 }
907
908 #[test]
912 fn test_tokens_to_ids_via_converter() {
913 let mut id_map: HashMap<String, Vec<i64>> = HashMap::new();
914 id_map.insert("a".into(), vec![5]);
915 id_map.insert("k".into(), vec![10]);
916 id_map.insert("o".into(), vec![15]);
917
918 let tokens: Vec<String> = vec!["a".into(), "k".into(), "o".into()];
919 let ids = phoneme_converter::tokens_to_ids(&tokens, &id_map).unwrap();
920 assert_eq!(ids, vec![5, 10, 15]);
921 }
922
923 #[test]
924 fn test_tokens_to_ids_unknown_phoneme() {
925 let id_map: HashMap<String, Vec<i64>> = HashMap::new();
926 let tokens: Vec<String> = vec!["xyz".into()];
927 let result = phoneme_converter::tokens_to_ids(&tokens, &id_map);
928 assert!(result.is_err());
929 match result.unwrap_err() {
930 PiperError::PhonemeIdNotFound { phoneme } => {
931 assert_eq!(phoneme, "xyz");
932 }
933 other => panic!("expected PhonemeIdNotFound, got: {other:?}"),
934 }
935 }
936
937 }