use std::{collections::HashMap, path::PathBuf};
use anyhow::{bail, Context, Result};
use hf_hub::api::sync::Api;
use serde::Deserialize;
use crate::model::KittenTtsOnnx;
#[derive(Debug, Deserialize)]
pub struct ModelConfig {
#[serde(rename = "type")]
pub model_type: String,
pub model_file: String,
pub voices: String,
#[serde(default)]
pub speed_priors: HashMap<String, f32>,
#[serde(default)]
pub voice_aliases: HashMap<String, String>,
}
fn hf_download(api: &Api, repo_id: &str, filename: &str) -> Result<PathBuf> {
let repo = api.model(repo_id.to_string());
repo.get(filename)
.with_context(|| format!("Failed to download '{}' from '{}'", filename, repo_id))
}
#[derive(Debug, Clone)]
pub enum LoadProgress {
Fetching { step: u32, total: u32, file: String },
Loading,
}
pub fn load_from_hub_cb<F>(repo_id: &str, mut on_progress: F) -> Result<KittenTtsOnnx>
where
F: FnMut(LoadProgress),
{
let repo_id = if repo_id.contains('/') {
repo_id.to_string()
} else {
format!("KittenML/{}", repo_id)
};
let api = Api::new().context("Failed to initialise HuggingFace Hub client")?;
on_progress(LoadProgress::Fetching {
step: 1, total: 4, file: "config.json".into(),
});
let config_path = hf_download(&api, &repo_id, "config.json")?;
let config_bytes = std::fs::read(&config_path)
.with_context(|| format!("Cannot read config: {}", config_path.display()))?;
let config: ModelConfig = serde_json::from_slice(&config_bytes)
.context("Failed to parse config.json")?;
if !matches!(config.model_type.as_str(), "ONNX1" | "ONNX2") {
bail!(
"Unsupported model type '{}' — expected ONNX1 or ONNX2",
config.model_type
);
}
on_progress(LoadProgress::Fetching {
step: 2, total: 4, file: config.model_file.clone(),
});
let model_path = hf_download(&api, &repo_id, &config.model_file)?;
on_progress(LoadProgress::Fetching {
step: 3, total: 4, file: config.voices.clone(),
});
let voices_path = hf_download(&api, &repo_id, &config.voices)?;
on_progress(LoadProgress::Loading);
KittenTtsOnnx::load(
&model_path,
&voices_path,
config.speed_priors,
config.voice_aliases,
)
}
pub fn load_from_hub(repo_id: &str) -> Result<KittenTtsOnnx> {
load_from_hub_cb(repo_id, |_| {})
}
pub fn load_default() -> Result<KittenTtsOnnx> {
load_from_hub("KittenML/kitten-tts-nano-0.8-int8")
}
pub fn list_voices_from_hub(repo_id: &str) -> Result<Vec<String>> {
let repo_id = if repo_id.contains('/') {
repo_id.to_string()
} else {
format!("KittenML/{}", repo_id)
};
let api = Api::new().context("Failed to initialise HuggingFace Hub client")?;
let config_path = hf_download(&api, &repo_id, "config.json")?;
let config_bytes = std::fs::read(&config_path)
.with_context(|| format!("Cannot read config: {}", config_path.display()))?;
let config: ModelConfig = serde_json::from_slice(&config_bytes)
.context("Failed to parse config.json")?;
let voices_path = hf_download(&api, &repo_id, &config.voices)?;
let raw = crate::npz::load_npz(&voices_path)
.with_context(|| format!("Cannot load voices NPZ: {}", voices_path.display()))?;
let mut names: Vec<String> = raw.into_keys().collect();
for alias_name in config.voice_aliases.keys() {
if !names.contains(alias_name) {
names.push(alias_name.clone());
}
}
names.sort();
Ok(names)
}