use std::path::PathBuf;
use anyhow::{bail, Context, Result};
use hf_hub::api::sync::Api;
#[cfg(feature = "backbone")]
use crate::model::NeuTTS;
#[derive(Debug, Clone)]
pub struct ModelInfo {
pub repo: &'static str,
pub name: &'static str,
pub language: &'static str,
pub params: &'static str,
pub is_gguf: bool,
}
pub const BACKBONE_MODELS: &[ModelInfo] = &[
ModelInfo { repo: "neuphonic/neutts-nano-q4-gguf", name: "NeuTTS Nano Q4", language: "en-us", params: "0.2B", is_gguf: true },
ModelInfo { repo: "neuphonic/neutts-nano-q8-gguf", name: "NeuTTS Nano Q8", language: "en-us", params: "0.2B", is_gguf: true },
ModelInfo { repo: "neuphonic/neutts-nano", name: "NeuTTS Nano (full)", language: "en-us", params: "0.2B", is_gguf: false },
ModelInfo { repo: "neuphonic/neutts-air-q4-gguf", name: "NeuTTS Air Q4", language: "en-us", params: "0.7B", is_gguf: true },
ModelInfo { repo: "neuphonic/neutts-air-q8-gguf", name: "NeuTTS Air Q8", language: "en-us", params: "0.7B", is_gguf: true },
ModelInfo { repo: "neuphonic/neutts-air", name: "NeuTTS Air (full)", language: "en-us", params: "0.7B", is_gguf: false },
ModelInfo { repo: "neuphonic/neutts-nano-german-q4-gguf", name: "NeuTTS Nano German Q4", language: "de", params: "0.2B", is_gguf: true },
ModelInfo { repo: "neuphonic/neutts-nano-german-q8-gguf", name: "NeuTTS Nano German Q8", language: "de", params: "0.2B", is_gguf: true },
ModelInfo { repo: "neuphonic/neutts-nano-german", name: "NeuTTS Nano German (full)", language: "de", params: "0.2B", is_gguf: false },
ModelInfo { repo: "neuphonic/neutts-nano-french-q4-gguf", name: "NeuTTS Nano French Q4", language: "fr-fr", params: "0.2B", is_gguf: true },
ModelInfo { repo: "neuphonic/neutts-nano-french-q8-gguf", name: "NeuTTS Nano French Q8", language: "fr-fr", params: "0.2B", is_gguf: true },
ModelInfo { repo: "neuphonic/neutts-nano-french", name: "NeuTTS Nano French (full)", language: "fr-fr", params: "0.2B", is_gguf: false },
ModelInfo { repo: "neuphonic/neutts-nano-spanish-q4-gguf", name: "NeuTTS Nano Spanish Q4", language: "es", params: "0.2B", is_gguf: true },
ModelInfo { repo: "neuphonic/neutts-nano-spanish-q8-gguf", name: "NeuTTS Nano Spanish Q8", language: "es", params: "0.2B", is_gguf: true },
ModelInfo { repo: "neuphonic/neutts-nano-spanish", name: "NeuTTS Nano Spanish (full)", language: "es", params: "0.2B", is_gguf: false },
];
pub fn find_model(repo: &str) -> Option<&'static ModelInfo> {
BACKBONE_MODELS.iter().find(|m| m.repo == repo)
}
fn backbone_language(repo: &str) -> &'static str {
find_model(repo).map(|m| m.language).unwrap_or("en-us")
}
#[derive(Debug, Clone)]
pub enum LoadProgress {
Fetching { step: u32, total: u32, file: String, repo: String },
Loading { step: u32, total: u32, component: String },
}
fn hf_download(api: &Api, repo_id: &str, filename: &str) -> Result<PathBuf> {
let repo = api.model(repo_id.to_string());
repo.get(filename)
.with_context(|| format!("Failed to download '{filename}' from '{repo_id}'"))
}
fn hf_list_files(api: &Api, repo_id: &str) -> Result<Vec<String>> {
let repo = api.model(repo_id.to_string());
let info = repo.info().with_context(|| format!("Failed to fetch repo info for '{repo_id}'"))?;
Ok(info.siblings.into_iter().map(|s| s.rfilename).collect())
}
fn hf_download_by_extension(
api: &Api,
repo_id: &str,
extensions: &[&str],
) -> Result<PathBuf> {
let files = hf_list_files(api, repo_id)?;
for ext in extensions {
if let Some(fname) = files.iter().find(|f| f.ends_with(ext)) {
return hf_download(api, repo_id, fname);
}
}
bail!(
"No file with extension {:?} found in '{}'.\n\
Available files: {:?}",
extensions, repo_id, files
);
}
#[cfg(feature = "backbone")]
pub fn load_from_hub_cb<F>(
backbone_repo: &str,
gguf_file: Option<&str>,
mut on_progress: F,
) -> Result<NeuTTS>
where
F: FnMut(LoadProgress),
{
let api = Api::new().context("Failed to initialise HuggingFace Hub client")?;
let file_label = gguf_file.unwrap_or("*.gguf").to_string();
on_progress(LoadProgress::Fetching {
step: 1, total: 2,
file: file_label,
repo: backbone_repo.into(),
});
let backbone_path = match gguf_file {
Some(fname) => hf_download(&api, backbone_repo, fname)
.with_context(|| {
format!("Failed to download '{fname}' from '{backbone_repo}'.\n\
\n\
List available files with:\n\
\n\
\tcargo run --example speak -- \
--backbone {backbone_repo} --list-files")
})?,
None => hf_download_by_extension(&api, backbone_repo, &[".gguf"])
.with_context(|| format!("Failed to download GGUF from '{backbone_repo}'"))?,
};
on_progress(LoadProgress::Loading {
step: 2, total: 2, component: "backbone".into(),
});
let language = backbone_language(backbone_repo).to_string();
NeuTTS::load(&backbone_path, &language)
}
#[cfg(feature = "backbone")]
pub fn load_from_hub(backbone_repo: &str) -> Result<NeuTTS> {
load_from_hub_cb(backbone_repo, None, |_| {})
}
pub fn list_gguf_files(backbone_repo: &str) -> Result<Vec<String>> {
let api = Api::new().context("Failed to initialise HuggingFace Hub client")?;
let files = hf_list_files(&api, backbone_repo)?;
Ok(files.into_iter().filter(|f| f.ends_with(".gguf")).collect())
}
#[cfg(feature = "backbone")]
pub fn load_default() -> Result<NeuTTS> {
load_from_hub("neuphonic/neutts-nano-q4-gguf")
}
pub fn download_encoder_onnx(encoder_repo: &str, dest_dir: &std::path::Path) -> Result<PathBuf> {
let api = Api::new().context("Failed to initialise HuggingFace Hub client")?;
let path = hf_download_by_extension(&api, encoder_repo, &[".onnx"])
.with_context(|| format!("Failed to download encoder ONNX from '{encoder_repo}'"))?;
std::fs::create_dir_all(dest_dir)
.context("Failed to create model staging directory")?;
let dest = dest_dir.join("neucodec_encoder.onnx");
std::fs::copy(&path, &dest)
.with_context(|| format!("Failed to copy encoder ONNX to {}", dest.display()))?;
Ok(dest)
}
pub fn download_decoder_onnx(decoder_repo: &str, dest_dir: &std::path::Path) -> Result<PathBuf> {
let api = Api::new().context("Failed to initialise HuggingFace Hub client")?;
let path = hf_download_by_extension(&api, decoder_repo, &[".onnx"])
.with_context(|| format!("Failed to download decoder ONNX from '{decoder_repo}'"))?;
std::fs::create_dir_all(dest_dir)
.context("Failed to create model staging directory")?;
let dest = dest_dir.join("neucodec_decoder.onnx");
std::fs::copy(&path, &dest)
.with_context(|| format!("Failed to copy decoder ONNX to {}", dest.display()))?;
Ok(dest)
}
pub fn load_encoder(source: &str) -> Result<crate::codec::NeuCodecEncoder> {
let path = std::path::Path::new(source);
if path.extension().and_then(|e| e.to_str()) == Some("bin") && path.exists() {
return crate::codec::NeuCodecEncoder::load(path)
.with_context(|| format!("Failed to load Burn encoder from {source}"));
}
if path.extension().and_then(|e| e.to_str()) == Some("onnx") && path.exists() {
bail!(
"ONNX files cannot be loaded at runtime with the Burn backend.\n\
\n\
Stage the file for build-time conversion and rebuild:\n\
\n\
\tcp {source} models/neucodec_encoder.onnx\n\
\tcargo build\n"
);
}
let models_dir = std::path::Path::new("models");
let staged = models_dir.join("neucodec_encoder.onnx");
if !staged.exists() {
println!("Downloading NeuCodec encoder ONNX from HuggingFace…");
download_encoder_onnx(source, models_dir)?;
bail!(
"Encoder ONNX downloaded to models/neucodec_encoder.onnx.\n\
\n\
Rebuild to convert it to Burn:\n\
\n\
\tcargo build\n\
\n\
Then call NeuCodecEncoder::new() — no runtime file path needed."
);
}
bail!(
"models/neucodec_encoder.onnx is staged but the Burn model is not compiled in yet.\n\
\n\
Run:\n\
\n\
\tcargo build\n\
\n\
Then use NeuCodecEncoder::new() at runtime."
)
}
pub fn supported_backbone_repos() -> Vec<&'static str> {
BACKBONE_MODELS.iter().map(|m| m.repo).collect()
}
pub fn supported_gguf_repos() -> Vec<&'static str> {
BACKBONE_MODELS.iter().filter(|m| m.is_gguf).map(|m| m.repo).collect()
}
pub fn supported_codec_decoder_repo() -> &'static str {
"neuphonic/neucodec-onnx-decoder"
}
pub fn supported_codec_encoder_repo() -> &'static str {
"neuphonic/neucodec-onnx-encoder"
}
pub fn print_model_table() {
println!("{:<45} {:<28} {:<7} {:<6} {}",
"repo", "name", "lang", "params", "gguf");
println!("{}", "-".repeat(95));
for m in BACKBONE_MODELS {
println!("{:<45} {:<28} {:<7} {:<6} {}",
m.repo, m.name, m.language, m.params,
if m.is_gguf { "yes" } else { "no" });
}
}