use std::path::PathBuf;
use hf_hub::api::sync::ApiBuilder;
use hf_hub::{Repo, RepoType};
use crate::TtsError;
const REPO_ID: &str = "wavekat/Qwen3-TTS-1.7B-VoiceDesign-ONNX";
const REVISION: &str = "2026-04-06";
const CLONE_REPO_ID: &str = "wavekat/Qwen3-TTS-0.6B-Base-ONNX";
const CLONE_REVISION: &str = "main";
const ONNX_FILES_INT4: &[&str] = &[
"int4/talker_prefill.onnx",
"int4/talker_prefill.onnx.data",
"int4/talker_decode.onnx",
"int4/talker_decode.onnx.data",
"int4/code_predictor.onnx",
"int4/code_predictor.onnx.data",
"int4/vocoder.onnx",
"int4/vocoder.onnx.data",
];
const ONNX_FILES_FP32: &[&str] = &[
"fp32/talker_prefill.onnx",
"fp32/talker_prefill.onnx.data",
"fp32/talker_decode.onnx",
"fp32/talker_decode.onnx.data",
"fp32/code_predictor.onnx",
"fp32/code_predictor.onnx.data",
"fp32/vocoder.onnx",
"fp32/vocoder.onnx.data",
];
const SHARED_FILES: &[&str] = &[
"config.json",
"embeddings/text_embedding.npy",
"embeddings/text_projection_fc1_weight.npy",
"embeddings/text_projection_fc1_bias.npy",
"embeddings/text_projection_fc2_weight.npy",
"embeddings/text_projection_fc2_bias.npy",
"embeddings/talker_codec_embedding.npy",
"embeddings/cp_codec_embedding_0.npy",
"embeddings/cp_codec_embedding_1.npy",
"embeddings/cp_codec_embedding_2.npy",
"embeddings/cp_codec_embedding_3.npy",
"embeddings/cp_codec_embedding_4.npy",
"embeddings/cp_codec_embedding_5.npy",
"embeddings/cp_codec_embedding_6.npy",
"embeddings/cp_codec_embedding_7.npy",
"embeddings/cp_codec_embedding_8.npy",
"embeddings/cp_codec_embedding_9.npy",
"embeddings/cp_codec_embedding_10.npy",
"embeddings/cp_codec_embedding_11.npy",
"embeddings/cp_codec_embedding_12.npy",
"embeddings/cp_codec_embedding_13.npy",
"embeddings/cp_codec_embedding_14.npy",
"tokenizer/vocab.json",
"tokenizer/merges.txt",
];
pub fn resolve_model_dir(config: &super::ModelConfig) -> Result<PathBuf, TtsError> {
if let Some(dir) = &config.model_dir {
return Ok(dir.clone());
}
let cache_dir_override = match std::env::var("WAVEKAT_MODEL_DIR") {
Ok(dir) => {
let path = PathBuf::from(&dir);
if path.join("config.json").exists() {
return Ok(path);
}
Some(path)
}
Err(_) => None,
};
let precision = config.precision;
let mut builder = ApiBuilder::from_env();
if let Some(ref dir) = cache_dir_override {
builder = builder.with_cache_dir(dir.clone());
}
if let Ok(token) = std::env::var("HF_TOKEN") {
if !token.is_empty() {
builder = builder.with_token(Some(token));
}
}
let api = builder
.build()
.map_err(|e| TtsError::Model(format!("failed to initialize HF Hub client: {e}")))?;
let repo = api.repo(Repo::with_revision(
REPO_ID.to_string(),
RepoType::Model,
REVISION.to_string(),
));
let onnx_files = match precision {
super::ModelPrecision::Int4 => ONNX_FILES_INT4,
super::ModelPrecision::Fp32 => ONNX_FILES_FP32,
};
let total = 1 + onnx_files.len() + SHARED_FILES[1..].len();
eprintln!(
"Ensuring Qwen3-TTS 1.7B ({}) model ({total} files from {REPO_ID})...",
precision.subdir()
);
eprintln!("[1/{total}] {}", SHARED_FILES[0]);
let config_path = repo
.get(SHARED_FILES[0])
.map_err(|e| TtsError::Model(format!("failed to download {}: {e}", SHARED_FILES[0])))?;
let model_dir = config_path
.parent()
.ok_or_else(|| TtsError::Model("unexpected cache path for config.json".into()))?
.to_path_buf();
for (i, filename) in onnx_files
.iter()
.chain(SHARED_FILES[1..].iter())
.enumerate()
{
eprintln!("[{}/{total}] {filename}", i + 2);
repo.get(filename)
.map_err(|e| TtsError::Model(format!("failed to download {filename}: {e}")))?;
}
eprintln!("Files ready. Loading model ...");
Ok(model_dir)
}
const CLONE_ONNX_FILES_INT4: &[&str] = &[
"int4/talker_prefill.onnx",
"int4/talker_prefill.onnx.data",
"int4/talker_decode.onnx",
"int4/talker_decode.onnx.data",
"int4/code_predictor.onnx",
"int4/code_predictor.onnx.data",
"int4/vocoder.onnx",
"int4/vocoder.onnx.data",
];
const CLONE_ONNX_FILES_FP32: &[&str] = &[
"fp32/talker_prefill.onnx",
"fp32/talker_prefill.onnx.data",
"fp32/talker_decode.onnx",
"fp32/talker_decode.onnx.data",
"fp32/code_predictor.onnx",
"fp32/code_predictor.onnx.data",
"fp32/vocoder.onnx",
"fp32/vocoder.onnx.data",
];
const CLONE_SHARED_FILES: &[&str] = &[
"config.json",
"speaker_encoder.onnx",
"speaker_encoder.onnx.data",
"tokenizer_encoder.onnx",
"tokenizer_encoder.onnx.data",
"embeddings/text_embedding.npy",
"embeddings/text_projection_fc1_weight.npy",
"embeddings/text_projection_fc1_bias.npy",
"embeddings/text_projection_fc2_weight.npy",
"embeddings/text_projection_fc2_bias.npy",
"embeddings/talker_codec_embedding.npy",
"embeddings/cp_codec_embedding_0.npy",
"embeddings/cp_codec_embedding_1.npy",
"embeddings/cp_codec_embedding_2.npy",
"embeddings/cp_codec_embedding_3.npy",
"embeddings/cp_codec_embedding_4.npy",
"embeddings/cp_codec_embedding_5.npy",
"embeddings/cp_codec_embedding_6.npy",
"embeddings/cp_codec_embedding_7.npy",
"embeddings/cp_codec_embedding_8.npy",
"embeddings/cp_codec_embedding_9.npy",
"embeddings/cp_codec_embedding_10.npy",
"embeddings/cp_codec_embedding_11.npy",
"embeddings/cp_codec_embedding_12.npy",
"embeddings/cp_codec_embedding_13.npy",
"embeddings/cp_codec_embedding_14.npy",
"tokenizer/vocab.json",
"tokenizer/merges.txt",
];
pub fn resolve_clone_model_dir(config: &super::ModelConfig) -> Result<PathBuf, TtsError> {
if let Some(dir) = &config.model_dir {
return Ok(dir.clone());
}
let cache_dir_override = match std::env::var("WAVEKAT_CLONE_MODEL_DIR") {
Ok(dir) => {
let path = PathBuf::from(&dir);
if path.join("config.json").exists() {
return Ok(path);
}
Some(path)
}
Err(_) => None,
};
let precision = config.precision;
let mut builder = ApiBuilder::from_env();
if let Some(ref dir) = cache_dir_override {
builder = builder.with_cache_dir(dir.clone());
}
if let Ok(token) = std::env::var("HF_TOKEN") {
if !token.is_empty() {
builder = builder.with_token(Some(token));
}
}
let api = builder
.build()
.map_err(|e| TtsError::Model(format!("failed to initialize HF Hub client: {e}")))?;
let repo = api.repo(Repo::with_revision(
CLONE_REPO_ID.to_string(),
RepoType::Model,
CLONE_REVISION.to_string(),
));
let onnx_files = match precision {
super::ModelPrecision::Int4 => CLONE_ONNX_FILES_INT4,
super::ModelPrecision::Fp32 => CLONE_ONNX_FILES_FP32,
};
let total = 1 + onnx_files.len() + CLONE_SHARED_FILES[1..].len();
eprintln!(
"Ensuring Qwen3-TTS 0.6B Clone ({}) model ({total} files from {CLONE_REPO_ID})...",
precision.subdir()
);
eprintln!("[1/{total}] {}", CLONE_SHARED_FILES[0]);
let config_path = repo.get(CLONE_SHARED_FILES[0]).map_err(|e| {
TtsError::Model(format!("failed to download {}: {e}", CLONE_SHARED_FILES[0]))
})?;
let model_dir = config_path
.parent()
.ok_or_else(|| TtsError::Model("unexpected cache path for config.json".into()))?
.to_path_buf();
for (i, filename) in onnx_files
.iter()
.chain(CLONE_SHARED_FILES[1..].iter())
.enumerate()
{
eprintln!("[{}/{total}] {filename}", i + 2);
repo.get(filename)
.map_err(|e| TtsError::Model(format!("failed to download {filename}: {e}")))?;
}
eprintln!("Files ready. Loading clone model ...");
Ok(model_dir)
}