Skip to main content

piper_plus/
config.rs

1use serde::Deserialize;
2use std::collections::HashMap;
3use std::path::Path;
4
5use crate::error::PiperError;
6
7pub type PhonemeIdMap = HashMap<String, Vec<i64>>;
8
9#[derive(Debug, Clone, Deserialize)]
10pub struct VoiceConfig {
11    #[serde(default)]
12    pub audio: AudioConfig,
13
14    #[serde(default = "default_num_speakers")]
15    pub num_speakers: usize,
16
17    #[serde(default)]
18    pub num_symbols: usize,
19
20    #[serde(default)]
21    pub phoneme_type: PhonemeType,
22
23    #[serde(default)]
24    pub phoneme_id_map: PhonemeIdMap,
25
26    #[serde(default = "default_num_languages")]
27    pub num_languages: usize,
28
29    #[serde(default)]
30    pub language_id_map: HashMap<String, i64>,
31
32    #[serde(default)]
33    pub speaker_id_map: HashMap<String, i64>,
34}
35
36#[derive(Debug, Clone, Deserialize)]
37pub struct AudioConfig {
38    #[serde(default = "default_sample_rate")]
39    pub sample_rate: u32,
40}
41
42impl Default for AudioConfig {
43    fn default() -> Self {
44        Self { sample_rate: 22050 }
45    }
46}
47
48#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
49#[serde(rename_all = "lowercase")]
50pub enum PhonemeType {
51    #[default]
52    #[serde(alias = "espeak")]
53    Espeak,
54    #[serde(alias = "openjtalk")]
55    OpenJTalk,
56    Bilingual,
57    Multilingual,
58    Text,
59}
60
61fn default_num_speakers() -> usize {
62    1
63}
64fn default_num_languages() -> usize {
65    1
66}
67fn default_sample_rate() -> u32 {
68    22050
69}
70
71impl VoiceConfig {
72    /// config.json を読み込む
73    pub fn load(path: &Path) -> Result<Self, PiperError> {
74        let content = std::fs::read_to_string(path).map_err(|_| PiperError::ConfigNotFound {
75            path: path.display().to_string(),
76        })?;
77        let config: VoiceConfig = serde_json::from_str(&content)?;
78        Ok(config)
79    }
80
81    /// モデルがマルチスピーカーか
82    pub fn is_multi_speaker(&self) -> bool {
83        self.num_speakers > 1
84    }
85
86    /// モデルが多言語か
87    pub fn is_multilingual(&self) -> bool {
88        self.num_languages > 1
89    }
90
91    /// sid テンソルが必要か
92    pub fn needs_sid(&self) -> bool {
93        self.is_multi_speaker() || self.is_multilingual()
94    }
95
96    /// lid テンソルが必要か
97    pub fn needs_lid(&self) -> bool {
98        self.is_multilingual()
99    }
100
101    /// prosody_features テンソルが必要か (phoneme_id_map に prosody 関連キーがあるか)
102    pub fn needs_prosody(&self) -> bool {
103        // prosody_features の有無はONNXモデルの入力ノードで判定するのが正確
104        // ここではconfig情報からのヒューリスティック
105        self.phoneme_type == PhonemeType::OpenJTalk
106            || self.phoneme_type == PhonemeType::Bilingual
107            || self.phoneme_type == PhonemeType::Multilingual
108    }
109
110    /// config.json のフォールバック検索
111    /// 1. --config で明示指定
112    /// 2. {model}.onnx.json
113    /// 3. {model_dir}/config.json
114    pub fn resolve_config_path(
115        model_path: &Path,
116        explicit_config: Option<&Path>,
117    ) -> Result<std::path::PathBuf, PiperError> {
118        if let Some(p) = explicit_config {
119            if p.exists() {
120                return Ok(p.to_path_buf());
121            }
122            return Err(PiperError::ConfigNotFound {
123                path: p.display().to_string(),
124            });
125        }
126
127        // {model}.onnx.json
128        let onnx_json = model_path.with_extension("onnx.json");
129        if onnx_json.exists() {
130            return Ok(onnx_json);
131        }
132
133        // {model_dir}/config.json
134        if let Some(dir) = model_path.parent() {
135            let dir_config = dir.join("config.json");
136            if dir_config.exists() {
137                return Ok(dir_config);
138            }
139        }
140
141        Err(PiperError::ConfigNotFound {
142            path: format!("no config found for {}", model_path.display()),
143        })
144    }
145
146    /// Validate the config for correctness.
147    /// Returns Ok(()) if valid, or Err with a description of the first problem found.
148    pub fn validate(&self) -> Result<(), String> {
149        // 1. phoneme_id_map must not be empty
150        if self.phoneme_id_map.is_empty() {
151            return Err("phoneme_id_map is empty".to_string());
152        }
153
154        // 2-4. Required markers
155        if !self.phoneme_id_map.contains_key("^") {
156            return Err("phoneme_id_map missing required BOS marker '^'".to_string());
157        }
158        if !self.phoneme_id_map.contains_key("_") {
159            return Err("phoneme_id_map missing required PAD marker '_'".to_string());
160        }
161        if !self.phoneme_id_map.contains_key("$") {
162            return Err("phoneme_id_map missing required EOS marker '$'".to_string());
163        }
164
165        // 5. Each ID list must be non-empty
166        for (key, ids) in &self.phoneme_id_map {
167            if ids.is_empty() {
168                return Err(format!("phoneme_id_map[\"{key}\"] has empty ID list"));
169            }
170        }
171
172        // 6. sample_rate range check
173        if self.audio.sample_rate < 8000 || self.audio.sample_rate > 48000 {
174            return Err(format!(
175                "audio.sample_rate={} out of range [8000, 48000]",
176                self.audio.sample_rate
177            ));
178        }
179
180        // 7-8. Multilingual/Bilingual require non-empty language_id_map
181        if matches!(
182            self.phoneme_type,
183            PhonemeType::Multilingual | PhonemeType::Bilingual
184        ) {
185            if self.language_id_map.is_empty() {
186                return Err("multilingual model requires non-empty language_id_map".to_string());
187            }
188            if self.num_languages > 1 && self.language_id_map.len() != self.num_languages {
189                return Err(format!(
190                    "num_languages={} but language_id_map has {} entries",
191                    self.num_languages,
192                    self.language_id_map.len()
193                ));
194            }
195        }
196
197        // 9. speaker_id_map warning (non-blocking)
198        if self.num_speakers > 1 && self.speaker_id_map.is_empty() {
199            eprintln!(
200                "warning: num_speakers={} but speaker_id_map is empty",
201                self.num_speakers
202            );
203        }
204
205        Ok(())
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn test_deserialize_minimal_config() {
215        let json = r#"{"phoneme_id_map": {"a": [1]}, "audio": {"sample_rate": 22050}}"#;
216        let config: VoiceConfig = serde_json::from_str(json).unwrap();
217        assert_eq!(config.audio.sample_rate, 22050);
218        assert_eq!(config.num_speakers, 1);
219        assert_eq!(config.num_languages, 1);
220        assert!(!config.is_multilingual());
221        assert!(!config.needs_lid());
222    }
223
224    #[test]
225    fn test_deserialize_multilingual_config() {
226        let json = r#"{
227            "num_speakers": 571,
228            "num_languages": 6,
229            "phoneme_type": "multilingual",
230            "phoneme_id_map": {"^": [1], "_": [0]},
231            "language_id_map": {"ja": 0, "en": 1, "zh": 2, "es": 3, "fr": 4, "pt": 5}
232        }"#;
233        let config: VoiceConfig = serde_json::from_str(json).unwrap();
234        assert!(config.is_multilingual());
235        assert!(config.needs_sid());
236        assert!(config.needs_lid());
237        assert_eq!(config.language_id_map.len(), 6);
238    }
239
240    #[test]
241    fn test_phoneme_type_deserialization() {
242        let json = r#"{"phoneme_type": "openjtalk"}"#;
243        let config: VoiceConfig = serde_json::from_str(json).unwrap();
244        assert_eq!(config.phoneme_type, PhonemeType::OpenJTalk);
245    }
246
247    #[test]
248    fn test_validate_minimal_valid() {
249        let json = r#"{
250            "phoneme_id_map": {"^": [1], "_": [0], "$": [2], "a": [15]},
251            "audio": {"sample_rate": 22050}
252        }"#;
253        let config: VoiceConfig = serde_json::from_str(json).unwrap();
254        assert!(config.validate().is_ok());
255    }
256
257    #[test]
258    fn test_validate_empty_phoneme_id_map() {
259        let json = r#"{"phoneme_id_map": {}, "audio": {"sample_rate": 22050}}"#;
260        let config: VoiceConfig = serde_json::from_str(json).unwrap();
261        let err = config.validate().unwrap_err();
262        assert!(err.contains("empty"), "Error: {err}");
263    }
264
265    #[test]
266    fn test_validate_missing_bos() {
267        let json = r#"{"phoneme_id_map": {"_": [0], "$": [2]}, "audio": {"sample_rate": 22050}}"#;
268        let config: VoiceConfig = serde_json::from_str(json).unwrap();
269        let err = config.validate().unwrap_err();
270        assert!(err.contains("BOS"), "Error: {err}");
271    }
272
273    #[test]
274    fn test_validate_missing_pad() {
275        let json = r#"{"phoneme_id_map": {"^": [1], "$": [2]}, "audio": {"sample_rate": 22050}}"#;
276        let config: VoiceConfig = serde_json::from_str(json).unwrap();
277        let err = config.validate().unwrap_err();
278        assert!(err.contains("PAD"), "Error: {err}");
279    }
280
281    #[test]
282    fn test_validate_missing_eos() {
283        let json = r#"{"phoneme_id_map": {"^": [1], "_": [0]}, "audio": {"sample_rate": 22050}}"#;
284        let config: VoiceConfig = serde_json::from_str(json).unwrap();
285        let err = config.validate().unwrap_err();
286        assert!(err.contains("EOS"), "Error: {err}");
287    }
288
289    #[test]
290    fn test_validate_empty_id_list() {
291        let json = r#"{"phoneme_id_map": {"^": [1], "_": [0], "$": [2], "a": []}, "audio": {"sample_rate": 22050}}"#;
292        let config: VoiceConfig = serde_json::from_str(json).unwrap();
293        let err = config.validate().unwrap_err();
294        assert!(err.contains("empty ID list"), "Error: {err}");
295    }
296
297    #[test]
298    fn test_validate_sample_rate_zero() {
299        let json =
300            r#"{"phoneme_id_map": {"^": [1], "_": [0], "$": [2]}, "audio": {"sample_rate": 0}}"#;
301        let config: VoiceConfig = serde_json::from_str(json).unwrap();
302        let err = config.validate().unwrap_err();
303        assert!(err.contains("out of range"), "Error: {err}");
304    }
305
306    #[test]
307    fn test_validate_sample_rate_too_high() {
308        let json = r#"{"phoneme_id_map": {"^": [1], "_": [0], "$": [2]}, "audio": {"sample_rate": 100000}}"#;
309        let config: VoiceConfig = serde_json::from_str(json).unwrap();
310        let err = config.validate().unwrap_err();
311        assert!(err.contains("out of range"), "Error: {err}");
312    }
313
314    #[test]
315    fn test_validate_multilingual_empty_lang_map() {
316        let json = r#"{
317            "phoneme_id_map": {"^": [1], "_": [0], "$": [2]},
318            "audio": {"sample_rate": 22050},
319            "phoneme_type": "multilingual",
320            "num_languages": 6,
321            "language_id_map": {}
322        }"#;
323        let config: VoiceConfig = serde_json::from_str(json).unwrap();
324        let err = config.validate().unwrap_err();
325        assert!(err.contains("requires non-empty"), "Error: {err}");
326    }
327
328    #[test]
329    fn test_validate_multilingual_valid() {
330        let json = r#"{
331            "phoneme_id_map": {"^": [1], "_": [0], "$": [2], "a": [15]},
332            "audio": {"sample_rate": 22050},
333            "phoneme_type": "multilingual",
334            "num_languages": 6,
335            "language_id_map": {"ja": 0, "en": 1, "zh": 2, "es": 3, "fr": 4, "pt": 5}
336        }"#;
337        let config: VoiceConfig = serde_json::from_str(json).unwrap();
338        assert!(config.validate().is_ok());
339    }
340
341    #[test]
342    fn test_validate_single_lang_empty_lang_map_ok() {
343        let json = r#"{
344            "phoneme_id_map": {"^": [1], "_": [0], "$": [2]},
345            "audio": {"sample_rate": 22050},
346            "num_languages": 1,
347            "language_id_map": {}
348        }"#;
349        let config: VoiceConfig = serde_json::from_str(json).unwrap();
350        assert!(config.validate().is_ok());
351    }
352}