use std::path::PathBuf;
use wavekat_core::AudioFrame;
use crate::error::TtsError;
use crate::traits::TtsBackend;
use crate::types::{SynthesizeRequest, VoiceInfo};
use std::sync::Once;
use tokenizer::{IM_END, IM_START, NEWLINE};
static WARNED_NO_INSTRUCTION: Once = Once::new();
mod clone_model;
mod download;
mod mel;
mod model;
mod sampler;
mod tokenizer;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ModelPrecision {
#[default]
Int4,
Fp32,
}
impl ModelPrecision {
pub(crate) fn subdir(self) -> &'static str {
match self {
Self::Int4 => "int4",
Self::Fp32 => "fp32",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ExecutionProvider {
#[default]
Cpu,
Cuda,
TensorRt,
CoreMl,
}
#[derive(Debug, Clone, Default)]
pub struct ModelConfig {
pub precision: ModelPrecision,
pub execution_provider: ExecutionProvider,
pub model_dir: Option<PathBuf>,
}
impl ModelConfig {
pub fn with_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.model_dir = Some(dir.into());
self
}
pub fn with_precision(mut self, precision: ModelPrecision) -> Self {
self.precision = precision;
self
}
pub fn with_execution_provider(mut self, ep: ExecutionProvider) -> Self {
self.execution_provider = ep;
self
}
}
pub struct Qwen3Tts {
model: model::Model,
tokenizer: tokenizer::Tokenizer,
}
impl Qwen3Tts {
pub fn new() -> Result<Self, TtsError> {
Self::from_config(ModelConfig::default())
}
pub fn from_config(config: ModelConfig) -> Result<Self, TtsError> {
let model_dir = download::resolve_model_dir(&config)?;
let model = model::Model::load(model_dir.as_ref(), &config)?;
let tokenizer = tokenizer::Tokenizer::new(&model_dir)?;
Ok(Self { model, tokenizer })
}
}
#[derive(Debug, Clone)]
pub struct CloneRequest<'a> {
pub text: &'a str,
pub ref_samples: &'a [f32],
pub ref_sample_rate: u32,
pub ref_text: &'a str,
pub language: Option<&'a str>,
}
impl<'a> CloneRequest<'a> {
pub fn new(
text: &'a str,
ref_samples: &'a [f32],
ref_sample_rate: u32,
ref_text: &'a str,
) -> Self {
Self {
text,
ref_samples,
ref_sample_rate,
ref_text,
language: None,
}
}
pub fn with_language(mut self, language: &'a str) -> Self {
self.language = Some(language);
self
}
}
pub struct Qwen3TtsClone {
model: clone_model::CloneModel,
tokenizer: tokenizer::Tokenizer,
}
impl Qwen3TtsClone {
pub fn new() -> Result<Self, TtsError> {
Self::from_config(ModelConfig::default())
}
pub fn from_config(config: ModelConfig) -> Result<Self, TtsError> {
let model_dir = download::resolve_clone_model_dir(&config)?;
let model = clone_model::CloneModel::load(model_dir.as_ref(), &config)?;
let tokenizer = tokenizer::Tokenizer::new(&model_dir)?;
Ok(Self { model, tokenizer })
}
pub fn synthesize_clone(
&self,
request: &CloneRequest,
) -> Result<AudioFrame<'static>, TtsError> {
if request.ref_sample_rate != 24000 {
return Err(TtsError::Synthesis(format!(
"reference audio must be 24 kHz, got {} Hz",
request.ref_sample_rate,
)));
}
let language = request.language.unwrap_or("en");
let ref_tokens = self.tokenizer.encode(request.ref_text)?;
let text_tokens = self.tokenizer.encode(request.text)?;
self.model
.synthesize(request.ref_samples, &ref_tokens, &text_tokens, language)
}
}
impl TtsBackend for Qwen3Tts {
fn synthesize(&self, request: &SynthesizeRequest) -> Result<AudioFrame<'static>, TtsError> {
let tokens = self.tokenizer.encode(request.text)?;
let language = request.language.unwrap_or("en");
if request.instruction.is_none() {
WARNED_NO_INSTRUCTION.call_once(|| {
eprintln!(
"wavekat-tts warning: Qwen3-TTS is a VoiceDesign model — \
synthesize quality may be inconsistent without a style instruction. \
Set `SynthesizeRequest::with_instruction` to control voice style."
);
});
}
let instruction_tokens = if let Some(instr) = request.instruction {
let mut toks = vec![IM_START];
toks.extend(self.tokenizer.encode("user")?);
toks.push(NEWLINE);
toks.extend(self.tokenizer.encode("<instruct>")?);
toks.extend(self.tokenizer.encode(instr)?);
toks.extend(self.tokenizer.encode("</instruct>")?);
toks.push(IM_END);
toks.push(NEWLINE);
Some(toks)
} else {
None
};
self.model
.synthesize(&tokens, language, instruction_tokens.as_deref())
}
fn voices(&self) -> Result<Vec<VoiceInfo>, TtsError> {
Ok(vec![VoiceInfo {
id: "default".into(),
name: "Qwen3-TTS Default".into(),
languages: vec![
"en".into(),
"zh".into(),
"ja".into(),
"ko".into(),
"de".into(),
"es".into(),
"fr".into(),
"ru".into(),
"it".into(),
"pt".into(),
],
gender: None,
}])
}
}