use anyhow::{Context, Result};
use bzip2::read::BzDecoder;
use futures_util::StreamExt;
use sha2::{Digest, Sha256};
use std::path::Path;
use tar::Archive;
use tokio::io::AsyncWriteExt;
struct DownloadProgress {
total: u64,
current: u64,
last_percent: u8,
}
impl DownloadProgress {
fn new(total: u64) -> Self {
Self {
total,
current: 0,
last_percent: 0,
}
}
fn update(&mut self, bytes: u64) {
self.current += bytes;
let percent = (self.current * 100)
.checked_div(self.total)
.map(|p| p as u8)
.unwrap_or(0);
if percent != self.last_percent {
self.last_percent = percent;
eprint!(
"\rDownloading... {percent}% ({:.1}MB / {:.1}MB)",
self.current as f64 / 1_048_576.0,
self.total as f64 / 1_048_576.0
);
}
}
fn finish(&self) {
eprintln!(
"\rDownload complete ({:.1}MB) ",
self.current as f64 / 1_048_576.0
);
}
}
const MODEL_BUNDLE_REPO: &str = "k2-fsa/sherpa-onnx";
const MODEL_BUNDLE_RELEASE: &str = "asr-models";
const MODEL_BUNDLE_FILENAME: &str = "sherpa-onnx-zipformer-vi-int8-2025-04-20.tar.bz2";
const MODEL_BUNDLE_TOP_DIR: &str = "sherpa-onnx-zipformer-vi-int8-2025-04-20";
const MODEL_BUNDLE_SHA256: &str =
"48d0fdc9b3515eb9b00c4dfec2883207ee5ebe5c95b1959e7afce87fc3391938";
pub const MODEL_FILES: &[&str] = &[
"encoder.int8.onnx",
"decoder.onnx",
"joiner.int8.onnx",
"bpe.model",
"tokens.txt",
];
#[cfg(feature = "diarization")]
const SPEAKER_HF_REPO: &str = "onnx-community/wespeaker-voxceleb-resnet34-LM";
#[cfg(feature = "diarization")]
pub const SPEAKER_MODEL_FILE: &str = "wespeaker_resnet34.onnx";
#[cfg(feature = "diarization")]
const SPEAKER_MODEL_SHA256: &str =
"3955447b0499dc9e0a4541a895df08b03c69098eba4e56c02b5603e9f7f4fcbb";
fn home_dir() -> Option<std::path::PathBuf> {
#[cfg(unix)]
{
std::env::var_os("HOME").map(std::path::PathBuf::from)
}
#[cfg(windows)]
{
std::env::var_os("USERPROFILE").map(std::path::PathBuf::from)
}
}
pub fn default_model_dir() -> String {
home_dir()
.map(|h| {
h.join(".phostt")
.join("models")
.to_string_lossy()
.into_owned()
})
.unwrap_or_else(|| ".phostt/models".into())
}
pub async fn ensure_model(model_dir: &str) -> Result<()> {
let dir = Path::new(model_dir);
if model_files_exist(dir) {
tracing::info!("Model found at {model_dir}");
return Ok(());
}
tracing::info!("Model not found, downloading Zipformer-vi bundle...");
std::fs::create_dir_all(dir).context("Failed to create model directory")?;
let archive_dest = dir.join(MODEL_BUNDLE_FILENAME);
let url = format!(
"https://github.com/{MODEL_BUNDLE_REPO}/releases/download/{MODEL_BUNDLE_RELEASE}/{MODEL_BUNDLE_FILENAME}"
);
stream_to_partial_then_finalize(
&url,
&archive_dest,
Some(MODEL_BUNDLE_SHA256),
MODEL_BUNDLE_FILENAME,
)
.await?;
tracing::info!("Extracting bundle into {}", dir.display());
extract_bundle(&archive_dest, dir)?;
let _ = std::fs::remove_file(&archive_dest);
normalize_model_filenames(dir)?;
if !model_files_exist(dir) {
anyhow::bail!(
"Bundle extracted but expected files are still missing under {}",
dir.display()
);
}
tracing::info!("Model bundle ready");
Ok(())
}
#[cfg(feature = "diarization")]
pub async fn ensure_speaker_model(model_dir: &str) -> Result<()> {
let dir = Path::new(model_dir);
let final_dest = dir.join(SPEAKER_MODEL_FILE);
if final_dest.exists() {
tracing::info!("Speaker model found at {}", final_dest.display());
return Ok(());
}
tracing::info!("Speaker model not found, downloading from HuggingFace...");
std::fs::create_dir_all(dir).context("Failed to create model directory")?;
let url = format!("https://huggingface.co/{SPEAKER_HF_REPO}/resolve/main/onnx/model.onnx");
stream_to_partial_then_finalize(
&url,
&final_dest,
Some(SPEAKER_MODEL_SHA256),
SPEAKER_MODEL_FILE,
)
.await
}
fn model_files_exist(dir: &Path) -> bool {
MODEL_FILES.iter().all(|f| dir.join(f).exists())
}
fn partial_path(final_path: &Path) -> std::path::PathBuf {
let mut s: std::ffi::OsString = final_path.as_os_str().to_owned();
s.push(".partial");
std::path::PathBuf::from(s)
}
fn bytes_to_hex(bytes: &[u8]) -> String {
let mut s = String::with_capacity(bytes.len() * 2);
for b in bytes {
use std::fmt::Write;
write!(s, "{:02x}", b).unwrap();
}
s
}
fn sha256_file(path: &Path) -> Result<String> {
use std::io::Read;
let file = std::fs::File::open(path)
.with_context(|| format!("Failed to open file for verification: {}", path.display()))?;
let mut reader = std::io::BufReader::new(file);
let mut hasher = Sha256::new();
let mut buf = [0u8; 64 * 1024];
loop {
let n = reader.read(&mut buf).context("Read error during SHA-256")?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
Ok(bytes_to_hex(hasher.finalize().as_ref()))
}
fn finalize_download(
partial_path: &Path,
final_path: &Path,
expected_sha256: Option<&str>,
label: &str,
) -> Result<()> {
if let Some(expected) = expected_sha256 {
let actual = sha256_file(partial_path)?;
if actual != expected {
let _ = std::fs::remove_file(partial_path);
anyhow::bail!("SHA-256 mismatch for {label}: expected {expected}, got {actual}");
}
tracing::info!("SHA-256 verified: {label}");
}
std::fs::rename(partial_path, final_path).with_context(|| {
format!(
"Failed to rename {} -> {}",
partial_path.display(),
final_path.display()
)
})?;
Ok(())
}
async fn stream_to_partial_then_finalize(
url: &str,
final_dest: &Path,
expected_sha256: Option<&str>,
label: &str,
) -> Result<()> {
let partial = partial_path(final_dest);
if partial.exists() {
let _ = tokio::fs::remove_file(&partial).await;
}
tracing::info!("Downloading {label}...");
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(30))
.timeout(std::time::Duration::from_secs(600))
.build()
.context("Failed to build HTTP client")?;
let response = client
.get(url)
.send()
.await
.context("HTTP request failed")?;
let status = response.status();
if !status.is_success() {
anyhow::bail!("Download failed for {label}: HTTP {status}");
}
let total_size = response.content_length().unwrap_or(0);
let mut progress = DownloadProgress::new(total_size);
let mut file = tokio::fs::File::create(&partial)
.await
.context("Failed to create partial model file")?;
let mut stream = response.bytes_stream();
let mut downloaded: u64 = 0;
while let Some(chunk) = stream.next().await {
let chunk = chunk.context("Download stream error")?;
file.write_all(&chunk)
.await
.context("Failed to write chunk")?;
downloaded += chunk.len() as u64;
progress.update(chunk.len() as u64);
}
file.flush().await?;
drop(file);
progress.finish();
tracing::info!("Wrote partial {} ({downloaded} bytes)", partial.display());
finalize_download(&partial, final_dest, expected_sha256, label)?;
tracing::info!("Saved {label}");
Ok(())
}
fn normalize_model_filenames(dir: &Path) -> Result<()> {
let renames: &[(&str, &str)] = &[
("encoder", "encoder.int8.onnx"),
("decoder", "decoder.onnx"),
("joiner", "joiner.int8.onnx"),
];
for (prefix, target_name) in renames {
let target = dir.join(target_name);
if target.exists() {
continue;
}
let mut candidates: Vec<_> = std::fs::read_dir(dir)
.with_context(|| {
format!(
"Failed to read model dir {} for normalization",
dir.display()
)
})?
.filter_map(|e| e.ok())
.filter(|e| {
let name = e.file_name().to_string_lossy().to_lowercase();
name.starts_with(prefix) && name.ends_with(".onnx")
})
.collect();
candidates.sort_by_key(|e| e.file_name().len());
if let Some(entry) = candidates.first() {
tracing::info!(
"Renaming {} -> {}",
entry.file_name().to_string_lossy(),
target_name
);
std::fs::rename(entry.path(), &target).with_context(|| {
format!(
"Failed to rename {} -> {}",
entry.path().display(),
target.display()
)
})?;
}
}
Ok(())
}
fn extract_bundle(archive: &Path, dest_dir: &Path) -> Result<()> {
let file = std::fs::File::open(archive)
.with_context(|| format!("Failed to open archive {}", archive.display()))?;
let bz = BzDecoder::new(file);
let mut tar = Archive::new(bz);
for entry in tar.entries().context("Failed to read tar entries")? {
let mut entry = entry.context("Tar entry read error")?;
let path = entry.path().context("Tar entry has no path")?.into_owned();
let relative = match path.strip_prefix(MODEL_BUNDLE_TOP_DIR) {
Ok(rel) if rel.as_os_str().is_empty() => continue,
Ok(rel) => rel.to_path_buf(),
Err(_) => path.clone(),
};
if relative
.components()
.any(|c| matches!(c, std::path::Component::ParentDir))
{
anyhow::bail!(
"Refusing to extract {}: parent-dir component in archive entry",
relative.display()
);
}
if relative.is_absolute() {
anyhow::bail!(
"Refusing to extract {}: absolute path in archive entry",
relative.display()
);
}
let entry_type = entry.header().entry_type();
if entry_type.is_symlink() || entry_type.is_hard_link() {
anyhow::bail!(
"Refusing to extract {}: symlink/hardlink entries are not allowed",
relative.display()
);
}
let target = dest_dir.join(&relative);
if let Some(parent) = target.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("Failed to create directory {}", parent.display()))?;
}
entry
.unpack(&target)
.with_context(|| format!("Failed to unpack {}", target.display()))?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
#[test]
fn test_home_dir_returns_some() {
assert!(
home_dir().is_some(),
"home_dir() must return Some on this platform"
);
}
#[test]
fn test_default_model_dir_contains_phostt() {
let dir = default_model_dir();
assert!(
dir.contains(".phostt"),
"default_model_dir() should contain \".phostt\", got: {dir}"
);
}
#[test]
fn test_download_progress_basic() {
let mut progress = DownloadProgress::new(1_000_000);
progress.update(500_000);
assert_eq!(progress.current, 500_000);
assert_eq!(progress.last_percent, 50);
progress.finish();
}
#[test]
fn test_download_progress_zero_total() {
let mut progress = DownloadProgress::new(0);
progress.update(100);
assert_eq!(progress.last_percent, 0);
progress.finish();
}
fn sha256_hex(bytes: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(bytes);
bytes_to_hex(hasher.finalize().as_ref())
}
fn stage_partial(final_path: &Path, bytes: &[u8]) -> std::path::PathBuf {
let partial = partial_path(final_path);
let mut f = std::fs::File::create(&partial).expect("create partial");
f.write_all(bytes).expect("write partial");
f.sync_all().expect("sync partial");
partial
}
#[test]
fn test_partial_path_appends_suffix() {
let p = partial_path(Path::new("/tmp/phostt/encoder.onnx"));
assert_eq!(
p,
std::path::PathBuf::from("/tmp/phostt/encoder.onnx.partial"),
);
}
#[test]
fn test_download_writes_partial_then_renames() {
let tmp = tempfile::tempdir().expect("tempdir");
let final_path = tmp.path().join("encoder.onnx");
let payload = b"fake encoder weights";
let expected = sha256_hex(payload);
let partial = stage_partial(&final_path, payload);
assert!(partial.exists(), "precondition: partial is present");
assert!(!final_path.exists(), "precondition: final is absent");
finalize_download(&partial, &final_path, Some(&expected), "encoder.onnx")
.expect("finalize should succeed");
assert!(
!partial.exists(),
"partial must be gone after atomic rename"
);
assert!(
final_path.exists(),
"final path must exist after atomic rename"
);
assert_eq!(std::fs::read(&final_path).unwrap(), payload);
}
#[test]
fn test_download_crash_before_rename_leaves_no_final_file() {
let tmp = tempfile::tempdir().expect("tempdir");
let final_path = tmp.path().join("encoder.onnx");
let partial = stage_partial(&final_path, b"half-written junk");
assert!(partial.exists(), "partial must exist to simulate crash");
assert!(
!final_path.exists(),
"crash before rename must never leave the final artefact visible"
);
assert!(
!model_files_exist(tmp.path()),
"model_files_exist must not accept a staged partial"
);
}
#[test]
fn test_download_rejects_sha256_mismatch() {
let tmp = tempfile::tempdir().expect("tempdir");
let final_path = tmp.path().join("decoder.onnx");
let payload = b"real bytes";
let wrong_expected = sha256_hex(b"different bytes");
let partial = stage_partial(&final_path, payload);
let err = finalize_download(&partial, &final_path, Some(&wrong_expected), "decoder.onnx")
.expect_err("mismatch must error");
let msg = format!("{err}");
assert!(msg.contains("SHA-256 mismatch"), "unexpected error: {msg}");
assert!(!partial.exists(), "partial must be removed on SHA mismatch");
assert!(
!final_path.exists(),
"final must never appear on SHA mismatch"
);
}
#[test]
fn test_download_atomic_on_success_without_checksum() {
let tmp = tempfile::tempdir().expect("tempdir");
let final_path = tmp.path().join("vocab.txt");
let payload = b"token0\ntoken1\n";
let partial = stage_partial(&final_path, payload);
finalize_download(&partial, &final_path, None, "vocab.txt")
.expect("no-checksum finalize should succeed");
assert!(!partial.exists(), "partial must be gone after rename");
assert!(final_path.exists(), "final path must exist");
assert_eq!(std::fs::read(&final_path).unwrap(), payload);
}
#[test]
fn test_sha256_file_matches_in_memory_hash() {
let tmp = tempfile::tempdir().expect("tempdir");
let p = tmp.path().join("blob");
let payload = b"phostt-model-bytes";
std::fs::write(&p, payload).unwrap();
let got = sha256_file(&p).expect("sha256_file");
let want = sha256_hex(payload);
assert_eq!(got, want);
}
#[test]
fn test_model_bundle_sha256_shape() {
assert_eq!(
MODEL_BUNDLE_SHA256.len(),
64,
"MODEL_BUNDLE_SHA256 must be a 64-char hex digest"
);
assert!(
MODEL_BUNDLE_SHA256
.chars()
.all(|c| c.is_ascii_digit() || ('a'..='f').contains(&c)),
"MODEL_BUNDLE_SHA256 must be lowercase hex; got: {MODEL_BUNDLE_SHA256}"
);
}
#[test]
fn test_model_files_list_matches_required_layout() {
for required in [
"encoder.int8.onnx",
"decoder.onnx",
"joiner.int8.onnx",
"bpe.model",
"tokens.txt",
] {
assert!(
MODEL_FILES.contains(&required),
"MODEL_FILES is missing required entry {required}"
);
}
}
#[test]
fn test_extract_bundle_strips_top_dir_and_rejects_traversal() {
use bzip2::Compression;
use bzip2::write::BzEncoder;
use std::io::Cursor;
use tar::Header;
fn append(builder: &mut tar::Builder<&mut Vec<u8>>, path: &str, data: &[u8]) {
let mut header = Header::new_gnu();
header.set_size(data.len() as u64);
header.set_mode(0o644);
header.set_cksum();
builder
.append_data(&mut header, path, Cursor::new(data))
.unwrap();
}
let tmp = tempfile::tempdir().expect("tempdir");
let archive_path = tmp.path().join("bundle.tar.bz2");
{
let mut tar_buf = Vec::new();
{
let mut builder = tar::Builder::new(&mut tar_buf);
append(
&mut builder,
&format!("{MODEL_BUNDLE_TOP_DIR}/encoder.int8.onnx"),
b"encoder-bytes",
);
append(
&mut builder,
&format!("{MODEL_BUNDLE_TOP_DIR}/test_wavs/0.wav"),
b"wav-bytes",
);
builder.finish().unwrap();
}
let mut bz = BzEncoder::new(
std::fs::File::create(&archive_path).unwrap(),
Compression::fast(),
);
std::io::copy(&mut Cursor::new(tar_buf), &mut bz).unwrap();
bz.finish().unwrap();
}
let dest = tmp.path().join("out");
std::fs::create_dir_all(&dest).unwrap();
extract_bundle(&archive_path, &dest).expect("happy-path extract");
assert!(dest.join("encoder.int8.onnx").exists());
assert!(dest.join("test_wavs").join("0.wav").exists());
assert!(
!dest.join(MODEL_BUNDLE_TOP_DIR).exists(),
"top dir must be stripped, not nested"
);
}
}