use crate::error::{AumateError, Result};
use futures_util::StreamExt;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
pub const WHISPER_MODELS: &[(&str, &str, u64, &str)] = &[
(
"whisper-tiny",
"Whisper Tiny (75 MB)",
75_000_000,
"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.bin",
),
(
"whisper-base",
"Whisper Base (142 MB)",
142_000_000,
"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.bin",
),
(
"whisper-small",
"Whisper Small (466 MB)",
466_000_000,
"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.bin",
),
(
"whisper-medium",
"Whisper Medium (1.5 GB)",
1_500_000_000,
"https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-medium.bin",
),
];
#[derive(Debug, Clone)]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub size_bytes: u64,
pub url: String,
pub is_downloaded: bool,
pub local_path: Option<PathBuf>,
}
impl ModelInfo {
pub fn size_display(&self) -> String {
if self.size_bytes >= 1_000_000_000 {
format!("{:.1} GB", self.size_bytes as f64 / 1_000_000_000.0)
} else if self.size_bytes >= 1_000_000 {
format!("{} MB", self.size_bytes / 1_000_000)
} else {
format!("{} KB", self.size_bytes / 1_000)
}
}
}
#[derive(Debug, Clone)]
pub struct DownloadProgress {
pub model_id: String,
pub downloaded_bytes: u64,
pub total_bytes: u64,
pub status: DownloadStatus,
}
impl DownloadProgress {
pub fn progress(&self) -> f32 {
if self.total_bytes == 0 {
0.0
} else {
self.downloaded_bytes as f32 / self.total_bytes as f32
}
}
pub fn progress_percent(&self) -> String {
format!("{:.1}%", self.progress() * 100.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DownloadStatus {
Pending,
Downloading,
Paused,
Completed,
Failed(String),
}
pub struct ModelManager {
models_dir: PathBuf,
downloads: Arc<Mutex<HashMap<String, DownloadProgress>>>,
}
impl ModelManager {
pub fn new() -> Result<Self> {
let models_dir = super::get_models_dir()?;
Ok(Self { models_dir, downloads: Arc::new(Mutex::new(HashMap::new())) })
}
pub fn models_dir(&self) -> &Path {
&self.models_dir
}
pub fn list_available_models(&self) -> Vec<ModelInfo> {
WHISPER_MODELS
.iter()
.map(|(id, name, size, url)| {
let local_path = self.models_dir.join(format!("{}.bin", id));
let is_downloaded = local_path.exists();
ModelInfo {
id: id.to_string(),
name: name.to_string(),
size_bytes: *size,
url: url.to_string(),
is_downloaded,
local_path: if is_downloaded { Some(local_path) } else { None },
}
})
.collect()
}
pub fn list_downloaded_models(&self) -> Vec<ModelInfo> {
self.list_available_models().into_iter().filter(|m| m.is_downloaded).collect()
}
pub fn get_model_path(&self, model_id: &str) -> Option<PathBuf> {
let path = self.models_dir.join(format!("{}.bin", model_id));
if path.exists() { Some(path) } else { None }
}
pub fn get_download_progress(&self, model_id: &str) -> Option<DownloadProgress> {
self.downloads.lock().unwrap().get(model_id).cloned()
}
pub fn download_model_sync(
&self,
model_id: &str,
progress_callback: Option<Box<dyn Fn(DownloadProgress) + Send>>,
) -> Result<PathBuf> {
let model_info = self
.list_available_models()
.into_iter()
.find(|m| m.id == model_id)
.ok_or_else(|| AumateError::Other(format!("Unknown model: {}", model_id)))?;
let filename = format!("{}.bin", model_id);
let output_path = self.models_dir.join(&filename);
let temp_path = self.models_dir.join(format!("{}.tmp", filename));
let progress = DownloadProgress {
model_id: model_id.to_string(),
downloaded_bytes: 0,
total_bytes: model_info.size_bytes,
status: DownloadStatus::Pending,
};
self.downloads.lock().unwrap().insert(model_id.to_string(), progress.clone());
let rt = tokio::runtime::Runtime::new()
.map_err(|e| AumateError::Other(format!("Failed to create runtime: {}", e)))?;
let url = model_info.url.clone();
let downloads = self.downloads.clone();
let model_id_owned = model_id.to_string();
let result = rt.block_on(async {
let start_pos = if temp_path.exists() {
std::fs::metadata(&temp_path).map(|m| m.len()).unwrap_or(0)
} else {
0
};
let client = reqwest::Client::new();
let mut request = client.get(&url);
if start_pos > 0 {
request = request.header("Range", format!("bytes={}-", start_pos));
}
let response = request
.send()
.await
.map_err(|e| AumateError::Other(format!("Download failed: {}", e)))?;
if !response.status().is_success()
&& response.status() != reqwest::StatusCode::PARTIAL_CONTENT
{
return Err(AumateError::Other(format!(
"Download failed with status: {}",
response.status()
)));
}
let total_size = response
.content_length()
.map(|len| len + start_pos)
.unwrap_or(model_info.size_bytes);
{
let mut downloads = downloads.lock().unwrap();
if let Some(p) = downloads.get_mut(&model_id_owned) {
p.downloaded_bytes = start_pos;
p.total_bytes = total_size;
p.status = DownloadStatus::Downloading;
}
}
let mut file =
std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&temp_path)
.map_err(|e| AumateError::Other(format!("Failed to open file: {}", e)))?;
let mut downloaded = start_pos;
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result
.map_err(|e| AumateError::Other(format!("Download error: {}", e)))?;
file.write_all(&chunk)
.map_err(|e| AumateError::Other(format!("Write error: {}", e)))?;
downloaded += chunk.len() as u64;
{
let mut downloads = downloads.lock().unwrap();
if let Some(p) = downloads.get_mut(&model_id_owned) {
p.downloaded_bytes = downloaded;
}
}
if let Some(ref callback) = progress_callback {
callback(DownloadProgress {
model_id: model_id_owned.clone(),
downloaded_bytes: downloaded,
total_bytes: total_size,
status: DownloadStatus::Downloading,
});
}
}
std::fs::rename(&temp_path, &output_path)
.map_err(|e| AumateError::Other(format!("Failed to rename file: {}", e)))?;
{
let mut downloads = downloads.lock().unwrap();
if let Some(p) = downloads.get_mut(&model_id_owned) {
p.status = DownloadStatus::Completed;
}
}
Ok(output_path.clone())
});
if let Err(ref e) = result {
let mut downloads = self.downloads.lock().unwrap();
if let Some(p) = downloads.get_mut(model_id) {
p.status = DownloadStatus::Failed(e.to_string());
}
}
result
}
pub fn delete_model(&self, model_id: &str) -> Result<()> {
let filename = format!("{}.bin", model_id);
let path = self.models_dir.join(filename);
if path.exists() {
std::fs::remove_file(&path)?;
}
Ok(())
}
pub fn verify_model(&self, model_id: &str) -> Result<bool> {
let path = self.get_model_path(model_id);
if let Some(path) = path {
let metadata = std::fs::metadata(&path)?;
Ok(metadata.len() > 0)
} else {
Ok(false)
}
}
#[allow(dead_code)]
fn calculate_hash(path: &Path) -> Result<String> {
let mut file = std::fs::File::open(path)?;
let mut hasher = Sha256::new();
std::io::copy(&mut file, &mut hasher)?;
Ok(format!("{:x}", hasher.finalize()))
}
}
impl Default for ModelManager {
fn default() -> Self {
Self::new().expect("Failed to create model manager")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_info_size_display() {
let model = ModelInfo {
id: "test".to_string(),
name: "Test".to_string(),
size_bytes: 142_000_000,
url: "".to_string(),
is_downloaded: false,
local_path: None,
};
assert_eq!(model.size_display(), "142 MB");
}
#[test]
fn test_download_progress() {
let progress = DownloadProgress {
model_id: "test".to_string(),
downloaded_bytes: 50,
total_bytes: 100,
status: DownloadStatus::Downloading,
};
assert_eq!(progress.progress(), 0.5);
assert_eq!(progress.progress_percent(), "50.0%");
}
#[test]
fn test_list_available_models() {
if let Ok(manager) = ModelManager::new() {
let models = manager.list_available_models();
assert!(!models.is_empty());
assert!(models.iter().any(|m| m.id == "whisper-base"));
}
}
}