use std::path::Path;
use crate::config::VoiceConfig;
use crate::engine::{OnnxEngine, SynthesisRequest, SynthesisResult};
use crate::error::PiperError;
use crate::phonemize::Phonemizer;
use crate::phonemize::adapter::G2pAdapter;
pub struct PiperVoice {
config: VoiceConfig,
engine: OnnxEngine,
phonemizer: Box<dyn Phonemizer>,
}
impl PiperVoice {
pub fn load(
model_path: &Path,
config_path: Option<&Path>,
device: &str,
) -> Result<Self, PiperError> {
let resolved_config = VoiceConfig::resolve_config_path(model_path, config_path)?;
let config = VoiceConfig::load(&resolved_config)?;
let model_dir = model_path.parent().map(|p| p.to_path_buf());
#[cfg(not(target_arch = "wasm32"))]
let (phonemizer, engine) = {
let config_clone = config.clone();
let model_dir_clone = model_dir.clone();
let phonemizer_handle = std::thread::spawn(move || {
Self::create_phonemizer(&config_clone, model_dir_clone.as_deref())
});
let engine = OnnxEngine::load(model_path, &config, device)?;
let phonemizer = phonemizer_handle.join().map_err(|e| {
let msg = e
.downcast_ref::<&str>()
.map(|s| s.to_string())
.or_else(|| e.downcast_ref::<String>().cloned())
.unwrap_or_else(|| "unknown panic".to_string());
PiperError::ModelLoad(format!("phonemizer init panicked: {}", msg))
})??;
(phonemizer, engine)
};
#[cfg(target_arch = "wasm32")]
let (phonemizer, engine) = {
let phonemizer = Self::create_phonemizer(&config, model_dir.as_deref())?;
let engine = OnnxEngine::load(model_path, &config, device)?;
(phonemizer, engine)
};
Ok(Self {
config,
engine,
phonemizer,
})
}
pub fn create_phonemizer(
config: &VoiceConfig,
model_dir: Option<&Path>,
) -> Result<Box<dyn Phonemizer>, PiperError> {
match config.phoneme_type {
#[cfg(feature = "japanese")]
crate::config::PhonemeType::OpenJTalk => Ok(Box::new(G2pAdapter::new(Box::new(
Self::create_japanese_phonemizer()?,
)))),
crate::config::PhonemeType::Bilingual | crate::config::PhonemeType::Multilingual => {
let mut languages: Vec<String> = config.language_id_map.keys().cloned().collect();
languages.sort();
if languages.is_empty() {
return Err(PiperError::InvalidConfig {
reason: "multilingual model requires language_id_map".to_string(),
});
}
let default_latin = if languages.contains(&"en".to_string()) {
"en".to_string()
} else {
languages
.iter()
.find(|l| matches!(l.as_str(), "es" | "fr" | "pt" | "sv"))
.cloned()
.unwrap_or_else(|| languages[0].clone())
};
let mut g2p_phonemizers: std::collections::HashMap<
String,
Box<dyn piper_plus_g2p::Phonemizer>,
> = std::collections::HashMap::new();
for lang in &languages {
let p = Self::create_language_g2p_phonemizer(lang, model_dir)?;
g2p_phonemizers.insert(lang.clone(), p);
}
Ok(Box::new(G2pAdapter::new(Box::new(
piper_plus_g2p::multilingual::MultilingualPhonemizer::new(
languages,
default_latin,
g2p_phonemizers,
),
))))
}
_ => Err(PiperError::UnsupportedLanguage {
code: format!("{:?}", config.phoneme_type),
}),
}
}
fn create_language_g2p_phonemizer(
lang: &str,
model_dir: Option<&Path>,
) -> Result<Box<dyn piper_plus_g2p::Phonemizer>, PiperError> {
match lang {
#[cfg(feature = "japanese")]
"ja" => match Self::create_japanese_phonemizer() {
Ok(p) => Ok(Box::new(p)),
Err(e) => {
tracing::warn!("Japanese phonemizer unavailable ({}), using passthrough", e);
Ok(Box::new(
piper_plus_g2p::multilingual::PassthroughPhonemizer::new(lang),
))
}
},
"en" => match Self::create_english_phonemizer(model_dir) {
Ok(p) => Ok(Box::new(p)),
Err(e) => {
tracing::warn!("English phonemizer unavailable ({}), using passthrough", e);
Ok(Box::new(
piper_plus_g2p::multilingual::PassthroughPhonemizer::new(lang),
))
}
},
"zh" => match Self::create_chinese_phonemizer(model_dir) {
Ok(p) => Ok(Box::new(p)),
Err(e) => {
tracing::warn!("Chinese phonemizer unavailable ({}), using passthrough", e);
Ok(Box::new(
piper_plus_g2p::multilingual::PassthroughPhonemizer::new(lang),
))
}
},
"es" => Ok(Box::new(piper_plus_g2p::spanish::SpanishPhonemizer::new())),
"fr" => Ok(Box::new(piper_plus_g2p::french::FrenchPhonemizer::new())),
"pt" => Ok(Box::new(
piper_plus_g2p::portuguese::PortuguesePhonemizer::new(),
)),
"ko" => Ok(Box::new(piper_plus_g2p::korean::KoreanPhonemizer::new())),
"sv" => Ok(Box::new(piper_plus_g2p::swedish::SwedishPhonemizer::new())),
_ => Ok(Box::new(
piper_plus_g2p::multilingual::PassthroughPhonemizer::new(lang),
)),
}
}
fn create_english_phonemizer(
model_dir: Option<&Path>,
) -> Result<piper_plus_g2p::english::EnglishPhonemizer, PiperError> {
if let Some(dir) = model_dir {
let model_dict = dir.join("cmudict_data.json");
if model_dict.exists() {
return piper_plus_g2p::english::EnglishPhonemizer::new_with_dict(&model_dict)
.map_err(PiperError::from);
}
}
piper_plus_g2p::english::EnglishPhonemizer::new().map_err(PiperError::from)
}
fn create_chinese_phonemizer(
model_dir: Option<&Path>,
) -> Result<piper_plus_g2p::chinese::ChinesePhonemizer, PiperError> {
if let (Ok(single), Ok(phrases)) = (
std::env::var("PINYIN_SINGLE_PATH"),
std::env::var("PINYIN_PHRASES_PATH"),
) {
let sp = std::path::PathBuf::from(&single);
let pp = std::path::PathBuf::from(&phrases);
if sp.exists() && pp.exists() {
return piper_plus_g2p::chinese::ChinesePhonemizer::new(&sp, &pp)
.map_err(PiperError::from);
}
}
if let Some(dir) = model_dir {
let single = dir.join("pinyin_single.json");
let phrases = dir.join("pinyin_phrases.json");
if single.exists() && phrases.exists() {
return piper_plus_g2p::chinese::ChinesePhonemizer::new(&single, &phrases)
.map_err(PiperError::from);
}
}
let single = std::path::PathBuf::from("pinyin_single.json");
let phrases = std::path::PathBuf::from("pinyin_phrases.json");
if single.exists() && phrases.exists() {
return piper_plus_g2p::chinese::ChinesePhonemizer::new(&single, &phrases)
.map_err(PiperError::from);
}
Err(PiperError::DictionaryLoad {
path: "pinyin_single.json / pinyin_phrases.json not found. \
Place dictionaries next to the model or set PINYIN_SINGLE_PATH / PINYIN_PHRASES_PATH env vars"
.to_string(),
})
}
pub fn synthesize_text(
&mut self,
text: &str,
speaker_id: Option<i64>,
language_override: Option<&str>,
noise_scale: f32,
length_scale: f32,
noise_w: f32,
) -> Result<SynthesisResult, PiperError> {
let (tokens, prosody) = self.phonemizer.phonemize_with_prosody(text)?;
let phoneme_id_map = self
.phonemizer
.get_phoneme_id_map()
.unwrap_or(&self.config.phoneme_id_map);
let ids = piper_plus_g2p::encode::tokens_to_ids(&tokens, phoneme_id_map)
.map_err(PiperError::from)?;
let prosody_feats = prosody_to_optional_features(&prosody);
let (ids, prosody_feats) =
self.phonemizer
.post_process_ids(ids, prosody_feats, phoneme_id_map);
let prosody_tensor = build_prosody_tensor(&prosody_feats);
let language_id = if self.config.needs_lid() {
let lang_code = if let Some(ovr) = language_override {
ovr
} else {
self.detect_language(text)
};
Some(
self.config
.language_id_map
.get(lang_code)
.copied()
.unwrap_or(0),
)
} else {
None
};
let request = SynthesisRequest {
phoneme_ids: ids,
prosody_features: prosody_tensor,
speaker_id,
language_id,
noise_scale,
length_scale,
noise_w,
};
self.engine.synthesize(&request)
}
pub fn phonemize_to_ids(&self, text: &str) -> Result<Vec<i64>, PiperError> {
let (tokens, prosody) = self.phonemizer.phonemize_with_prosody(text)?;
let phoneme_id_map = self
.phonemizer
.get_phoneme_id_map()
.unwrap_or(&self.config.phoneme_id_map);
let ids = piper_plus_g2p::encode::tokens_to_ids(&tokens, phoneme_id_map)
.map_err(PiperError::from)?;
let prosody_feats = prosody_to_optional_features(&prosody);
let (ids, _prosody_feats) =
self.phonemizer
.post_process_ids(ids, prosody_feats, phoneme_id_map);
Ok(ids)
}
pub fn text_to_wav_file(
&mut self,
text: &str,
output: &Path,
speaker_id: Option<i64>,
) -> Result<SynthesisResult, PiperError> {
let result = self.synthesize_text(text, speaker_id, None, 0.667, 1.0, 0.8)?;
crate::audio::write_wav(output, result.sample_rate, &result.audio)?;
Ok(result)
}
fn detect_language(&self, text: &str) -> &str {
self.phonemizer.detect_primary_language(text)
}
#[cfg(feature = "japanese")]
fn create_japanese_phonemizer()
-> Result<piper_plus_g2p::japanese::JapanesePhonemizer, PiperError> {
#[cfg(feature = "naist-jdic")]
{
piper_plus_g2p::japanese::JapanesePhonemizer::new_bundled().map_err(PiperError::from)
}
#[cfg(not(feature = "naist-jdic"))]
{
match crate::dictionary_manager::ensure_dictionary() {
Ok(dict_path) => {
tracing::info!("Using OpenJTalk dictionary from {}", dict_path.display());
piper_plus_g2p::japanese::JapanesePhonemizer::new_with_dict(&dict_path)
.map_err(PiperError::from)
}
Err(e) => {
tracing::warn!(
"dictionary_manager failed ({}), falling back to JapanesePhonemizer::new()",
e
);
piper_plus_g2p::japanese::JapanesePhonemizer::new().map_err(PiperError::from)
}
}
}
}
pub fn warmup(&mut self, runs: usize) -> Result<(), PiperError> {
self.engine.warmup(runs)
}
pub fn config(&self) -> &VoiceConfig {
&self.config
}
pub fn engine(&self) -> &OnnxEngine {
&self.engine
}
}
fn prosody_to_optional_features(
prosody: &[Option<crate::phonemize::ProsodyInfo>],
) -> Vec<Option<crate::phonemize::ProsodyFeature>> {
prosody
.iter()
.map(|p| p.map(|info| [info.a1, info.a2, info.a3]))
.collect()
}
fn build_prosody_tensor(
features: &[Option<crate::phonemize::ProsodyFeature>],
) -> Option<Vec<crate::phonemize::ProsodyFeature>> {
if features.iter().any(|p| p.is_some()) {
Some(features.iter().map(|p| p.unwrap_or([0, 0, 0])).collect())
} else {
None
}
}
#[cfg(test)]
fn build_prosody_direct(
prosody: &[Option<crate::phonemize::ProsodyInfo>],
) -> Option<Vec<crate::phonemize::ProsodyFeature>> {
if prosody.iter().any(|p| p.is_some()) {
Some(
prosody
.iter()
.map(|p| match p {
Some(info) => [info.a1, info.a2, info.a3],
None => [0, 0, 0],
})
.collect(),
)
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::PhonemeType;
use crate::engine::SynthesisRequest;
use crate::phonemize::ProsodyInfo;
use std::collections::HashMap;
fn expect_err<T>(result: Result<T, PiperError>) -> PiperError {
match result {
Err(e) => e,
Ok(_) => panic!("expected Err, got Ok"),
}
}
#[test]
fn test_load_fails_with_missing_model() {
let result = PiperVoice::load(Path::new("/nonexistent/model.onnx"), None, "cpu");
let err = expect_err(result);
let msg = format!("{err}");
assert!(
msg.contains("config") || msg.contains("not found") || msg.contains("Config"),
"unexpected error message: {msg}"
);
}
#[test]
fn test_create_phonemizer_unsupported_espeak() {
let config = VoiceConfig {
audio: Default::default(),
num_speakers: 1,
num_symbols: 0,
phoneme_type: PhonemeType::Espeak,
phoneme_id_map: HashMap::new(),
num_languages: 1,
language_id_map: HashMap::new(),
speaker_id_map: HashMap::new(),
};
match expect_err(PiperVoice::create_phonemizer(&config, None)) {
PiperError::UnsupportedLanguage { code } => {
assert!(
code.contains("Espeak"),
"expected 'Espeak' in code, got: {code}"
);
}
other => panic!("expected UnsupportedLanguage, got: {other:?}"),
}
}
#[test]
fn test_create_phonemizer_bilingual_empty_language_id_map() {
let config = VoiceConfig {
audio: Default::default(),
num_speakers: 1,
num_symbols: 0,
phoneme_type: PhonemeType::Bilingual,
phoneme_id_map: HashMap::new(),
num_languages: 2,
language_id_map: HashMap::new(),
speaker_id_map: HashMap::new(),
};
match expect_err(PiperVoice::create_phonemizer(&config, None)) {
PiperError::InvalidConfig { reason } => {
assert!(
reason.contains("language_id_map"),
"expected 'language_id_map' in reason, got: {reason}"
);
}
other => panic!("expected InvalidConfig, got: {other:?}"),
}
}
#[test]
fn test_create_phonemizer_bilingual_success() {
let config = VoiceConfig {
audio: Default::default(),
num_speakers: 330,
num_symbols: 97,
phoneme_type: PhonemeType::Bilingual,
phoneme_id_map: HashMap::new(),
num_languages: 2,
language_id_map: [("en".into(), 0i64), ("es".into(), 1)]
.into_iter()
.collect(),
speaker_id_map: HashMap::new(),
};
let result = PiperVoice::create_phonemizer(&config, None);
assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
let phonemizer = result.unwrap();
assert_eq!(phonemizer.language_code(), "en");
}
#[test]
fn test_create_phonemizer_multilingual_success() {
let config = VoiceConfig {
audio: Default::default(),
num_speakers: 571,
num_symbols: 173,
phoneme_type: PhonemeType::Multilingual,
phoneme_id_map: HashMap::new(),
num_languages: 5,
language_id_map: [
("en".into(), 0i64),
("zh".into(), 1),
("es".into(), 2),
("fr".into(), 3),
("pt".into(), 4),
]
.into_iter()
.collect(),
speaker_id_map: HashMap::new(),
};
let result = PiperVoice::create_phonemizer(&config, None);
assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
let phonemizer = result.unwrap();
assert_eq!(phonemizer.language_code(), "en");
}
#[test]
fn test_create_phonemizer_multilingual_empty_language_id_map() {
let config = VoiceConfig {
audio: Default::default(),
num_speakers: 571,
num_symbols: 173,
phoneme_type: PhonemeType::Multilingual,
phoneme_id_map: HashMap::new(),
num_languages: 6,
language_id_map: HashMap::new(),
speaker_id_map: HashMap::new(),
};
match expect_err(PiperVoice::create_phonemizer(&config, None)) {
PiperError::InvalidConfig { reason } => {
assert!(
reason.contains("language_id_map"),
"expected 'language_id_map' in reason, got: {reason}"
);
}
other => panic!("expected InvalidConfig, got: {other:?}"),
}
}
#[test]
fn test_create_phonemizer_multilingual_default_latin_fallback() {
let config = VoiceConfig {
audio: Default::default(),
num_speakers: 100,
num_symbols: 100,
phoneme_type: PhonemeType::Multilingual,
phoneme_id_map: HashMap::new(),
num_languages: 2,
language_id_map: [("zh".into(), 0i64), ("es".into(), 1)]
.into_iter()
.collect(),
speaker_id_map: HashMap::new(),
};
let result = PiperVoice::create_phonemizer(&config, None);
assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
let phonemizer = result.unwrap();
assert_eq!(phonemizer.language_code(), "es");
}
#[test]
fn test_create_phonemizer_multilingual_detect_language() {
let config = VoiceConfig {
audio: Default::default(),
num_speakers: 330,
num_symbols: 97,
phoneme_type: PhonemeType::Bilingual,
phoneme_id_map: HashMap::new(),
num_languages: 2,
language_id_map: [("en".into(), 0i64), ("zh".into(), 1)]
.into_iter()
.collect(),
speaker_id_map: HashMap::new(),
};
let phonemizer = PiperVoice::create_phonemizer(&config, None).unwrap();
assert_eq!(phonemizer.detect_primary_language("Hello world"), "en");
assert_eq!(phonemizer.detect_primary_language("你好世界"), "zh");
}
#[test]
fn test_create_phonemizer_unsupported_text() {
let config = VoiceConfig {
audio: Default::default(),
num_speakers: 1,
num_symbols: 0,
phoneme_type: PhonemeType::Text,
phoneme_id_map: HashMap::new(),
num_languages: 1,
language_id_map: HashMap::new(),
speaker_id_map: HashMap::new(),
};
match expect_err(PiperVoice::create_phonemizer(&config, None)) {
PiperError::UnsupportedLanguage { code } => {
assert!(
code.contains("Text"),
"expected 'Text' in code, got: {code}"
);
}
other => panic!("expected UnsupportedLanguage, got: {other:?}"),
}
}
#[test]
fn test_language_id_single_language_no_lid() {
let config = VoiceConfig {
audio: Default::default(),
num_speakers: 1,
num_symbols: 0,
phoneme_type: PhonemeType::OpenJTalk,
phoneme_id_map: HashMap::new(),
num_languages: 1,
language_id_map: HashMap::new(),
speaker_id_map: HashMap::new(),
};
assert!(!config.needs_lid());
assert!(!config.is_multilingual());
}
#[test]
fn test_language_id_multilingual_needs_lid() {
let config = VoiceConfig {
audio: Default::default(),
num_speakers: 571,
num_symbols: 173,
phoneme_type: PhonemeType::Multilingual,
phoneme_id_map: HashMap::new(),
num_languages: 6,
language_id_map: [
("ja".into(), 0i64),
("en".into(), 1),
("zh".into(), 2),
("es".into(), 3),
("fr".into(), 4),
("pt".into(), 5),
]
.into_iter()
.collect(),
speaker_id_map: HashMap::new(),
};
assert!(config.needs_lid());
assert_eq!(config.language_id_map.get("ja"), Some(&0));
assert_eq!(config.language_id_map.get("en"), Some(&1));
assert_eq!(config.language_id_map.get("zh"), Some(&2));
assert_eq!(config.language_id_map.get("ko").copied().unwrap_or(0), 0);
}
#[test]
fn test_language_id_bilingual_needs_lid() {
let config = VoiceConfig {
audio: Default::default(),
num_speakers: 330,
num_symbols: 97,
phoneme_type: PhonemeType::Bilingual,
phoneme_id_map: HashMap::new(),
num_languages: 2,
language_id_map: [("ja".into(), 0i64), ("en".into(), 1)]
.into_iter()
.collect(),
speaker_id_map: HashMap::new(),
};
assert!(config.needs_lid());
assert_eq!(config.language_id_map.get("ja"), Some(&0));
assert_eq!(config.language_id_map.get("en"), Some(&1));
}
#[test]
fn test_synthesis_request_construction_basic() {
let ids = vec![1i64, 8, 5, 39, 42, 10, 2];
let request = SynthesisRequest {
phoneme_ids: ids.clone(),
prosody_features: None,
speaker_id: Some(0),
language_id: None,
noise_scale: 0.667,
length_scale: 1.0,
noise_w: 0.8,
};
assert_eq!(request.phoneme_ids, ids);
assert!(request.prosody_features.is_none());
assert_eq!(request.speaker_id, Some(0));
assert!(request.language_id.is_none());
}
#[test]
fn test_synthesis_request_construction_with_prosody() {
let prosody_feats = vec![[-2, 1, 5], [0, 2, 5], [1, 3, 5]];
let request = SynthesisRequest {
phoneme_ids: vec![1, 2, 3],
prosody_features: Some(prosody_feats.clone()),
speaker_id: Some(3),
language_id: Some(0),
noise_scale: 0.5,
length_scale: 1.2,
noise_w: 0.6,
};
assert_eq!(request.prosody_features.as_ref().unwrap().len(), 3);
assert_eq!(request.prosody_features.as_ref().unwrap()[0], [-2, 1, 5]);
assert_eq!(request.speaker_id, Some(3));
assert_eq!(request.language_id, Some(0));
}
#[test]
fn test_synthesis_request_construction_multilingual() {
let request = SynthesisRequest {
phoneme_ids: vec![1, 5, 10, 20],
prosody_features: None,
speaker_id: Some(100),
language_id: Some(2), noise_scale: 0.667,
length_scale: 1.0,
noise_w: 0.8,
};
assert_eq!(request.language_id, Some(2));
assert_eq!(request.speaker_id, Some(100));
}
#[test]
fn test_prosody_to_optional_features_with_values() {
let prosody = vec![
Some(ProsodyInfo {
a1: -2,
a2: 1,
a3: 5,
}),
None,
Some(ProsodyInfo {
a1: 0,
a2: 3,
a3: 5,
}),
];
let result = prosody_to_optional_features(&prosody);
assert_eq!(result.len(), 3);
assert_eq!(result[0], Some([-2, 1, 5]));
assert_eq!(result[1], None);
assert_eq!(result[2], Some([0, 3, 5]));
}
#[test]
fn test_prosody_to_optional_features_all_none() {
let prosody: Vec<Option<ProsodyInfo>> = vec![None, None, None];
let result = prosody_to_optional_features(&prosody);
assert!(result.iter().all(|p| p.is_none()));
}
#[test]
fn test_prosody_to_optional_features_empty() {
let prosody: Vec<Option<ProsodyInfo>> = vec![];
let result = prosody_to_optional_features(&prosody);
assert!(result.is_empty());
}
#[test]
fn test_build_prosody_tensor_with_some() {
let features = vec![Some([-2, 1, 5]), None, Some([0, 3, 5])];
let tensor = build_prosody_tensor(&features);
assert!(tensor.is_some());
let t = tensor.unwrap();
assert_eq!(t.len(), 3);
assert_eq!(t[0], [-2, 1, 5]);
assert_eq!(t[1], [0, 0, 0]); assert_eq!(t[2], [0, 3, 5]);
}
#[test]
fn test_build_prosody_tensor_all_none() {
let features: Vec<Option<[i32; 3]>> = vec![None, None];
let tensor = build_prosody_tensor(&features);
assert!(tensor.is_none());
}
#[test]
fn test_build_prosody_tensor_empty() {
let features: Vec<Option<[i32; 3]>> = vec![];
let tensor = build_prosody_tensor(&features);
assert!(tensor.is_none());
}
#[test]
fn test_build_prosody_direct_with_some() {
let prosody = vec![
Some(ProsodyInfo {
a1: -2,
a2: 1,
a3: 5,
}),
None,
Some(ProsodyInfo {
a1: 0,
a2: 3,
a3: 5,
}),
];
let tensor = build_prosody_direct(&prosody);
assert!(tensor.is_some());
let t = tensor.unwrap();
assert_eq!(t.len(), 3);
assert_eq!(t[0], [-2, 1, 5]);
assert_eq!(t[1], [0, 0, 0]); assert_eq!(t[2], [0, 3, 5]);
}
#[test]
fn test_build_prosody_direct_all_none() {
let prosody: Vec<Option<ProsodyInfo>> = vec![None, None];
let tensor = build_prosody_direct(&prosody);
assert!(tensor.is_none());
}
#[test]
fn test_build_prosody_direct_empty() {
let prosody: Vec<Option<ProsodyInfo>> = vec![];
let tensor = build_prosody_direct(&prosody);
assert!(tensor.is_none());
}
#[test]
fn test_build_prosody_direct_matches_two_step() {
let prosody = vec![
Some(ProsodyInfo {
a1: 1,
a2: 2,
a3: 3,
}),
None,
Some(ProsodyInfo {
a1: -1,
a2: 0,
a3: 7,
}),
None,
];
let two_step = build_prosody_tensor(&prosody_to_optional_features(&prosody));
let direct = build_prosody_direct(&prosody);
assert_eq!(two_step, direct);
}
#[test]
fn test_tokens_to_ids_via_converter() {
let mut id_map: HashMap<String, Vec<i64>> = HashMap::new();
id_map.insert("a".into(), vec![5]);
id_map.insert("k".into(), vec![10]);
id_map.insert("o".into(), vec![15]);
let tokens: Vec<String> = vec!["a".into(), "k".into(), "o".into()];
let ids = piper_plus_g2p::encode::tokens_to_ids(&tokens, &id_map)
.map_err(PiperError::from)
.unwrap();
assert_eq!(ids, vec![5, 10, 15]);
}
#[test]
fn test_tokens_to_ids_unknown_phoneme() {
let id_map: HashMap<String, Vec<i64>> = HashMap::new();
let tokens: Vec<String> = vec!["xyz".into()];
let result =
piper_plus_g2p::encode::tokens_to_ids(&tokens, &id_map).map_err(PiperError::from);
assert!(result.is_err());
match result.unwrap_err() {
PiperError::PhonemeIdNotFound { phoneme } => {
assert_eq!(phoneme, "xyz");
}
other => panic!("expected PhonemeIdNotFound, got: {other:?}"),
}
}
}