1use serde::Deserialize;
2use std::collections::HashMap;
3use std::path::Path;
4
5use crate::error::PiperError;
6
7pub type PhonemeIdMap = HashMap<String, Vec<i64>>;
8
9#[derive(Debug, Clone, Deserialize)]
10pub struct VoiceConfig {
11 #[serde(default)]
12 pub audio: AudioConfig,
13
14 #[serde(default = "default_num_speakers")]
15 pub num_speakers: usize,
16
17 #[serde(default)]
18 pub num_symbols: usize,
19
20 #[serde(default)]
21 pub phoneme_type: PhonemeType,
22
23 #[serde(default)]
24 pub phoneme_id_map: PhonemeIdMap,
25
26 #[serde(default = "default_num_languages")]
27 pub num_languages: usize,
28
29 #[serde(default)]
30 pub language_id_map: HashMap<String, i64>,
31
32 #[serde(default)]
33 pub speaker_id_map: HashMap<String, i64>,
34}
35
36#[derive(Debug, Clone, Deserialize)]
37pub struct AudioConfig {
38 #[serde(default = "default_sample_rate")]
39 pub sample_rate: u32,
40}
41
42impl Default for AudioConfig {
43 fn default() -> Self {
44 Self { sample_rate: 22050 }
45 }
46}
47
48#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
49#[serde(rename_all = "lowercase")]
50pub enum PhonemeType {
51 #[default]
52 #[serde(alias = "espeak")]
53 Espeak,
54 #[serde(alias = "openjtalk")]
55 OpenJTalk,
56 Bilingual,
57 Multilingual,
58 Text,
59}
60
61fn default_num_speakers() -> usize {
62 1
63}
64fn default_num_languages() -> usize {
65 1
66}
67fn default_sample_rate() -> u32 {
68 22050
69}
70
71impl VoiceConfig {
72 pub fn load(path: &Path) -> Result<Self, PiperError> {
74 let content = std::fs::read_to_string(path).map_err(|_| PiperError::ConfigNotFound {
75 path: path.display().to_string(),
76 })?;
77 let config: VoiceConfig = serde_json::from_str(&content)?;
78 Ok(config)
79 }
80
81 pub fn is_multi_speaker(&self) -> bool {
83 self.num_speakers > 1
84 }
85
86 pub fn is_multilingual(&self) -> bool {
88 self.num_languages > 1
89 }
90
91 pub fn needs_sid(&self) -> bool {
93 self.is_multi_speaker() || self.is_multilingual()
94 }
95
96 pub fn needs_lid(&self) -> bool {
98 self.is_multilingual()
99 }
100
101 pub fn needs_prosody(&self) -> bool {
103 self.phoneme_type == PhonemeType::OpenJTalk
106 || self.phoneme_type == PhonemeType::Bilingual
107 || self.phoneme_type == PhonemeType::Multilingual
108 }
109
110 pub fn resolve_config_path(
115 model_path: &Path,
116 explicit_config: Option<&Path>,
117 ) -> Result<std::path::PathBuf, PiperError> {
118 if let Some(p) = explicit_config {
119 if p.exists() {
120 return Ok(p.to_path_buf());
121 }
122 return Err(PiperError::ConfigNotFound {
123 path: p.display().to_string(),
124 });
125 }
126
127 let onnx_json = model_path.with_extension("onnx.json");
129 if onnx_json.exists() {
130 return Ok(onnx_json);
131 }
132
133 if let Some(dir) = model_path.parent() {
135 let dir_config = dir.join("config.json");
136 if dir_config.exists() {
137 return Ok(dir_config);
138 }
139 }
140
141 Err(PiperError::ConfigNotFound {
142 path: format!("no config found for {}", model_path.display()),
143 })
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 #[test]
152 fn test_deserialize_minimal_config() {
153 let json = r#"{"phoneme_id_map": {"a": [1]}, "audio": {"sample_rate": 22050}}"#;
154 let config: VoiceConfig = serde_json::from_str(json).unwrap();
155 assert_eq!(config.audio.sample_rate, 22050);
156 assert_eq!(config.num_speakers, 1);
157 assert_eq!(config.num_languages, 1);
158 assert!(!config.is_multilingual());
159 assert!(!config.needs_lid());
160 }
161
162 #[test]
163 fn test_deserialize_multilingual_config() {
164 let json = r#"{
165 "num_speakers": 571,
166 "num_languages": 6,
167 "phoneme_type": "multilingual",
168 "phoneme_id_map": {"^": [1], "_": [0]},
169 "language_id_map": {"ja": 0, "en": 1, "zh": 2, "es": 3, "fr": 4, "pt": 5}
170 }"#;
171 let config: VoiceConfig = serde_json::from_str(json).unwrap();
172 assert!(config.is_multilingual());
173 assert!(config.needs_sid());
174 assert!(config.needs_lid());
175 assert_eq!(config.language_id_map.len(), 6);
176 }
177
178 #[test]
179 fn test_phoneme_type_deserialization() {
180 let json = r#"{"phoneme_type": "openjtalk"}"#;
181 let config: VoiceConfig = serde_json::from_str(json).unwrap();
182 assert_eq!(config.phoneme_type, PhonemeType::OpenJTalk);
183 }
184}