use std::path::{Path, PathBuf};
use std::process::{Command, Stdio};
use anyhow::{Context, Result};
use futures_util::StreamExt;
use indicatif::{ProgressBar, ProgressStyle};
use reqwest::Client;
use reqwest::header::RANGE;
use serde_json::Value;
use tokio::fs;
use tokio::io::AsyncWriteExt;
use crate::model::{HfModelResponse, HfSibling, ModelChoice, model_directory};
const REQUIRED_MODEL_FILES: &[&str] = &[
"config.json",
"model.bin",
"tokenizer.json",
"tokenizer_config.json",
"vocabulary.json",
"vocabulary.txt",
"preprocessor_config.json",
];
const CURL_DOWNLOAD_THRESHOLD: u64 = 8 * 1024 * 1024;
pub async fn ensure_model_downloaded(
choice: ModelChoice,
models_root: Option<&Path>,
) -> Result<PathBuf> {
let target_dir = model_directory(choice, models_root)?;
fs::create_dir_all(&target_dir)
.await
.with_context(|| format!("failed to create `{}`", target_dir.display()))?;
let manifest_bar = ProgressBar::new_spinner();
manifest_bar.set_style(
ProgressStyle::with_template(" resolving model {spinner:.green} {msg}")
.context("failed to configure manifest spinner")?,
);
manifest_bar.enable_steady_tick(std::time::Duration::from_millis(80));
manifest_bar.set_message(choice.repo_id().to_string());
let client = Client::builder()
.connect_timeout(std::time::Duration::from_secs(15))
.user_agent("transcribe-cli/0.1.0")
.build()
.context("failed to build HTTP client")?;
let manifest = fetch_manifest(&client, choice).await?;
let files = resolve_manifest_sizes(&client, choice, manifest).await?;
manifest_bar.finish_with_message(format!("manifest ready [{} files]", files.len()));
let total_size = files
.iter()
.filter_map(|file| file.expected_size)
.sum::<u64>();
let download_bar = ProgressBar::new(total_size.max(1));
download_bar.set_style(
ProgressStyle::with_template(
" downloading model [{bar:40.cyan/blue}] {bytes}/{total_bytes} {bytes_per_sec} ETA {eta} {msg}",
)
.context("failed to configure download progress bar")?
.progress_chars("=> "),
);
for file in files {
let destination = target_dir.join(&file.relative_path);
if let Some(parent) = destination.parent() {
fs::create_dir_all(parent)
.await
.with_context(|| format!("failed to create `{}`", parent.display()))?;
}
let existing_size = existing_file_size(&destination).await?;
if file_is_complete(existing_size, file.expected_size) {
if let Some(expected_size) = file.expected_size {
download_bar.inc(expected_size);
} else if existing_size > 0 {
download_bar.inc(existing_size);
}
continue;
}
if existing_size > 0 {
download_bar.inc(existing_size);
download_bar.set_message(format!(
"{} (resuming from {})",
file.relative_path,
indicatif::HumanBytes(existing_size)
));
} else {
download_bar.set_message(file.relative_path.clone());
}
download_file(
&client,
&file.download_url,
&destination,
&download_bar,
file.expected_size,
existing_size,
)
.await?;
}
ensure_preprocessor_config(&target_dir).await?;
download_bar.finish_with_message(format!(" downloaded model [{}]", choice.runtime_name()));
Ok(target_dir)
}
async fn ensure_preprocessor_config(model_dir: &Path) -> Result<()> {
let preprocessor_path = model_dir.join("preprocessor_config.json");
if fs::try_exists(&preprocessor_path)
.await
.with_context(|| format!("failed to inspect `{}`", preprocessor_path.display()))?
{
return Ok(());
}
let config_path = model_dir.join("config.json");
let feature_size = if fs::try_exists(&config_path)
.await
.with_context(|| format!("failed to inspect `{}`", config_path.display()))?
{
let config_contents = fs::read_to_string(&config_path)
.await
.with_context(|| format!("failed to read `{}`", config_path.display()))?;
let value: Value = serde_json::from_str(&config_contents)
.with_context(|| format!("failed to parse `{}`", config_path.display()))?;
value
.get("num_mel_bins")
.and_then(Value::as_u64)
.unwrap_or(80) as usize
} else {
80
};
let preprocessor = serde_json::json!({
"chunk_length": 30,
"feature_extractor_type": "WhisperFeatureExtractor",
"feature_size": feature_size,
"hop_length": 160,
"n_fft": 400,
"n_samples": 480000,
"nb_max_frames": 3000,
"padding_side": "right",
"padding_value": 0.0,
"processor_class": "WhisperProcessor",
"return_attention_mask": false,
"sampling_rate": 16000
});
let serialized = serde_json::to_vec_pretty(&preprocessor)
.context("failed to serialize generated preprocessor config")?;
fs::write(&preprocessor_path, serialized)
.await
.with_context(|| format!("failed to write `{}`", preprocessor_path.display()))?;
Ok(())
}
#[derive(Debug)]
struct ResolvedModelFile {
relative_path: String,
download_url: String,
expected_size: Option<u64>,
}
async fn fetch_manifest(client: &Client, choice: ModelChoice) -> Result<Vec<HfSibling>> {
let url = format!("https://huggingface.co/api/models/{}", choice.repo_id());
let response = client
.get(url)
.send()
.await
.with_context(|| format!("failed to fetch model metadata for `{}`", choice.repo_id()))?
.error_for_status()
.with_context(|| format!("model metadata request failed for `{}`", choice.repo_id()))?;
let parsed: HfModelResponse = response
.json()
.await
.with_context(|| format!("failed to decode manifest for `{}`", choice.repo_id()))?;
Ok(parsed
.siblings
.into_iter()
.filter(|file| !file.rfilename.ends_with('/'))
.filter(|file| is_required_model_file(&file.rfilename))
.collect())
}
async fn resolve_manifest_sizes(
_client: &Client,
choice: ModelChoice,
manifest: Vec<HfSibling>,
) -> Result<Vec<ResolvedModelFile>> {
let mut files = Vec::with_capacity(manifest.len());
for file in manifest {
let expected_size = file.expected_size();
let download_url = file_download_url(choice, &file.rfilename);
files.push(ResolvedModelFile {
relative_path: file.rfilename,
download_url,
expected_size,
});
}
Ok(files)
}
async fn existing_file_size(path: &Path) -> Result<u64> {
let Ok(metadata) = fs::metadata(path).await else {
return Ok(0);
};
Ok(metadata.len())
}
fn file_is_complete(existing_size: u64, expected_size: Option<u64>) -> bool {
match expected_size {
Some(expected_size) => existing_size == expected_size,
None => existing_size > 0,
}
}
async fn download_file(
client: &Client,
url: &str,
destination: &Path,
progress: &ProgressBar,
expected_size: Option<u64>,
existing_size: u64,
) -> Result<()> {
if should_use_curl(destination, expected_size) && curl_is_available() {
return download_with_curl(url, destination, progress, expected_size, existing_size);
}
download_file_once(
client,
url,
destination,
progress,
expected_size,
existing_size,
)
.await
}
async fn download_file_once(
client: &Client,
url: &str,
destination: &Path,
progress: &ProgressBar,
expected_size: Option<u64>,
existing_size: u64,
) -> Result<()> {
let mut request = client.get(url);
if existing_size > 0 {
request = request.header(RANGE, format!("bytes={existing_size}-"));
}
let response = request
.send()
.await
.with_context(|| format!("failed to download `{}`", destination.display()))?
.error_for_status()
.with_context(|| format!("download failed for `{}`", destination.display()))?;
let resumed = existing_size > 0 && response.status().as_u16() == 206;
let content_length = response.content_length();
if expected_size.is_none() && existing_size == 0 {
if let Some(content_length) = content_length {
progress.inc_length(content_length);
}
}
let mut downloaded = if resumed { existing_size } else { 0 };
let mut stream = response.bytes_stream();
let mut file = if resumed {
fs::OpenOptions::new()
.append(true)
.open(destination)
.await
.with_context(|| format!("failed to reopen `{}` for resume", destination.display()))?
} else {
if existing_size > 0 {
let current = progress.position();
progress.set_position(current.saturating_sub(existing_size));
progress.set_message(format!(
"{} (resume not supported, restarting)",
destination
.file_name()
.and_then(|name| name.to_str())
.unwrap_or("file")
));
}
fs::File::create(destination)
.await
.with_context(|| format!("failed to create `{}`", destination.display()))?
};
while let Some(chunk) = stream.next().await {
let chunk = chunk.with_context(|| format!("failed to read `{}`", destination.display()))?;
file.write_all(&chunk)
.await
.with_context(|| format!("failed to write `{}`", destination.display()))?;
let chunk_len = chunk.len() as u64;
downloaded += chunk_len;
progress.inc(chunk_len);
}
file.flush()
.await
.with_context(|| format!("failed to flush `{}`", destination.display()))?;
let effective_size = expected_size.or_else(|| {
content_length.map(|length| {
if resumed {
existing_size + length
} else {
length
}
})
});
if let Some(expected_size) = effective_size {
if downloaded != expected_size {
anyhow::bail!(
"incomplete download for `{}`: expected {} bytes, got {} bytes",
destination.display(),
expected_size,
downloaded
);
}
}
Ok(())
}
fn file_download_url(choice: ModelChoice, filename: &str) -> String {
format!(
"https://huggingface.co/{}/resolve/main/{}?download=true",
choice.repo_id(),
filename
)
}
fn is_required_model_file(filename: &str) -> bool {
REQUIRED_MODEL_FILES
.iter()
.any(|required| filename == *required)
}
fn should_use_curl(destination: &Path, expected_size: Option<u64>) -> bool {
destination
.file_name()
.and_then(|name| name.to_str())
.is_some_and(|name| name == "model.bin")
|| expected_size.is_some_and(|size| size >= CURL_DOWNLOAD_THRESHOLD)
}
fn curl_is_available() -> bool {
Command::new("curl")
.arg("--version")
.stdout(Stdio::null())
.stderr(Stdio::null())
.status()
.map(|status| status.success())
.unwrap_or(false)
}
fn download_with_curl(
url: &str,
destination: &Path,
progress: &ProgressBar,
expected_size: Option<u64>,
existing_size: u64,
) -> Result<()> {
progress.suspend(|| {
println!(
" downloading model via curl: {}",
destination
.file_name()
.and_then(|name| name.to_str())
.unwrap_or("file")
);
});
let status = Command::new("curl")
.arg("-L")
.arg("--fail")
.arg("--progress-bar")
.arg("--connect-timeout")
.arg("15")
.arg("--retry")
.arg("5")
.arg("--retry-delay")
.arg("2")
.arg("--continue-at")
.arg("-")
.arg("--output")
.arg(destination)
.arg(url)
.stdin(Stdio::null())
.stdout(Stdio::inherit())
.stderr(Stdio::inherit())
.status()
.with_context(|| format!("failed to spawn curl for `{}`", destination.display()))?;
if !status.success() {
anyhow::bail!("curl failed while downloading `{}`", destination.display());
}
let final_size = std::fs::metadata(destination)
.with_context(|| {
format!(
"failed to stat `{}` after curl download",
destination.display()
)
})?
.len();
if let Some(expected_size) = expected_size {
if final_size != expected_size {
anyhow::bail!(
"incomplete curl download for `{}`: expected {} bytes, got {} bytes",
destination.display(),
expected_size,
final_size
);
}
}
let advanced = final_size.saturating_sub(existing_size);
if advanced > 0 {
progress.inc(advanced);
}
Ok(())
}