use anyhow::{Context, Result};
use futures_util::StreamExt;
use sha2::{Digest, Sha256};
use std::path::Path;
use tokio::io::AsyncWriteExt;
#[cfg(unix)]
use std::os::fd::AsRawFd;
struct DownloadProgress {
total: u64,
current: u64,
last_percent: u8,
}
impl DownloadProgress {
fn new(total: u64) -> Self {
Self {
total,
current: 0,
last_percent: 0,
}
}
fn update(&mut self, bytes: u64) {
self.current += bytes;
let percent = (self.current * 100)
.checked_div(self.total)
.map(|p| p as u8)
.unwrap_or(0);
if percent != self.last_percent {
self.last_percent = percent;
eprint!(
"\rDownloading... {percent}% ({:.1}MB / {:.1}MB)",
self.current as f64 / 1_048_576.0,
self.total as f64 / 1_048_576.0
);
}
}
fn finish(&self) {
eprintln!(
"\rDownload complete ({:.1}MB) ",
self.current as f64 / 1_048_576.0
);
}
}
const HF_REPO: &str = "istupakov/gigaam-v3-onnx";
const PUNCT_HF_REPO: &str = "ekhodzitsky/rupunct-small-onnx";
const VAD_MODEL_URL: &str =
"https://github.com/snakers4/silero-vad/raw/v5.1.2/src/silero_vad/data/silero_vad.onnx";
const VAD_MODEL_SHA256: &str = "2623a2953f6ff3d2c1e61740c6cdb7168133479b267dfef114a4a3cc5bdd788f";
const PUNCT_FILES: &[(&str, &str)] = &[
(
crate::punctuation::PUNCT_MODEL_FILE,
"b105da023474d98aa13ba18953ae67b04b17bd0595034bc06030c17536893933",
),
(
crate::punctuation::PUNCT_TOKENIZER_FILE,
"7ca617388c2092a3a84272025c52bbf3c6db0aee225c0351186295c0b5d3ddc6",
),
(
crate::punctuation::PUNCT_CONFIG_FILE,
"6924a8cf41ec2bd3a3aa73a387ae0ccd0aed253ec7cac4d2f53c7d27440891eb",
),
];
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ModelVariant {
#[default]
Rnnt,
E2eRnnt,
}
impl ModelVariant {
pub fn encoder_file(self) -> &'static str {
match self {
ModelVariant::Rnnt => "v3_rnnt_encoder.onnx",
ModelVariant::E2eRnnt => "v3_e2e_rnnt_encoder.onnx",
}
}
pub fn encoder_int8_file(self) -> &'static str {
match self {
ModelVariant::Rnnt => "v3_rnnt_encoder_int8.onnx",
ModelVariant::E2eRnnt => "v3_e2e_rnnt_encoder_int8.onnx",
}
}
pub fn decoder_file(self) -> &'static str {
match self {
ModelVariant::Rnnt => "v3_rnnt_decoder.onnx",
ModelVariant::E2eRnnt => "v3_e2e_rnnt_decoder.onnx",
}
}
pub fn joint_file(self) -> &'static str {
match self {
ModelVariant::Rnnt => "v3_rnnt_joint.onnx",
ModelVariant::E2eRnnt => "v3_e2e_rnnt_joint.onnx",
}
}
pub fn vocab_file(self) -> &'static str {
match self {
ModelVariant::Rnnt => "v3_vocab.txt",
ModelVariant::E2eRnnt => "v3_e2e_rnnt_vocab.txt",
}
}
pub fn download_files(self) -> [&'static str; 4] {
[
self.encoder_file(),
self.decoder_file(),
self.joint_file(),
self.vocab_file(),
]
}
pub fn checksum(self, filename: &str) -> Option<&'static str> {
let table = match self {
ModelVariant::Rnnt => RNNT_CHECKSUMS,
ModelVariant::E2eRnnt => E2E_RNNT_CHECKSUMS,
};
table
.iter()
.find(|(name, _)| *name == filename)
.and_then(|(_, hash)| *hash)
}
pub fn detect_in_dir(dir: &Path) -> Option<Self> {
[ModelVariant::Rnnt, ModelVariant::E2eRnnt]
.into_iter()
.find(|&variant| {
dir.join(variant.encoder_file()).exists()
|| dir.join(variant.encoder_int8_file()).exists()
})
}
}
impl std::str::FromStr for ModelVariant {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.trim().to_ascii_lowercase().as_str() {
"rnnt" => Ok(ModelVariant::Rnnt),
"e2e_rnnt" | "e2e-rnnt" => Ok(ModelVariant::E2eRnnt),
other => Err(format!(
"unknown model variant '{other}' (expected 'rnnt' or 'e2e_rnnt')"
)),
}
}
}
const RNNT_CHECKSUMS: &[(&str, Option<&str>)] = &[
(
"v3_rnnt_encoder.onnx",
Some("7ae7509c3f1128369564df0b00e2ee4950adf539de2392ac5c800a5bc04c7132"),
),
(
"v3_rnnt_decoder.onnx",
Some("443c3b7bd42b453611618135d6b1e7d9467e5dd97c8a68501da4aa355750c0da"),
),
(
"v3_rnnt_joint.onnx",
Some("fd1d02f45c2ad3d6b67cc149811ad794ab4b020ed49a0a9e2790a8619d1cddd8"),
),
(
"v3_vocab.txt",
Some("a9143c30844d3c0bee3e9e927e4084774eb1b9eeaafc473b2c4521e4911a7c07"),
),
];
const E2E_RNNT_CHECKSUMS: &[(&str, Option<&str>)] = &[
(
"v3_e2e_rnnt_encoder.onnx",
Some("cd60b3764a832e8560ae6d3ad0b10adc1a42ffae412b9476f25620aae4f4a508"),
),
(
"v3_e2e_rnnt_decoder.onnx",
Some("7b0a16d67fd2cb37061decc93c69e364a9ab27afee3c57495d55b1c974cf7231"),
),
(
"v3_e2e_rnnt_joint.onnx",
Some("602ff7017a93311aad34df1437c8d7f49911353c13d6eae7a6ee7b041339465c"),
),
(
"v3_e2e_rnnt_vocab.txt",
Some("39abae20e692998290c574e606f11a9edef2902a1995463fcff63d1490cf22b7"),
),
];
#[cfg(feature = "diarization")]
const SPEAKER_HF_REPO: &str = "onnx-community/wespeaker-voxceleb-resnet34-LM";
#[cfg(feature = "diarization")]
pub const SPEAKER_MODEL_FILE: &str = "wespeaker_resnet34.onnx";
#[cfg(feature = "diarization")]
const SPEAKER_MODEL_SHA256: &str =
"3955447b0499dc9e0a4541a895df08b03c69098eba4e56c02b5603e9f7f4fcbb";
fn home_dir() -> Option<std::path::PathBuf> {
#[cfg(unix)]
{
std::env::var_os("HOME").map(std::path::PathBuf::from)
}
#[cfg(windows)]
{
std::env::var_os("USERPROFILE").map(std::path::PathBuf::from)
}
}
pub fn default_model_dir() -> String {
home_dir()
.map(|h| {
h.join(".gigastt")
.join("models")
.to_string_lossy()
.into_owned()
})
.unwrap_or_else(|| ".gigastt/models".into())
}
pub fn default_punct_model_dir() -> String {
home_dir()
.map(|h| {
h.join(".gigastt")
.join("models")
.join("punct")
.to_string_lossy()
.into_owned()
})
.unwrap_or_else(|| ".gigastt/models/punct".into())
}
pub fn default_vad_model_dir() -> String {
home_dir()
.map(|h| {
h.join(".gigastt")
.join("models")
.join("vad")
.to_string_lossy()
.into_owned()
})
.unwrap_or_else(|| ".gigastt/models/vad".into())
}
#[cfg(unix)]
fn acquire_download_lock(dir: &Path) -> Result<std::fs::File> {
let lock_path = dir.join(".download.lock");
let file = std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&lock_path)
.context("Failed to create download lock file")?;
let fd = file.as_raw_fd();
let ret = unsafe { libc::flock(fd, libc::LOCK_EX) };
if ret != 0 {
anyhow::bail!("Failed to acquire download lock (another process is downloading)");
}
Ok(file)
}
#[derive(Debug, PartialEq, Eq)]
pub enum VariantAction {
Use(ModelVariant),
Download(ModelVariant),
}
pub fn resolve_variant(
requested: Option<ModelVariant>,
existing: Option<ModelVariant>,
) -> VariantAction {
match (requested, existing) {
(Some(req), Some(ex)) if req == ex => VariantAction::Use(req),
(Some(req), _) => VariantAction::Download(req),
(None, Some(ex)) => VariantAction::Use(ex),
(None, None) => VariantAction::Download(ModelVariant::default()),
}
}
pub async fn ensure_model(model_dir: &str) -> Result<()> {
ensure_model_variant(None, model_dir).await?;
Ok(())
}
pub async fn ensure_model_variant(
requested: Option<ModelVariant>,
model_dir: &str,
) -> Result<ModelVariant> {
let dir = Path::new(model_dir);
let existing = ModelVariant::detect_in_dir(dir).filter(|&v| is_model_present(v, dir));
let variant = match resolve_variant(requested, existing) {
VariantAction::Use(v) => {
tracing::info!("Using existing {v:?} model at {model_dir}");
return Ok(v);
}
VariantAction::Download(v) => v,
};
if let Some(other) = existing
&& other != variant
{
tracing::warn!(
"Model directory {model_dir} holds {other:?} files but {variant:?} was \
requested; downloading the {variant:?} set (variants are never mixed)"
);
}
std::fs::create_dir_all(dir).context("Failed to create model directory")?;
#[cfg(unix)]
let _lock = acquire_download_lock(dir)?;
if is_model_present(variant, dir) {
tracing::info!("Model ({variant:?}) found at {model_dir} after lock acquisition");
return Ok(variant);
}
tracing::info!("Model ({variant:?}) not found, downloading from HuggingFace...");
for file in variant.download_files() {
download_file(variant, file, dir).await?;
}
tracing::info!("Model download complete");
Ok(variant)
}
#[cfg(feature = "diarization")]
pub async fn ensure_speaker_model(model_dir: &str) -> Result<()> {
let dir = Path::new(model_dir);
let final_dest = dir.join(SPEAKER_MODEL_FILE);
if final_dest.exists() {
tracing::info!("Speaker model found at {}", final_dest.display());
return Ok(());
}
tracing::info!("Speaker model not found, downloading from HuggingFace...");
std::fs::create_dir_all(dir).context("Failed to create model directory")?;
let url = format!("https://huggingface.co/{SPEAKER_HF_REPO}/resolve/main/onnx/model.onnx");
stream_to_partial_then_finalize(
&url,
&final_dest,
Some(SPEAKER_MODEL_SHA256),
SPEAKER_MODEL_FILE,
)
.await
}
pub async fn ensure_punct_model(punct_model_dir: &str) -> Result<()> {
let dir = Path::new(punct_model_dir);
if PUNCT_FILES.iter().all(|(file, _)| dir.join(file).exists()) {
tracing::info!("Punctuation model found at {punct_model_dir}");
return Ok(());
}
tracing::info!("Punctuation model not found, downloading from HuggingFace...");
std::fs::create_dir_all(dir).context("Failed to create punctuation model directory")?;
#[cfg(unix)]
let _lock = acquire_download_lock(dir)?;
for (file, sha256) in PUNCT_FILES {
let final_dest = dir.join(file);
if final_dest.exists() {
continue;
}
let url = format!("https://huggingface.co/{PUNCT_HF_REPO}/resolve/main/{file}");
stream_to_partial_then_finalize(&url, &final_dest, Some(sha256), file).await?;
}
tracing::info!("Punctuation model download complete");
Ok(())
}
pub async fn ensure_vad_model(vad_model_dir: &str) -> Result<()> {
let dir = Path::new(vad_model_dir);
let final_dest = dir.join(crate::vad::VAD_MODEL_FILE);
if final_dest.exists() {
tracing::info!("VAD model found at {}", final_dest.display());
return Ok(());
}
tracing::info!("VAD model not found, downloading from {VAD_MODEL_URL}...");
std::fs::create_dir_all(dir).context("Failed to create VAD model directory")?;
#[cfg(unix)]
let _lock = acquire_download_lock(dir)?;
if final_dest.exists() {
return Ok(());
}
stream_to_partial_then_finalize(
VAD_MODEL_URL,
&final_dest,
Some(VAD_MODEL_SHA256),
crate::vad::VAD_MODEL_FILE,
)
.await?;
tracing::info!("VAD model download complete");
Ok(())
}
pub fn is_model_present(variant: ModelVariant, dir: &Path) -> bool {
variant
.download_files()
.iter()
.all(|f| dir.join(f).exists())
}
#[cfg(test)]
fn partial_path(final_path: &Path) -> std::path::PathBuf {
let mut s: std::ffi::OsString = final_path.as_os_str().to_owned();
s.push(".partial");
std::path::PathBuf::from(s)
}
fn partial_path_unique(final_path: &Path) -> std::path::PathBuf {
let stamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
let mut s: std::ffi::OsString = final_path.as_os_str().to_owned();
s.push(format!(".partial.{}.{}", std::process::id(), stamp));
std::path::PathBuf::from(s)
}
fn sha256_file(path: &Path) -> Result<String> {
let data = std::fs::read(path)
.with_context(|| format!("Failed to read file for verification: {}", path.display()))?;
let mut hasher = Sha256::new();
hasher.update(&data);
Ok(hex::encode(hasher.finalize()))
}
fn finalize_download(
partial_path: &Path,
final_path: &Path,
expected_sha256: Option<&str>,
label: &str,
) -> Result<()> {
if let Some(expected) = expected_sha256 {
let actual = sha256_file(partial_path)?;
if actual != expected {
let _ = std::fs::remove_file(partial_path);
anyhow::bail!("SHA-256 mismatch for {label}: expected {expected}, got {actual}");
}
tracing::info!("SHA-256 verified: {label}");
}
std::fs::rename(partial_path, final_path).with_context(|| {
format!(
"Failed to rename {} -> {}",
partial_path.display(),
final_path.display()
)
})?;
Ok(())
}
async fn download_file(variant: ModelVariant, filename: &str, dir: &Path) -> Result<()> {
let url = format!("https://huggingface.co/{HF_REPO}/resolve/main/{filename}");
let final_dest = dir.join(filename);
let expected = variant.checksum(filename);
stream_to_partial_then_finalize(&url, &final_dest, expected, filename).await
}
async fn stream_to_partial_then_finalize(
url: &str,
final_dest: &Path,
expected_sha256: Option<&str>,
label: &str,
) -> Result<()> {
let partial = partial_path_unique(final_dest);
tracing::info!("Downloading {label}...");
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(30))
.read_timeout(std::time::Duration::from_secs(300))
.redirect(reqwest::redirect::Policy::limited(5))
.build()
.context("Failed to build HTTP client")?;
let response = client
.get(url)
.send()
.await
.context("HTTP request failed")?;
let status = response.status();
if !status.is_success() {
anyhow::bail!("Download failed for {label}: HTTP {status}");
}
let total_size = response.content_length().unwrap_or(0);
let mut progress = DownloadProgress::new(total_size);
let mut file = tokio::fs::File::create(&partial)
.await
.context("Failed to create partial model file")?;
let mut stream = response.bytes_stream();
let mut downloaded: u64 = 0;
while let Some(chunk) = stream.next().await {
let chunk = chunk.context("Download stream error")?;
file.write_all(&chunk)
.await
.context("Failed to write chunk")?;
downloaded += chunk.len() as u64;
progress.update(chunk.len() as u64);
}
file.flush().await?;
drop(file);
progress.finish();
tracing::info!("Wrote partial {} ({downloaded} bytes)", partial.display());
finalize_download(&partial, final_dest, expected_sha256, label)?;
tracing::info!("Saved {label}");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
#[test]
fn test_home_dir_returns_some() {
assert!(
home_dir().is_some(),
"home_dir() must return Some on this platform"
);
}
#[test]
fn test_default_model_dir_contains_gigastt() {
let dir = default_model_dir();
assert!(
dir.contains(".gigastt"),
"default_model_dir() should contain \".gigastt\", got: {dir}"
);
}
#[test]
fn test_download_progress_basic() {
let mut progress = DownloadProgress::new(1_000_000);
progress.update(500_000);
assert_eq!(progress.current, 500_000);
assert_eq!(progress.last_percent, 50);
progress.finish();
}
#[test]
fn test_download_progress_zero_total() {
let mut progress = DownloadProgress::new(0);
progress.update(100);
assert_eq!(progress.last_percent, 0);
progress.finish();
}
fn sha256_hex(bytes: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(bytes);
hex::encode(hasher.finalize())
}
fn stage_partial(final_path: &Path, bytes: &[u8]) -> std::path::PathBuf {
let partial = partial_path(final_path);
let mut f = std::fs::File::create(&partial).expect("create partial");
f.write_all(bytes).expect("write partial");
f.sync_all().expect("sync partial");
partial
}
#[test]
fn test_partial_path_appends_suffix() {
let p = partial_path(Path::new("/tmp/gigastt/encoder.onnx"));
assert_eq!(
p,
std::path::PathBuf::from("/tmp/gigastt/encoder.onnx.partial"),
);
}
#[test]
fn test_download_writes_partial_then_renames() {
let tmp = tempfile::tempdir().expect("tempdir");
let final_path = tmp.path().join("encoder.onnx");
let payload = b"fake encoder weights";
let expected = sha256_hex(payload);
let partial = stage_partial(&final_path, payload);
assert!(partial.exists(), "precondition: partial is present");
assert!(!final_path.exists(), "precondition: final is absent");
finalize_download(&partial, &final_path, Some(&expected), "encoder.onnx")
.expect("finalize should succeed");
assert!(
!partial.exists(),
"partial must be gone after atomic rename"
);
assert!(
final_path.exists(),
"final path must exist after atomic rename"
);
assert_eq!(std::fs::read(&final_path).unwrap(), payload);
}
#[test]
fn test_download_crash_before_rename_leaves_no_final_file() {
let tmp = tempfile::tempdir().expect("tempdir");
let final_path = tmp.path().join("encoder.onnx");
let partial = stage_partial(&final_path, b"half-written junk");
assert!(partial.exists(), "partial must exist to simulate crash");
assert!(
!final_path.exists(),
"crash before rename must never leave the final artefact visible"
);
assert!(
!is_model_present(ModelVariant::Rnnt, tmp.path()),
"is_model_present must not accept a staged partial"
);
assert!(
!is_model_present(ModelVariant::E2eRnnt, tmp.path()),
"is_model_present must not accept a staged partial"
);
}
#[test]
fn test_download_rejects_sha256_mismatch() {
let tmp = tempfile::tempdir().expect("tempdir");
let final_path = tmp.path().join("decoder.onnx");
let payload = b"real bytes";
let wrong_expected = sha256_hex(b"different bytes");
let partial = stage_partial(&final_path, payload);
let err = finalize_download(&partial, &final_path, Some(&wrong_expected), "decoder.onnx")
.expect_err("mismatch must error");
let msg = format!("{err}");
assert!(msg.contains("SHA-256 mismatch"), "unexpected error: {msg}");
assert!(!partial.exists(), "partial must be removed on SHA mismatch");
assert!(
!final_path.exists(),
"final must never appear on SHA mismatch"
);
}
#[test]
fn test_download_atomic_on_success_without_checksum() {
let tmp = tempfile::tempdir().expect("tempdir");
let final_path = tmp.path().join("vocab.txt");
let payload = b"token0\ntoken1\n";
let partial = stage_partial(&final_path, payload);
finalize_download(&partial, &final_path, None, "vocab.txt")
.expect("no-checksum finalize should succeed");
assert!(!partial.exists(), "partial must be gone after rename");
assert!(final_path.exists(), "final path must exist");
assert_eq!(std::fs::read(&final_path).unwrap(), payload);
}
#[test]
fn test_sha256_file_matches_in_memory_hash() {
let tmp = tempfile::tempdir().expect("tempdir");
let p = tmp.path().join("blob");
let payload = b"gigastt-model-bytes";
std::fs::write(&p, payload).unwrap();
let got = sha256_file(&p).expect("sha256_file");
let want = sha256_hex(payload);
assert_eq!(got, want);
}
#[cfg(feature = "diarization")]
#[test]
fn test_speaker_model_sha256_shape() {
assert_eq!(
SPEAKER_MODEL_SHA256.len(),
64,
"SPEAKER_MODEL_SHA256 must be a 64-char hex digest"
);
assert!(
SPEAKER_MODEL_SHA256
.chars()
.all(|c| c.is_ascii_digit() || ('a'..='f').contains(&c)),
"SPEAKER_MODEL_SHA256 must be lowercase hex; got: {SPEAKER_MODEL_SHA256}"
);
}
#[cfg(feature = "diarization")]
#[test]
fn test_speaker_model_rejects_sha256_mismatch() {
let tmp = tempfile::tempdir().expect("tempdir");
let final_path = tmp.path().join(SPEAKER_MODEL_FILE);
let partial = stage_partial(&final_path, b"not the real wespeaker weights");
let err = finalize_download(
&partial,
&final_path,
Some(SPEAKER_MODEL_SHA256),
SPEAKER_MODEL_FILE,
)
.expect_err("speaker mismatch must error");
assert!(
format!("{err}").contains("SHA-256 mismatch"),
"unexpected error: {err}"
);
assert!(
!partial.exists(),
"partial speaker model must be removed on mismatch"
);
assert!(
!final_path.exists(),
"final speaker model must never appear on mismatch"
);
}
#[cfg(feature = "diarization")]
#[test]
fn test_speaker_model_partial_promoted_on_match() {
let tmp = tempfile::tempdir().expect("tempdir");
let final_path = tmp.path().join(SPEAKER_MODEL_FILE);
let payload = b"wespeaker-surrogate";
let expected = sha256_hex(payload);
let partial = stage_partial(&final_path, payload);
finalize_download(&partial, &final_path, Some(&expected), SPEAKER_MODEL_FILE)
.expect("matching partial must promote");
assert!(!partial.exists());
assert!(final_path.exists());
assert_eq!(std::fs::read(&final_path).unwrap(), payload);
}
#[test]
fn test_partial_path_unique_contains_pid_and_timestamp() {
let p = partial_path_unique(Path::new("/tmp/final.onnx"));
let s = p.to_string_lossy();
assert!(s.contains(".partial."));
assert!(s.contains(&std::process::id().to_string()));
}
#[cfg(unix)]
#[test]
fn test_acquire_download_lock_creates_lock_file() {
let tmp = tempfile::tempdir().expect("tempdir");
let lock = acquire_download_lock(tmp.path()).expect("acquire lock");
assert!(tmp.path().join(".download.lock").exists());
drop(lock);
}
#[tokio::test]
async fn test_stream_to_partial_then_finalize_success() {
let server = wiremock::MockServer::start().await;
let payload = b"fake model bytes";
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/model.onnx"))
.respond_with(
wiremock::ResponseTemplate::new(200)
.set_body_bytes(payload.as_slice())
.insert_header("content-length", payload.len().to_string()),
)
.mount(&server)
.await;
let tmp = tempfile::tempdir().expect("tempdir");
let final_path = tmp.path().join("model.onnx");
let url = format!("{}/model.onnx", server.uri());
stream_to_partial_then_finalize(&url, &final_path, None, "model.onnx")
.await
.expect("download should succeed");
assert!(final_path.exists());
assert_eq!(std::fs::read(&final_path).unwrap(), payload);
}
#[tokio::test]
async fn test_stream_to_partial_then_finalize_http_error() {
let server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/missing.onnx"))
.respond_with(wiremock::ResponseTemplate::new(404))
.mount(&server)
.await;
let tmp = tempfile::tempdir().expect("tempdir");
let final_path = tmp.path().join("missing.onnx");
let url = format!("{}/missing.onnx", server.uri());
let err = stream_to_partial_then_finalize(&url, &final_path, None, "missing.onnx")
.await
.expect_err("404 should fail");
let msg = format!("{err}");
assert!(msg.contains("404"), "error should mention 404: {msg}");
}
#[tokio::test]
async fn test_stream_to_partial_then_finalize_checksum_mismatch() {
let server = wiremock::MockServer::start().await;
let payload = b"wrong bytes";
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/model.onnx"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_bytes(payload.as_slice()))
.mount(&server)
.await;
let tmp = tempfile::tempdir().expect("tempdir");
let final_path = tmp.path().join("model.onnx");
let url = format!("{}/model.onnx", server.uri());
let wrong_hash = sha256_hex(b"different bytes");
let err =
stream_to_partial_then_finalize(&url, &final_path, Some(&wrong_hash), "model.onnx")
.await
.expect_err("checksum mismatch should fail");
let msg = format!("{err}");
assert!(
msg.contains("SHA-256 mismatch"),
"error should mention mismatch: {msg}"
);
}
#[test]
fn test_punct_files_checksums_are_pinned() {
assert_eq!(PUNCT_FILES.len(), 3);
for (file, sum) in PUNCT_FILES {
assert_eq!(
sum.len(),
64,
"{file} punct checksum must be 64 hex chars, got: {sum}"
);
assert!(
sum.chars()
.all(|c| c.is_ascii_digit() || ('a'..='f').contains(&c)),
"{file} punct checksum must be lowercase hex, got: {sum}"
);
}
}
#[tokio::test]
async fn test_ensure_punct_model_present_no_download() {
let tmp = tempfile::tempdir().expect("tempdir");
let dir = tmp.path();
for (file, _) in PUNCT_FILES {
std::fs::write(dir.join(file), b"stub").unwrap();
}
ensure_punct_model(dir.to_str().unwrap())
.await
.expect("present model must short-circuit");
let partials: Vec<_> = std::fs::read_dir(dir)
.unwrap()
.filter_map(|e| e.ok())
.filter(|e| e.file_name().to_string_lossy().contains(".partial"))
.collect();
assert!(partials.is_empty(), "no .partial files: {partials:?}");
for (file, _) in PUNCT_FILES {
assert_eq!(std::fs::read(dir.join(file)).unwrap(), b"stub");
}
}
#[test]
fn test_model_variant_default_is_rnnt() {
assert_eq!(ModelVariant::default(), ModelVariant::Rnnt);
}
#[test]
fn test_model_variant_rnnt_file_mapping() {
let v = ModelVariant::Rnnt;
assert_eq!(v.encoder_file(), "v3_rnnt_encoder.onnx");
assert_eq!(v.encoder_int8_file(), "v3_rnnt_encoder_int8.onnx");
assert_eq!(v.decoder_file(), "v3_rnnt_decoder.onnx");
assert_eq!(v.joint_file(), "v3_rnnt_joint.onnx");
assert_eq!(v.vocab_file(), "v3_vocab.txt");
assert_eq!(
v.download_files(),
[
"v3_rnnt_encoder.onnx",
"v3_rnnt_decoder.onnx",
"v3_rnnt_joint.onnx",
"v3_vocab.txt",
]
);
}
#[test]
fn test_model_variant_e2e_rnnt_file_mapping() {
let v = ModelVariant::E2eRnnt;
assert_eq!(v.encoder_file(), "v3_e2e_rnnt_encoder.onnx");
assert_eq!(v.encoder_int8_file(), "v3_e2e_rnnt_encoder_int8.onnx");
assert_eq!(v.decoder_file(), "v3_e2e_rnnt_decoder.onnx");
assert_eq!(v.joint_file(), "v3_e2e_rnnt_joint.onnx");
assert_eq!(v.vocab_file(), "v3_e2e_rnnt_vocab.txt");
assert_eq!(
v.download_files(),
[
"v3_e2e_rnnt_encoder.onnx",
"v3_e2e_rnnt_decoder.onnx",
"v3_e2e_rnnt_joint.onnx",
"v3_e2e_rnnt_vocab.txt",
]
);
}
#[test]
fn test_model_variant_from_str() {
use std::str::FromStr;
assert_eq!(ModelVariant::from_str("rnnt").unwrap(), ModelVariant::Rnnt);
assert_eq!(
ModelVariant::from_str("e2e_rnnt").unwrap(),
ModelVariant::E2eRnnt
);
assert_eq!(
ModelVariant::from_str("E2E-RNNT").unwrap(),
ModelVariant::E2eRnnt
);
assert_eq!(
ModelVariant::from_str(" RNNT ").unwrap(),
ModelVariant::Rnnt
);
assert!(ModelVariant::from_str("whisper").is_err());
}
#[test]
fn test_model_variant_checksums_are_pinned() {
for variant in [ModelVariant::Rnnt, ModelVariant::E2eRnnt] {
for file in variant.download_files() {
let sum = variant
.checksum(file)
.unwrap_or_else(|| panic!("{variant:?} {file} must have a pinned checksum"));
assert_eq!(
sum.len(),
64,
"{variant:?} {file} checksum must be 64 hex chars, got: {sum}"
);
assert!(
sum.chars()
.all(|c| c.is_ascii_digit() || ('a'..='f').contains(&c)),
"{variant:?} {file} checksum must be lowercase hex, got: {sum}"
);
}
}
}
#[test]
fn test_detect_in_dir_rnnt_by_fp32_encoder() {
let tmp = tempfile::tempdir().expect("tempdir");
std::fs::write(tmp.path().join("v3_rnnt_encoder.onnx"), b"fp32").unwrap();
assert_eq!(
ModelVariant::detect_in_dir(tmp.path()),
Some(ModelVariant::Rnnt)
);
}
#[test]
fn test_detect_in_dir_rnnt_by_int8_encoder() {
let tmp = tempfile::tempdir().expect("tempdir");
std::fs::write(tmp.path().join("v3_rnnt_encoder_int8.onnx"), b"int8").unwrap();
assert_eq!(
ModelVariant::detect_in_dir(tmp.path()),
Some(ModelVariant::Rnnt)
);
}
#[test]
fn test_detect_in_dir_e2e_by_fp32_encoder() {
let tmp = tempfile::tempdir().expect("tempdir");
std::fs::write(tmp.path().join("v3_e2e_rnnt_encoder.onnx"), b"fp32").unwrap();
assert_eq!(
ModelVariant::detect_in_dir(tmp.path()),
Some(ModelVariant::E2eRnnt)
);
}
#[test]
fn test_detect_in_dir_e2e_by_int8_encoder() {
let tmp = tempfile::tempdir().expect("tempdir");
std::fs::write(tmp.path().join("v3_e2e_rnnt_encoder_int8.onnx"), b"int8").unwrap();
assert_eq!(
ModelVariant::detect_in_dir(tmp.path()),
Some(ModelVariant::E2eRnnt)
);
}
#[test]
fn test_detect_in_dir_none_when_empty() {
let tmp = tempfile::tempdir().expect("tempdir");
assert_eq!(ModelVariant::detect_in_dir(tmp.path()), None);
}
#[test]
fn test_is_model_present_per_variant() {
let tmp = tempfile::tempdir().expect("tempdir");
let dir = tmp.path();
for f in ModelVariant::Rnnt.download_files() {
std::fs::write(dir.join(f), b"x").unwrap();
}
assert!(
is_model_present(ModelVariant::Rnnt, dir),
"rnnt set is complete"
);
assert!(
!is_model_present(ModelVariant::E2eRnnt, dir),
"e2e set is absent — must not be reported present"
);
}
#[test]
fn test_is_model_present_false_when_one_file_missing() {
let tmp = tempfile::tempdir().expect("tempdir");
let dir = tmp.path();
for f in [
ModelVariant::Rnnt.encoder_file(),
ModelVariant::Rnnt.decoder_file(),
ModelVariant::Rnnt.joint_file(),
] {
std::fs::write(dir.join(f), b"x").unwrap();
}
assert!(
!is_model_present(ModelVariant::Rnnt, dir),
"a missing vocab must make the set incomplete"
);
}
#[test]
fn test_resolve_variant_none_empty_dir_downloads_default() {
assert_eq!(
resolve_variant(None, None),
VariantAction::Download(ModelVariant::Rnnt),
);
}
#[test]
fn test_resolve_variant_none_e2e_present_uses_e2e() {
assert_eq!(
resolve_variant(None, Some(ModelVariant::E2eRnnt)),
VariantAction::Use(ModelVariant::E2eRnnt),
);
}
#[test]
fn test_resolve_variant_none_rnnt_present_uses_rnnt() {
assert_eq!(
resolve_variant(None, Some(ModelVariant::Rnnt)),
VariantAction::Use(ModelVariant::Rnnt),
);
}
#[test]
fn test_resolve_variant_some_rnnt_rnnt_present_uses_rnnt() {
assert_eq!(
resolve_variant(Some(ModelVariant::Rnnt), Some(ModelVariant::Rnnt)),
VariantAction::Use(ModelVariant::Rnnt),
);
}
#[test]
fn test_resolve_variant_some_e2e_rnnt_present_downloads_e2e() {
assert_eq!(
resolve_variant(Some(ModelVariant::E2eRnnt), Some(ModelVariant::Rnnt)),
VariantAction::Download(ModelVariant::E2eRnnt),
);
}
#[test]
fn test_resolve_variant_some_e2e_empty_downloads_e2e() {
assert_eq!(
resolve_variant(Some(ModelVariant::E2eRnnt), None),
VariantAction::Download(ModelVariant::E2eRnnt),
);
}
#[test]
fn test_resolve_variant_some_rnnt_e2e_present_downloads_rnnt() {
assert_eq!(
resolve_variant(Some(ModelVariant::Rnnt), Some(ModelVariant::E2eRnnt)),
VariantAction::Download(ModelVariant::Rnnt),
);
}
#[tokio::test]
async fn test_ensure_model_none_respects_existing_e2e_install() {
let tmp = tempfile::tempdir().expect("tempdir");
let dir = tmp.path();
for f in ModelVariant::E2eRnnt.download_files() {
std::fs::write(dir.join(f), b"stub").unwrap();
}
let variant = ensure_model_variant(None, dir.to_str().unwrap())
.await
.expect("ensure_model_variant should succeed");
assert_eq!(
variant,
ModelVariant::E2eRnnt,
"must use the installed E2eRnnt"
);
let partials: Vec<_> = std::fs::read_dir(dir)
.unwrap()
.filter_map(|e| e.ok())
.filter(|e| e.file_name().to_string_lossy().contains(".partial"))
.collect();
assert!(
partials.is_empty(),
"no .partial files must exist: {partials:?}"
);
for f in ModelVariant::E2eRnnt.download_files() {
assert_eq!(
std::fs::read(dir.join(f)).unwrap(),
b"stub",
"{f} must be unchanged"
);
}
}
}