1use serde::Deserialize;
2use std::collections::HashMap;
3use std::path::Path;
4
5use crate::error::PiperError;
6
7pub type PhonemeIdMap = HashMap<String, Vec<i64>>;
8
9#[derive(Debug, Clone, Deserialize)]
10pub struct VoiceConfig {
11 #[serde(default)]
12 pub audio: AudioConfig,
13
14 #[serde(default = "default_num_speakers")]
15 pub num_speakers: usize,
16
17 #[serde(default)]
18 pub num_symbols: usize,
19
20 #[serde(default)]
21 pub phoneme_type: PhonemeType,
22
23 #[serde(default)]
24 pub phoneme_id_map: PhonemeIdMap,
25
26 #[serde(default = "default_num_languages")]
27 pub num_languages: usize,
28
29 #[serde(default)]
30 pub language_id_map: HashMap<String, i64>,
31
32 #[serde(default)]
33 pub speaker_id_map: HashMap<String, i64>,
34}
35
36#[derive(Debug, Clone, Deserialize)]
37pub struct AudioConfig {
38 #[serde(default = "default_sample_rate")]
39 pub sample_rate: u32,
40}
41
42impl Default for AudioConfig {
43 fn default() -> Self {
44 Self { sample_rate: 22050 }
45 }
46}
47
48#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
49#[serde(rename_all = "lowercase")]
50pub enum PhonemeType {
51 #[default]
52 #[serde(alias = "espeak")]
53 Espeak,
54 #[serde(alias = "openjtalk")]
55 OpenJTalk,
56 Bilingual,
57 Multilingual,
58 Text,
59}
60
61fn default_num_speakers() -> usize {
62 1
63}
64fn default_num_languages() -> usize {
65 1
66}
67fn default_sample_rate() -> u32 {
68 22050
69}
70
71impl VoiceConfig {
72 pub fn load(path: &Path) -> Result<Self, PiperError> {
74 let content = std::fs::read_to_string(path).map_err(|_| PiperError::ConfigNotFound {
75 path: path.display().to_string(),
76 })?;
77 let config: VoiceConfig = serde_json::from_str(&content)?;
78 Ok(config)
79 }
80
81 pub fn is_multi_speaker(&self) -> bool {
83 self.num_speakers > 1
84 }
85
86 pub fn is_multilingual(&self) -> bool {
88 self.num_languages > 1
89 }
90
91 pub fn needs_sid(&self) -> bool {
93 self.is_multi_speaker() || self.is_multilingual()
94 }
95
96 pub fn needs_lid(&self) -> bool {
98 self.is_multilingual()
99 }
100
101 pub fn needs_prosody(&self) -> bool {
103 self.phoneme_type == PhonemeType::OpenJTalk
106 || self.phoneme_type == PhonemeType::Bilingual
107 || self.phoneme_type == PhonemeType::Multilingual
108 }
109
110 pub fn resolve_config_path(
115 model_path: &Path,
116 explicit_config: Option<&Path>,
117 ) -> Result<std::path::PathBuf, PiperError> {
118 if let Some(p) = explicit_config {
119 if p.exists() {
120 return Ok(p.to_path_buf());
121 }
122 return Err(PiperError::ConfigNotFound {
123 path: p.display().to_string(),
124 });
125 }
126
127 let onnx_json = model_path.with_extension("onnx.json");
129 if onnx_json.exists() {
130 return Ok(onnx_json);
131 }
132
133 if let Some(dir) = model_path.parent() {
135 let dir_config = dir.join("config.json");
136 if dir_config.exists() {
137 return Ok(dir_config);
138 }
139 }
140
141 Err(PiperError::ConfigNotFound {
142 path: format!("no config found for {}", model_path.display()),
143 })
144 }
145
146 pub fn validate(&self) -> Result<(), String> {
149 if self.phoneme_id_map.is_empty() {
151 return Err("phoneme_id_map is empty".to_string());
152 }
153
154 if !self.phoneme_id_map.contains_key("^") {
156 return Err("phoneme_id_map missing required BOS marker '^'".to_string());
157 }
158 if !self.phoneme_id_map.contains_key("_") {
159 return Err("phoneme_id_map missing required PAD marker '_'".to_string());
160 }
161 if !self.phoneme_id_map.contains_key("$") {
162 return Err("phoneme_id_map missing required EOS marker '$'".to_string());
163 }
164
165 for (key, ids) in &self.phoneme_id_map {
167 if ids.is_empty() {
168 return Err(format!("phoneme_id_map[\"{key}\"] has empty ID list"));
169 }
170 }
171
172 if self.audio.sample_rate < 8000 || self.audio.sample_rate > 48000 {
174 return Err(format!(
175 "audio.sample_rate={} out of range [8000, 48000]",
176 self.audio.sample_rate
177 ));
178 }
179
180 if matches!(
182 self.phoneme_type,
183 PhonemeType::Multilingual | PhonemeType::Bilingual
184 ) {
185 if self.language_id_map.is_empty() {
186 return Err("multilingual model requires non-empty language_id_map".to_string());
187 }
188 if self.num_languages > 1 && self.language_id_map.len() != self.num_languages {
189 return Err(format!(
190 "num_languages={} but language_id_map has {} entries",
191 self.num_languages,
192 self.language_id_map.len()
193 ));
194 }
195 }
196
197 if self.num_speakers > 1 && self.speaker_id_map.is_empty() {
199 eprintln!(
200 "warning: num_speakers={} but speaker_id_map is empty",
201 self.num_speakers
202 );
203 }
204
205 Ok(())
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn test_deserialize_minimal_config() {
215 let json = r#"{"phoneme_id_map": {"a": [1]}, "audio": {"sample_rate": 22050}}"#;
216 let config: VoiceConfig = serde_json::from_str(json).unwrap();
217 assert_eq!(config.audio.sample_rate, 22050);
218 assert_eq!(config.num_speakers, 1);
219 assert_eq!(config.num_languages, 1);
220 assert!(!config.is_multilingual());
221 assert!(!config.needs_lid());
222 }
223
224 #[test]
225 fn test_deserialize_multilingual_config() {
226 let json = r#"{
227 "num_speakers": 571,
228 "num_languages": 6,
229 "phoneme_type": "multilingual",
230 "phoneme_id_map": {"^": [1], "_": [0]},
231 "language_id_map": {"ja": 0, "en": 1, "zh": 2, "es": 3, "fr": 4, "pt": 5}
232 }"#;
233 let config: VoiceConfig = serde_json::from_str(json).unwrap();
234 assert!(config.is_multilingual());
235 assert!(config.needs_sid());
236 assert!(config.needs_lid());
237 assert_eq!(config.language_id_map.len(), 6);
238 }
239
240 #[test]
241 fn test_phoneme_type_deserialization() {
242 let json = r#"{"phoneme_type": "openjtalk"}"#;
243 let config: VoiceConfig = serde_json::from_str(json).unwrap();
244 assert_eq!(config.phoneme_type, PhonemeType::OpenJTalk);
245 }
246
247 #[test]
248 fn test_validate_minimal_valid() {
249 let json = r#"{
250 "phoneme_id_map": {"^": [1], "_": [0], "$": [2], "a": [15]},
251 "audio": {"sample_rate": 22050}
252 }"#;
253 let config: VoiceConfig = serde_json::from_str(json).unwrap();
254 assert!(config.validate().is_ok());
255 }
256
257 #[test]
258 fn test_validate_empty_phoneme_id_map() {
259 let json = r#"{"phoneme_id_map": {}, "audio": {"sample_rate": 22050}}"#;
260 let config: VoiceConfig = serde_json::from_str(json).unwrap();
261 let err = config.validate().unwrap_err();
262 assert!(err.contains("empty"), "Error: {err}");
263 }
264
265 #[test]
266 fn test_validate_missing_bos() {
267 let json = r#"{"phoneme_id_map": {"_": [0], "$": [2]}, "audio": {"sample_rate": 22050}}"#;
268 let config: VoiceConfig = serde_json::from_str(json).unwrap();
269 let err = config.validate().unwrap_err();
270 assert!(err.contains("BOS"), "Error: {err}");
271 }
272
273 #[test]
274 fn test_validate_missing_pad() {
275 let json = r#"{"phoneme_id_map": {"^": [1], "$": [2]}, "audio": {"sample_rate": 22050}}"#;
276 let config: VoiceConfig = serde_json::from_str(json).unwrap();
277 let err = config.validate().unwrap_err();
278 assert!(err.contains("PAD"), "Error: {err}");
279 }
280
281 #[test]
282 fn test_validate_missing_eos() {
283 let json = r#"{"phoneme_id_map": {"^": [1], "_": [0]}, "audio": {"sample_rate": 22050}}"#;
284 let config: VoiceConfig = serde_json::from_str(json).unwrap();
285 let err = config.validate().unwrap_err();
286 assert!(err.contains("EOS"), "Error: {err}");
287 }
288
289 #[test]
290 fn test_validate_empty_id_list() {
291 let json = r#"{"phoneme_id_map": {"^": [1], "_": [0], "$": [2], "a": []}, "audio": {"sample_rate": 22050}}"#;
292 let config: VoiceConfig = serde_json::from_str(json).unwrap();
293 let err = config.validate().unwrap_err();
294 assert!(err.contains("empty ID list"), "Error: {err}");
295 }
296
297 #[test]
298 fn test_validate_sample_rate_zero() {
299 let json =
300 r#"{"phoneme_id_map": {"^": [1], "_": [0], "$": [2]}, "audio": {"sample_rate": 0}}"#;
301 let config: VoiceConfig = serde_json::from_str(json).unwrap();
302 let err = config.validate().unwrap_err();
303 assert!(err.contains("out of range"), "Error: {err}");
304 }
305
306 #[test]
307 fn test_validate_sample_rate_too_high() {
308 let json = r#"{"phoneme_id_map": {"^": [1], "_": [0], "$": [2]}, "audio": {"sample_rate": 100000}}"#;
309 let config: VoiceConfig = serde_json::from_str(json).unwrap();
310 let err = config.validate().unwrap_err();
311 assert!(err.contains("out of range"), "Error: {err}");
312 }
313
314 #[test]
315 fn test_validate_multilingual_empty_lang_map() {
316 let json = r#"{
317 "phoneme_id_map": {"^": [1], "_": [0], "$": [2]},
318 "audio": {"sample_rate": 22050},
319 "phoneme_type": "multilingual",
320 "num_languages": 6,
321 "language_id_map": {}
322 }"#;
323 let config: VoiceConfig = serde_json::from_str(json).unwrap();
324 let err = config.validate().unwrap_err();
325 assert!(err.contains("requires non-empty"), "Error: {err}");
326 }
327
328 #[test]
329 fn test_validate_multilingual_valid() {
330 let json = r#"{
331 "phoneme_id_map": {"^": [1], "_": [0], "$": [2], "a": [15]},
332 "audio": {"sample_rate": 22050},
333 "phoneme_type": "multilingual",
334 "num_languages": 6,
335 "language_id_map": {"ja": 0, "en": 1, "zh": 2, "es": 3, "fr": 4, "pt": 5}
336 }"#;
337 let config: VoiceConfig = serde_json::from_str(json).unwrap();
338 assert!(config.validate().is_ok());
339 }
340
341 #[test]
342 fn test_validate_single_lang_empty_lang_map_ok() {
343 let json = r#"{
344 "phoneme_id_map": {"^": [1], "_": [0], "$": [2]},
345 "audio": {"sample_rate": 22050},
346 "num_languages": 1,
347 "language_id_map": {}
348 }"#;
349 let config: VoiceConfig = serde_json::from_str(json).unwrap();
350 assert!(config.validate().is_ok());
351 }
352}