use serde::Deserialize;
use std::collections::HashMap;
#[derive(Debug, Clone, Deserialize)]
pub struct KokoroConfig {
#[serde(default = "default_hidden_dim")]
pub hidden_dim: usize,
#[serde(default = "default_style_dim")]
pub style_dim: usize,
#[serde(default = "default_n_mels")]
pub n_mels: usize,
#[serde(default = "default_n_token")]
pub n_token: usize,
#[serde(default = "default_n_layer")]
pub n_layer: usize,
#[serde(default = "default_dim_in")]
pub dim_in: usize,
#[serde(default = "default_dropout")]
pub dropout: f64,
#[serde(default = "default_max_conv_dim")]
pub max_conv_dim: usize,
#[serde(default = "default_max_dur")]
pub max_dur: usize,
#[serde(default = "default_multispeaker")]
pub multispeaker: bool,
#[serde(default = "default_text_encoder_kernel_size")]
pub text_encoder_kernel_size: usize,
#[serde(default)]
pub plbert: PlbertConfig,
#[serde(default)]
pub istftnet: IstftNetConfig,
#[serde(default)]
pub vocab: HashMap<String, u32>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct PlbertConfig {
#[serde(default = "default_plbert_vocab_size")]
pub vocab_size: usize,
#[serde(default = "default_plbert_hidden_size")]
pub hidden_size: usize,
#[serde(default = "default_plbert_num_heads")]
pub num_attention_heads: usize,
#[serde(default = "default_plbert_num_layers")]
pub num_hidden_layers: usize,
#[serde(default = "default_plbert_intermediate_size")]
pub intermediate_size: usize,
#[serde(default = "default_plbert_max_position")]
pub max_position_embeddings: usize,
#[serde(default = "default_plbert_embedding_size")]
pub embedding_size: usize,
#[serde(default = "default_plbert_num_hidden_groups")]
pub num_hidden_groups: usize,
#[serde(default = "default_plbert_hidden_act")]
pub hidden_act: String,
#[serde(default = "default_plbert_dropout")]
pub hidden_dropout_prob: f64,
#[serde(default = "default_plbert_dropout")]
pub attention_probs_dropout_prob: f64,
#[serde(default = "default_plbert_type_vocab_size")]
pub type_vocab_size: usize,
}
#[derive(Debug, Clone, Deserialize)]
pub struct IstftNetConfig {
#[serde(default = "default_upsample_rates")]
pub upsample_rates: Vec<usize>,
#[serde(default = "default_upsample_kernel_sizes")]
pub upsample_kernel_sizes: Vec<usize>,
#[serde(default = "default_upsample_initial_channel")]
pub upsample_initial_channel: usize,
#[serde(default = "default_resblock_kernel_sizes")]
pub resblock_kernel_sizes: Vec<usize>,
#[serde(default = "default_resblock_dilation_sizes")]
pub resblock_dilation_sizes: Vec<Vec<usize>>,
#[serde(default = "default_gen_istft_n_fft")]
pub gen_istft_n_fft: usize,
#[serde(default = "default_gen_istft_hop_size")]
pub gen_istft_hop_size: usize,
}
fn default_hidden_dim() -> usize {
512
}
fn default_style_dim() -> usize {
128
}
fn default_n_mels() -> usize {
80
}
fn default_n_token() -> usize {
178
}
fn default_n_layer() -> usize {
3
}
fn default_dim_in() -> usize {
64
}
fn default_dropout() -> f64 {
0.2
}
fn default_max_conv_dim() -> usize {
512
}
fn default_max_dur() -> usize {
50
}
fn default_multispeaker() -> bool {
true
}
fn default_text_encoder_kernel_size() -> usize {
5
}
fn default_plbert_vocab_size() -> usize {
178
}
fn default_plbert_hidden_size() -> usize {
768
}
fn default_plbert_num_heads() -> usize {
12
}
fn default_plbert_num_layers() -> usize {
12
}
fn default_plbert_intermediate_size() -> usize {
2048
}
fn default_plbert_max_position() -> usize {
512
}
fn default_plbert_embedding_size() -> usize {
128
}
fn default_plbert_num_hidden_groups() -> usize {
1
}
fn default_plbert_hidden_act() -> String {
"gelu".to_string()
}
fn default_plbert_dropout() -> f64 {
0.1
}
fn default_plbert_type_vocab_size() -> usize {
2
}
fn default_upsample_rates() -> Vec<usize> {
vec![10, 6]
}
fn default_upsample_kernel_sizes() -> Vec<usize> {
vec![20, 12]
}
fn default_upsample_initial_channel() -> usize {
512
}
fn default_resblock_kernel_sizes() -> Vec<usize> {
vec![3, 7, 11]
}
fn default_resblock_dilation_sizes() -> Vec<Vec<usize>> {
vec![vec![1, 3, 5], vec![1, 3, 5], vec![1, 3, 5]]
}
fn default_gen_istft_n_fft() -> usize {
20
}
fn default_gen_istft_hop_size() -> usize {
5
}
impl Default for PlbertConfig {
fn default() -> Self {
serde_json::from_str("{}").unwrap()
}
}
impl Default for IstftNetConfig {
fn default() -> Self {
serde_json::from_str("{}").unwrap()
}
}
impl KokoroConfig {
pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self, crate::error::TtsError> {
Self::from_bytes(std::fs::read(path)?)
}
pub fn from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self, crate::error::TtsError> {
let config: Self = serde_json::from_slice(bytes.as_ref())?;
Ok(config)
}
pub fn upsample_factor(&self) -> usize {
let conv_factor: usize = self.istftnet.upsample_rates.iter().product();
conv_factor * self.istftnet.gen_istft_hop_size
}
pub fn full_style_dim(&self) -> usize {
self.style_dim * 2
}
pub fn sample_rate(&self) -> u32 {
24000
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let json = "{}";
let config: KokoroConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.hidden_dim, 512);
assert_eq!(config.style_dim, 128);
assert_eq!(config.n_token, 178);
assert_eq!(config.n_layer, 3);
}
#[test]
fn test_plbert_defaults() {
let json = r#"{"plbert": {}}"#;
let config: KokoroConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.plbert.hidden_size, 768);
assert_eq!(config.plbert.num_attention_heads, 12);
assert_eq!(config.plbert.num_hidden_layers, 12);
}
#[test]
fn test_istftnet_defaults() {
let json = r#"{"istftnet": {}}"#;
let config: KokoroConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.istftnet.upsample_rates, vec![10, 6]);
assert_eq!(config.istftnet.gen_istft_hop_size, 5);
}
#[test]
fn test_upsample_factor() {
let config: KokoroConfig = serde_json::from_str("{}").unwrap();
assert_eq!(config.upsample_factor(), 300);
}
#[test]
fn test_vocab_parsing() {
let json = r#"{
"vocab": {
";": 1,
"a": 2,
"b": 3
}
}"#;
let config: KokoroConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.vocab.len(), 3);
assert_eq!(config.vocab["a"], 2);
}
}