use crate::error::{Error, Result};
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KokoroConfig {
#[serde(default = "default_n_symbols")]
pub n_symbols: usize,
#[serde(default = "default_hidden_dim")]
pub hidden_dim: usize,
#[serde(default = "default_style_dim")]
pub style_dim: usize,
#[serde(default = "default_voice_style_dim")]
pub voice_style_dim: usize,
#[serde(default = "default_text_conv_kernel")]
pub text_conv_kernel: usize,
#[serde(default = "default_text_conv_depth")]
pub text_conv_depth: usize,
#[serde(default = "default_max_dur")]
pub max_dur: usize,
#[serde(default = "default_leaky_slope")]
pub leaky_slope: f64,
#[serde(default = "default_sample_rate")]
pub sample_rate: u32,
#[serde(default = "default_n_fft")]
pub n_fft: usize,
#[serde(default = "default_hop_length")]
pub hop_length: usize,
#[serde(default = "default_upsample_ratios")]
pub upsample_ratios: Vec<usize>,
#[serde(default = "default_upsample_kernel_sizes")]
pub upsample_kernel_sizes: Vec<usize>,
#[serde(default = "default_upsample_initial_channel")]
pub upsample_initial_channel: usize,
#[serde(default = "default_resblock_kernel_sizes")]
pub resblock_kernel_sizes: Vec<usize>,
#[serde(default = "default_resblock_dilation_sizes")]
pub resblock_dilation_sizes: Vec<Vec<usize>>,
#[serde(default = "default_harmonic_num")]
pub harmonic_num: usize,
#[serde(default = "default_bert_hidden_size")]
pub bert_hidden_size: usize,
#[serde(default = "default_bert_num_layers")]
pub bert_num_layers: usize,
#[serde(default = "default_bert_num_heads")]
pub bert_num_heads: usize,
#[serde(default = "default_bert_embedding_size")]
pub bert_embedding_size: usize,
#[serde(default = "default_bert_intermediate_size")]
pub bert_intermediate_size: usize,
}
fn default_n_symbols() -> usize {
178
}
fn default_hidden_dim() -> usize {
512
}
fn default_style_dim() -> usize {
128
}
fn default_voice_style_dim() -> usize {
256
}
fn default_text_conv_kernel() -> usize {
5
}
fn default_text_conv_depth() -> usize {
3
}
fn default_max_dur() -> usize {
50
}
fn default_leaky_slope() -> f64 {
0.2
}
fn default_sample_rate() -> u32 {
24_000
}
fn default_n_fft() -> usize {
20
}
fn default_hop_length() -> usize {
5
}
fn default_upsample_ratios() -> Vec<usize> {
vec![10, 6]
}
fn default_upsample_kernel_sizes() -> Vec<usize> {
vec![20, 12]
}
fn default_upsample_initial_channel() -> usize {
512
}
fn default_resblock_kernel_sizes() -> Vec<usize> {
vec![3, 7, 11]
}
fn default_resblock_dilation_sizes() -> Vec<Vec<usize>> {
vec![vec![1, 3, 5], vec![1, 3, 5], vec![1, 3, 5]]
}
fn default_harmonic_num() -> usize {
8
}
fn default_bert_hidden_size() -> usize {
768
}
fn default_bert_num_layers() -> usize {
12
}
fn default_bert_num_heads() -> usize {
12
}
fn default_bert_embedding_size() -> usize {
128
}
fn default_bert_intermediate_size() -> usize {
2048
}
impl Default for KokoroConfig {
fn default() -> Self {
Self {
n_symbols: default_n_symbols(),
hidden_dim: default_hidden_dim(),
style_dim: default_style_dim(),
voice_style_dim: default_voice_style_dim(),
text_conv_kernel: default_text_conv_kernel(),
text_conv_depth: default_text_conv_depth(),
max_dur: default_max_dur(),
leaky_slope: default_leaky_slope(),
sample_rate: default_sample_rate(),
n_fft: default_n_fft(),
hop_length: default_hop_length(),
upsample_ratios: default_upsample_ratios(),
upsample_kernel_sizes: default_upsample_kernel_sizes(),
upsample_initial_channel: default_upsample_initial_channel(),
resblock_kernel_sizes: default_resblock_kernel_sizes(),
resblock_dilation_sizes: default_resblock_dilation_sizes(),
harmonic_num: default_harmonic_num(),
bert_hidden_size: default_bert_hidden_size(),
bert_num_layers: default_bert_num_layers(),
bert_num_heads: default_bert_num_heads(),
bert_embedding_size: default_bert_embedding_size(),
bert_intermediate_size: default_bert_intermediate_size(),
}
}
}
impl KokoroConfig {
pub fn from_json_file(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let bytes = std::fs::read(path).map_err(|e| Error::ModelError {
reason: format!("reading kokoro config {}: {e}", path.display()),
})?;
Self::from_json_bytes(&bytes)
}
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(|e| Error::ModelError {
reason: format!("invalid kokoro config.json: {e}"),
})
}
pub fn total_upsample(&self) -> usize {
self.upsample_ratios.iter().product()
}
pub fn validate(&self) -> Result<()> {
if self.hidden_dim == 0 || !self.hidden_dim.is_multiple_of(2) {
return Err(Error::ModelError {
reason: format!(
"hidden_dim must be positive and even (BiLSTM splits it), got {}",
self.hidden_dim
),
});
}
if self.n_symbols == 0 {
return Err(Error::ModelError {
reason: "n_symbols must be > 0".into(),
});
}
if self.sample_rate == 0 {
return Err(Error::ModelError {
reason: "sample_rate must be > 0".into(),
});
}
if self.n_fft == 0 || self.hop_length == 0 {
return Err(Error::ModelError {
reason: "n_fft and hop_length must be > 0".into(),
});
}
if self.upsample_ratios.is_empty() {
return Err(Error::ModelError {
reason: "upsample_ratios must have at least one entry".into(),
});
}
if self.upsample_ratios.contains(&0) {
return Err(Error::ModelError {
reason: "upsample_ratios entries must be > 0".into(),
});
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn defaults_match_kokoro_82m() {
let cfg = KokoroConfig::default();
assert_eq!(cfg.n_symbols, 178);
assert_eq!(cfg.hidden_dim, 512);
assert_eq!(cfg.style_dim, 128);
assert_eq!(cfg.voice_style_dim, 256);
assert_eq!(cfg.voice_style_dim, 2 * cfg.style_dim);
assert_eq!(cfg.max_dur, 50);
assert_eq!(cfg.sample_rate, 24_000);
assert_eq!(cfg.upsample_ratios, vec![10, 6]);
assert_eq!(cfg.upsample_kernel_sizes, vec![20, 12]);
assert_eq!(cfg.resblock_kernel_sizes, vec![3, 7, 11]);
assert_eq!(cfg.harmonic_num, 8);
assert_eq!(cfg.bert_hidden_size, 768);
cfg.validate().unwrap();
}
#[test]
fn parses_minimal_json_with_all_defaults() {
let json = b"{}";
let cfg = KokoroConfig::from_json_bytes(json).unwrap();
assert_eq!(cfg.n_symbols, default_n_symbols());
}
#[test]
fn partial_overrides_keep_remaining_defaults() {
let json = br#"{"n_symbols": 200, "sample_rate": 22050}"#;
let cfg = KokoroConfig::from_json_bytes(json).unwrap();
assert_eq!(cfg.n_symbols, 200);
assert_eq!(cfg.sample_rate, 22_050);
assert_eq!(cfg.hidden_dim, default_hidden_dim());
}
#[test]
fn total_upsample_multiplies_ratios() {
let cfg = KokoroConfig {
upsample_ratios: vec![2, 3, 5],
..Default::default()
};
assert_eq!(cfg.total_upsample(), 30);
}
#[test]
fn validate_rejects_odd_hidden_dim() {
let cfg = KokoroConfig {
hidden_dim: 7,
..Default::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn validate_rejects_empty_upsample() {
let cfg = KokoroConfig {
upsample_ratios: vec![],
..Default::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn validate_rejects_zero_upsample_entry() {
let cfg = KokoroConfig {
upsample_ratios: vec![2, 0, 5],
..Default::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn invalid_json_surfaces_model_error() {
let err = KokoroConfig::from_json_bytes(b"not json").unwrap_err();
match err {
Error::ModelError { reason } => {
assert!(reason.contains("invalid kokoro config.json"));
}
other => panic!("unexpected error variant: {other:?}"),
}
}
}