use humansize::{BINARY, format_size};
use std::fs;
use std::io::{self, Read, Write};
use std::path::{Path, PathBuf};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ModelDownloadError {
#[error("Failed to determine cache directory location")]
CacheDirectoryNotFound,
#[error("Failed to create cache directory at {path}: {source}")]
DirectoryCreationFailed {
path: PathBuf,
source: std::io::Error,
},
#[error("Failed to download model from {url}: {source}")]
DownloadFailed { url: String, source: reqwest::Error },
#[error("Failed to write model file {path}: {source}")]
WriteFailed {
path: PathBuf,
source: std::io::Error,
},
#[error("Invalid model file at {path}: {reason}")]
InvalidModel { path: PathBuf, reason: String },
#[error("HTTP error downloading model: {0}")]
HttpError(String),
}
pub const SUPPORTED_MODELS: &[&str] = &[
"tiny",
"tiny.en",
"tiny-q5_1",
"tiny.en-q5_1",
"tiny-q8_0",
"base",
"base.en",
"base-q5_1",
"base.en-q5_1",
"base-q8_0",
"small",
"small.en",
"small-q5_1",
"small.en-q5_1",
"small-q8_0",
"medium",
"medium.en",
"medium-q5_0",
"medium.en-q5_0",
"medium-q8_0",
"large-v1",
"large-v2",
"large-v2-q5_0",
"large-v2-q8_0",
"large-v3",
"large-v3-q5_0",
"large-v3-turbo",
"large-v3-turbo-q5_0",
"large-v3-turbo-q8_0",
];
const MODEL_BASE_URL: &str = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main";
const MIN_MODEL_SIZE: u64 = 1024 * 1024;
pub fn ensure_model_available(model_name: &str) -> Result<PathBuf, ModelDownloadError> {
if !SUPPORTED_MODELS.contains(&model_name) {
return Err(ModelDownloadError::InvalidModel {
path: PathBuf::from(model_name),
reason: format!(
"Unsupported model name. Supported models: {}",
SUPPORTED_MODELS.join(", ")
),
});
}
let cache_dir = get_model_cache_dir()?;
let model_path = cache_dir.join(format!("ggml-{}.bin", model_name));
if model_path.exists() {
match fs::metadata(&model_path) {
Ok(metadata) => {
let size = metadata.len();
if size >= MIN_MODEL_SIZE {
return Ok(model_path);
} else {
let _ = fs::remove_file(&model_path);
}
}
Err(_) => {
let _ = fs::remove_file(&model_path);
}
}
}
download_model(model_name, &model_path)?;
Ok(model_path)
}
fn download_model(model_name: &str, target_path: &Path) -> Result<(), ModelDownloadError> {
let url = format!("{}/ggml-{}.bin", MODEL_BASE_URL, model_name);
println!("🔍 Preparing evidence kit...");
println!(
"📥 Downloading Whisper model '{}' from Hugging Face",
model_name
);
println!(" This may take a few minutes depending on your connection...");
print!(" Progress: ");
io::stdout().flush().ok();
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(600)) .build()
.map_err(|e| ModelDownloadError::DownloadFailed {
url: url.clone(),
source: e,
})?;
let mut response = client
.get(&url)
.send()
.map_err(|e| ModelDownloadError::DownloadFailed {
url: url.clone(),
source: e,
})?;
if !response.status().is_success() {
return Err(ModelDownloadError::HttpError(format!(
"HTTP {} while downloading model from {}",
response.status(),
url
)));
}
let total_size = response.content_length();
let temp_path = target_path.with_extension("tmp");
let mut file = fs::File::create(&temp_path).map_err(|e| ModelDownloadError::WriteFailed {
path: temp_path.clone(),
source: e,
})?;
let mut downloaded: u64 = 0;
let mut buffer = [0; 8192]; let mut last_progress_percent = 0;
loop {
let bytes_read =
response
.read(&mut buffer)
.map_err(|e| ModelDownloadError::WriteFailed {
path: temp_path.clone(),
source: e,
})?;
if bytes_read == 0 {
break; }
file.write_all(&buffer[..bytes_read])
.map_err(|e| ModelDownloadError::WriteFailed {
path: temp_path.clone(),
source: e,
})?;
downloaded += bytes_read as u64;
if let Some(total) = total_size {
let progress_percent = (downloaded * 100 / total) as u32;
if progress_percent >= last_progress_percent + 10 {
print!("{}% ", progress_percent);
io::stdout().flush().ok();
last_progress_percent = progress_percent;
}
}
}
println!("100% ✓");
if downloaded < MIN_MODEL_SIZE {
let _ = fs::remove_file(&temp_path);
return Err(ModelDownloadError::InvalidModel {
path: target_path.to_path_buf(),
reason: format!(
"Downloaded file is too small ({} bytes), expected at least {} bytes",
downloaded, MIN_MODEL_SIZE
),
});
}
fs::rename(&temp_path, target_path).map_err(|e| ModelDownloadError::WriteFailed {
path: target_path.to_path_buf(),
source: e,
})?;
println!("✅ Model cached at: {}", target_path.display());
Ok(())
}
fn get_model_cache_dir() -> Result<PathBuf, ModelDownloadError> {
let proj_dirs = directories::ProjectDirs::from("de", "westhoffswelt", "dialogdetective")
.ok_or(ModelDownloadError::CacheDirectoryNotFound)?;
let cache_dir = proj_dirs.cache_dir().join("models");
fs::create_dir_all(&cache_dir).map_err(|e| ModelDownloadError::DirectoryCreationFailed {
path: cache_dir.clone(),
source: e,
})?;
Ok(cache_dir)
}
#[derive(Debug, Clone)]
pub struct CachedModelInfo {
pub model_name: String,
pub path: PathBuf,
pub file_name: String,
pub size_bytes: u64,
}
impl CachedModelInfo {
pub fn size_human_readable(&self) -> String {
format_size(self.size_bytes, BINARY)
}
}
pub fn supported_models() -> &'static [&'static str] {
SUPPORTED_MODELS
}
pub fn get_cache_dir() -> Result<PathBuf, ModelDownloadError> {
get_model_cache_dir()
}
pub fn list_cached_models() -> Result<Vec<CachedModelInfo>, ModelDownloadError> {
let cache_dir = get_model_cache_dir()?;
let mut models = Vec::new();
if let Ok(entries) = fs::read_dir(&cache_dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.is_file() {
if let Some(file_name) = path.file_name().and_then(|n| n.to_str()) {
if file_name.starts_with("ggml-") && file_name.ends_with(".bin") {
let model_name = file_name
.strip_prefix("ggml-")
.and_then(|s| s.strip_suffix(".bin"))
.unwrap_or("")
.to_string();
if !model_name.is_empty() {
let size_bytes = fs::metadata(&path).map(|m| m.len()).unwrap_or(0);
models.push(CachedModelInfo {
model_name,
path: path.clone(),
file_name: file_name.to_string(),
size_bytes,
});
}
}
}
}
}
}
models.sort_by(|a, b| a.model_name.cmp(&b.model_name));
Ok(models)
}