use std::path::PathBuf;
use hf_hub::api::sync::ApiBuilder;
use hf_hub::{Repo, RepoType};
use super::pipecat::SmartTurnLang;
use crate::error::TurnError;
const REPO_ID: &str = "wavekat/smart-turn-ONNX";
const REVISION: &str = "main";
const LOCAL_DIR_ENV: &str = "WAVEKAT_TURN_MODEL_DIR";
fn relative_path(lang: SmartTurnLang) -> &'static str {
match lang {
SmartTurnLang::Zh => "zh/smart-turn-cpu.onnx",
}
}
pub(crate) fn resolve_model(lang: SmartTurnLang) -> Result<PathBuf, TurnError> {
let rel = relative_path(lang);
if let Some(dir) = std::env::var_os(LOCAL_DIR_ENV) {
let candidate = PathBuf::from(dir).join(rel);
if !candidate.exists() {
return Err(TurnError::ModelNotLoaded(format!(
"{LOCAL_DIR_ENV} is set but {} does not exist",
candidate.display()
)));
}
return Ok(candidate);
}
let api = ApiBuilder::new()
.with_token(std::env::var("HF_TOKEN").ok())
.build()
.map_err(|e| TurnError::BackendError(format!("failed to build hf-hub client: {e}")))?;
let repo = api.repo(Repo::with_revision(
REPO_ID.to_string(),
RepoType::Model,
REVISION.to_string(),
));
repo.get(rel).map_err(|e| {
TurnError::BackendError(format!(
"failed to download {REPO_ID}@{REVISION}:{rel} from HuggingFace: {e}. \
Set {LOCAL_DIR_ENV} to a directory containing {rel} to skip the download."
))
})
}