Skip to main content

piper_plus/
config.rs

1use serde::Deserialize;
2use std::collections::HashMap;
3use std::path::Path;
4
5use crate::error::PiperError;
6
7pub type PhonemeIdMap = HashMap<String, Vec<i64>>;
8
9#[derive(Debug, Clone, Deserialize)]
10pub struct VoiceConfig {
11    #[serde(default)]
12    pub audio: AudioConfig,
13
14    #[serde(default = "default_num_speakers")]
15    pub num_speakers: usize,
16
17    #[serde(default)]
18    pub num_symbols: usize,
19
20    #[serde(default)]
21    pub phoneme_type: PhonemeType,
22
23    #[serde(default)]
24    pub phoneme_id_map: PhonemeIdMap,
25
26    #[serde(default = "default_num_languages")]
27    pub num_languages: usize,
28
29    #[serde(default)]
30    pub language_id_map: HashMap<String, i64>,
31
32    #[serde(default)]
33    pub speaker_id_map: HashMap<String, i64>,
34}
35
36#[derive(Debug, Clone, Deserialize)]
37pub struct AudioConfig {
38    #[serde(default = "default_sample_rate")]
39    pub sample_rate: u32,
40}
41
42impl Default for AudioConfig {
43    fn default() -> Self {
44        Self { sample_rate: 22050 }
45    }
46}
47
48#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
49#[serde(rename_all = "lowercase")]
50pub enum PhonemeType {
51    #[default]
52    #[serde(alias = "espeak")]
53    Espeak,
54    #[serde(alias = "openjtalk")]
55    OpenJTalk,
56    Bilingual,
57    Multilingual,
58    Text,
59}
60
61fn default_num_speakers() -> usize {
62    1
63}
64fn default_num_languages() -> usize {
65    1
66}
67fn default_sample_rate() -> u32 {
68    22050
69}
70
71impl VoiceConfig {
72    /// config.json を読み込む
73    pub fn load(path: &Path) -> Result<Self, PiperError> {
74        let content = std::fs::read_to_string(path).map_err(|_| PiperError::ConfigNotFound {
75            path: path.display().to_string(),
76        })?;
77        let config: VoiceConfig = serde_json::from_str(&content)?;
78        Ok(config)
79    }
80
81    /// モデルがマルチスピーカーか
82    pub fn is_multi_speaker(&self) -> bool {
83        self.num_speakers > 1
84    }
85
86    /// モデルが多言語か
87    pub fn is_multilingual(&self) -> bool {
88        self.num_languages > 1
89    }
90
91    /// sid テンソルが必要か
92    pub fn needs_sid(&self) -> bool {
93        self.is_multi_speaker() || self.is_multilingual()
94    }
95
96    /// lid テンソルが必要か
97    pub fn needs_lid(&self) -> bool {
98        self.is_multilingual()
99    }
100
101    /// prosody_features テンソルが必要か (phoneme_id_map に prosody 関連キーがあるか)
102    pub fn needs_prosody(&self) -> bool {
103        // prosody_features の有無はONNXモデルの入力ノードで判定するのが正確
104        // ここではconfig情報からのヒューリスティック
105        self.phoneme_type == PhonemeType::OpenJTalk
106            || self.phoneme_type == PhonemeType::Bilingual
107            || self.phoneme_type == PhonemeType::Multilingual
108    }
109
110    /// config.json のフォールバック検索
111    /// 1. --config で明示指定
112    /// 2. {model}.onnx.json
113    /// 3. {model_dir}/config.json
114    pub fn resolve_config_path(
115        model_path: &Path,
116        explicit_config: Option<&Path>,
117    ) -> Result<std::path::PathBuf, PiperError> {
118        if let Some(p) = explicit_config {
119            if p.exists() {
120                return Ok(p.to_path_buf());
121            }
122            return Err(PiperError::ConfigNotFound {
123                path: p.display().to_string(),
124            });
125        }
126
127        // {model}.onnx.json
128        let onnx_json = model_path.with_extension("onnx.json");
129        if onnx_json.exists() {
130            return Ok(onnx_json);
131        }
132
133        // {model_dir}/config.json
134        if let Some(dir) = model_path.parent() {
135            let dir_config = dir.join("config.json");
136            if dir_config.exists() {
137                return Ok(dir_config);
138            }
139        }
140
141        Err(PiperError::ConfigNotFound {
142            path: format!("no config found for {}", model_path.display()),
143        })
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn test_deserialize_minimal_config() {
153        let json = r#"{"phoneme_id_map": {"a": [1]}, "audio": {"sample_rate": 22050}}"#;
154        let config: VoiceConfig = serde_json::from_str(json).unwrap();
155        assert_eq!(config.audio.sample_rate, 22050);
156        assert_eq!(config.num_speakers, 1);
157        assert_eq!(config.num_languages, 1);
158        assert!(!config.is_multilingual());
159        assert!(!config.needs_lid());
160    }
161
162    #[test]
163    fn test_deserialize_multilingual_config() {
164        let json = r#"{
165            "num_speakers": 571,
166            "num_languages": 6,
167            "phoneme_type": "multilingual",
168            "phoneme_id_map": {"^": [1], "_": [0]},
169            "language_id_map": {"ja": 0, "en": 1, "zh": 2, "es": 3, "fr": 4, "pt": 5}
170        }"#;
171        let config: VoiceConfig = serde_json::from_str(json).unwrap();
172        assert!(config.is_multilingual());
173        assert!(config.needs_sid());
174        assert!(config.needs_lid());
175        assert_eq!(config.language_id_map.len(), 6);
176    }
177
178    #[test]
179    fn test_phoneme_type_deserialization() {
180        let json = r#"{"phoneme_type": "openjtalk"}"#;
181        let config: VoiceConfig = serde_json::from_str(json).unwrap();
182        assert_eq!(config.phoneme_type, PhonemeType::OpenJTalk);
183    }
184}