use serde::Deserialize;
use std::collections::HashMap;
#[derive(Debug, Clone, Deserialize)]
pub struct Qwen3TtsConfig {
#[serde(default = "default_model_type")]
pub model_type: String,
#[serde(default)]
pub talker_config: TalkerConfig,
#[serde(default)]
pub speaker_encoder_config: Option<SpeakerEncoderConfig>,
#[serde(default = "default_tokenizer_type")]
pub tokenizer_type: String,
#[serde(default)]
pub tts_model_size: Option<String>,
#[serde(default = "default_tts_model_type")]
pub tts_model_type: String,
#[serde(default = "default_im_start_token_id")]
pub im_start_token_id: u32,
#[serde(default = "default_im_end_token_id")]
pub im_end_token_id: u32,
#[serde(default = "default_tts_pad_token_id")]
pub tts_pad_token_id: u32,
#[serde(default = "default_tts_bos_token_id")]
pub tts_bos_token_id: u32,
#[serde(default = "default_tts_eos_token_id")]
pub tts_eos_token_id: u32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TalkerConfig {
#[serde(default = "default_talker_hidden_size")]
pub hidden_size: usize,
#[serde(default = "default_talker_num_layers")]
pub num_hidden_layers: usize,
#[serde(default = "default_talker_num_heads")]
pub num_attention_heads: usize,
#[serde(default = "default_talker_num_kv_heads")]
pub num_key_value_heads: usize,
#[serde(default = "default_head_dim")]
pub head_dim: usize,
#[serde(default = "default_talker_intermediate_size")]
pub intermediate_size: usize,
#[serde(default = "default_talker_vocab_size")]
pub vocab_size: usize,
#[serde(default = "default_text_vocab_size")]
pub text_vocab_size: usize,
#[serde(default = "default_text_hidden_size")]
pub text_hidden_size: usize,
#[serde(default = "default_num_code_groups")]
pub num_code_groups: usize,
#[serde(default = "default_rms_norm_eps")]
pub rms_norm_eps: f64,
#[serde(default = "default_rope_theta")]
pub rope_theta: f64,
#[serde(default = "default_max_position_embeddings")]
pub max_position_embeddings: usize,
#[serde(default = "default_hidden_act")]
pub hidden_act: String,
#[serde(default = "default_codec_eos_token_id")]
pub codec_eos_token_id: u32,
#[serde(default)]
pub codec_bos_id: Option<u32>,
#[serde(default)]
pub codec_pad_id: Option<u32>,
#[serde(default)]
pub codec_think_id: Option<u32>,
#[serde(default)]
pub codec_nothink_id: Option<u32>,
#[serde(default)]
pub codec_think_bos_id: Option<u32>,
#[serde(default)]
pub codec_think_eos_id: Option<u32>,
#[serde(default)]
pub position_id_per_seconds: Option<u32>,
#[serde(default)]
pub mrope_section: Option<Vec<usize>>,
#[serde(default)]
pub interleaved: Option<bool>,
#[serde(default)]
pub spk_id: HashMap<String, u32>,
#[serde(default)]
pub codec_language_id: HashMap<String, u32>,
#[serde(default)]
pub code_predictor_config: Option<CodePredictorConfig>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct CodePredictorConfig {
#[serde(default = "default_code_predictor_hidden")]
pub hidden_size: usize,
#[serde(default = "default_code_predictor_layers")]
pub num_hidden_layers: usize,
#[serde(default = "default_code_predictor_heads")]
pub num_attention_heads: usize,
#[serde(default = "default_code_predictor_kv_heads")]
pub num_key_value_heads: usize,
#[serde(default = "default_head_dim")]
pub head_dim: usize,
#[serde(default = "default_code_predictor_intermediate")]
pub intermediate_size: usize,
#[serde(default = "default_code_predictor_vocab")]
pub vocab_size: usize,
#[serde(default = "default_num_code_groups")]
pub num_code_groups: usize,
}
#[derive(Debug, Clone, Deserialize)]
pub struct SpeakerEncoderConfig {
#[serde(default = "default_speaker_sample_rate")]
pub sample_rate: u32,
}
fn default_model_type() -> String {
"qwen3_tts".to_string()
}
fn default_tokenizer_type() -> String {
"qwen3_tts_tokenizer_12hz".to_string()
}
fn default_tts_model_type() -> String {
"custom_voice".to_string()
}
fn default_im_start_token_id() -> u32 {
151644
}
fn default_im_end_token_id() -> u32 {
151645
}
fn default_tts_pad_token_id() -> u32 {
151671
}
fn default_tts_bos_token_id() -> u32 {
151672
}
fn default_tts_eos_token_id() -> u32 {
151673
}
fn default_talker_hidden_size() -> usize {
2048
}
fn default_talker_num_layers() -> usize {
28
}
fn default_talker_num_heads() -> usize {
16
}
fn default_talker_num_kv_heads() -> usize {
8
}
fn default_head_dim() -> usize {
128
}
fn default_talker_intermediate_size() -> usize {
6144
}
fn default_talker_vocab_size() -> usize {
3072
}
fn default_text_vocab_size() -> usize {
151936
}
fn default_text_hidden_size() -> usize {
2048
}
fn default_num_code_groups() -> usize {
16
}
fn default_rms_norm_eps() -> f64 {
1e-6
}
fn default_rope_theta() -> f64 {
1000000.0
}
fn default_max_position_embeddings() -> usize {
32768
}
fn default_hidden_act() -> String {
"silu".to_string()
}
fn default_codec_eos_token_id() -> u32 {
0
}
fn default_code_predictor_hidden() -> usize {
1024
}
fn default_code_predictor_layers() -> usize {
5
}
fn default_code_predictor_heads() -> usize {
16
}
fn default_code_predictor_kv_heads() -> usize {
8
}
fn default_code_predictor_intermediate() -> usize {
3072
}
fn default_code_predictor_vocab() -> usize {
2048
}
fn default_speaker_sample_rate() -> u32 {
16000
}
impl Default for TalkerConfig {
fn default() -> Self {
serde_json::from_str("{}").unwrap()
}
}
impl Qwen3TtsConfig {
pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self, crate::error::TtsError> {
Self::from_bytes(std::fs::read(path)?)
}
pub fn from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self, crate::error::TtsError> {
let config: Self = serde_json::from_slice(bytes.as_ref())?;
Ok(config)
}
pub fn is_voice_design(&self) -> bool {
self.tts_model_type == "voice_design"
}
pub fn speakers(&self) -> Vec<String> {
self.talker_config.spk_id.keys().cloned().collect()
}
pub fn languages(&self) -> Vec<String> {
self.talker_config
.codec_language_id
.keys()
.filter(|k| !k.contains("dialect"))
.cloned()
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let json = "{}";
let config: Qwen3TtsConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.model_type, "qwen3_tts");
assert_eq!(config.tts_model_type, "custom_voice");
assert_eq!(config.tts_bos_token_id, 151672);
}
#[test]
fn test_talker_config_defaults() {
let json = r#"{"talker_config": {}}"#;
let config: Qwen3TtsConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.talker_config.hidden_size, 2048);
assert_eq!(config.talker_config.num_code_groups, 16);
}
#[test]
fn test_speaker_id_parsing() {
let json = r#"{
"talker_config": {
"spk_id": {"Vivian": 0, "Ryan": 1, "Serena": 2}
}
}"#;
let config: Qwen3TtsConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.talker_config.spk_id.len(), 3);
assert_eq!(config.talker_config.spk_id["Vivian"], 0);
}
#[test]
fn test_language_filtering() {
let json = r#"{
"talker_config": {
"codec_language_id": {
"Chinese": 0,
"English": 1,
"Chinese_dialect_sichuan": 2
}
}
}"#;
let config: Qwen3TtsConfig = serde_json::from_str(json).unwrap();
let langs = config.languages();
assert!(langs.contains(&"Chinese".to_string()));
assert!(langs.contains(&"English".to_string()));
assert!(!langs.iter().any(|l| l.contains("dialect")));
}
}