use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use crate::WhisperModel;
#[derive(Debug)]
pub enum ModelError {
Io(std::io::Error),
Network(String),
NoProjectDir,
AlreadyDownloading,
}
impl std::fmt::Display for ModelError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModelError::Io(e) => write!(f, "I/O error: {}", e),
ModelError::Network(msg) => write!(f, "Network error: {}", msg),
ModelError::NoProjectDir => write!(f, "Could not determine platform cache directory"),
ModelError::AlreadyDownloading => {
write!(f, "A download for this model is already in progress")
}
}
}
}
impl std::error::Error for ModelError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
ModelError::Io(e) => Some(e),
_ => None,
}
}
}
#[derive(Clone)]
pub struct ModelManager {
cache_dir: PathBuf,
in_progress: Arc<Mutex<HashSet<String>>>,
}
impl ModelManager {
pub fn new(app_name: &str) -> Self {
let base = directories::BaseDirs::new()
.map(|d| d.cache_dir().to_path_buf())
.unwrap_or_else(|| PathBuf::from("."));
let cache_dir = base.join(app_name).join("whisper");
Self {
cache_dir,
in_progress: Arc::new(Mutex::new(HashSet::new())),
}
}
pub fn path(&self, model: WhisperModel) -> PathBuf {
self.cache_dir.join(format!("ggml-{}.bin", model.slug()))
}
pub fn is_available(&self, model: WhisperModel) -> bool {
let p = self.path(model);
p.exists() && p.metadata().map(|m| m.len() > 0).unwrap_or(false)
}
pub fn list_cached(&self) -> Vec<WhisperModel> {
WhisperModel::all_in_size_order()
.iter()
.copied()
.filter(|&m| self.is_available(m))
.collect()
}
pub async fn download(
&self,
model: WhisperModel,
on_progress: impl Fn(u8) + Send + 'static,
) -> Result<(), ModelError> {
let slug = model.slug().to_string();
{
let mut in_progress = self
.in_progress
.lock()
.expect("ModelManager in_progress lock poisoned");
if in_progress.contains(&slug) {
return Err(ModelError::AlreadyDownloading);
}
in_progress.insert(slug.clone());
}
let result = self.do_download(model, on_progress).await;
{
let mut in_progress = self
.in_progress
.lock()
.expect("ModelManager in_progress lock poisoned");
in_progress.remove(&slug);
}
result
}
async fn do_download(
&self,
model: WhisperModel,
on_progress: impl Fn(u8),
) -> Result<(), ModelError> {
use tokio::io::AsyncWriteExt;
let dest = self.path(model);
if let Some(parent) = dest.parent() {
std::fs::create_dir_all(parent).map_err(ModelError::Io)?;
}
let url = model.download_url();
tracing::info!("[ModelManager] Downloading {} from {}", model.slug(), url);
let client = reqwest::Client::new();
let response = client
.get(&url)
.send()
.await
.map_err(|e| ModelError::Network(e.to_string()))?;
if !response.status().is_success() {
return Err(ModelError::Network(format!("HTTP {}", response.status())));
}
let total_size = response.content_length().unwrap_or(0);
let mut downloaded: u64 = 0;
let mut last_percent: u8 = 0;
on_progress(0);
let tmp_path = dest.with_extension("bin.part");
let mut file = tokio::fs::File::create(&tmp_path)
.await
.map_err(ModelError::Io)?;
let mut stream = response.bytes_stream();
use futures::StreamExt;
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.map_err(|e| ModelError::Network(e.to_string()))?;
file.write_all(&chunk).await.map_err(ModelError::Io)?;
downloaded += chunk.len() as u64;
if total_size > 0 {
let percent = ((downloaded * 100) / total_size).min(99) as u8;
if percent > last_percent {
on_progress(percent);
last_percent = percent;
}
}
}
file.flush().await.map_err(ModelError::Io)?;
drop(file);
tokio::fs::rename(&tmp_path, &dest)
.await
.map_err(ModelError::Io)?;
on_progress(100);
tracing::info!(
"[ModelManager] Downloaded {} ({} bytes)",
model.slug(),
downloaded
);
Ok(())
}
}