#![allow(dead_code)]
use crate::error::{AumateError, Result};
use hf_hub::api::tokio::Api;
use std::path::PathBuf;
#[derive(Debug, Clone)]
pub struct ModelSource {
pub repo_id: String,
pub revision: Option<String>,
pub files: Vec<String>,
}
impl ModelSource {
pub fn new(repo_id: impl Into<String>) -> Self {
Self { repo_id: repo_id.into(), revision: None, files: Vec::new() }
}
pub fn with_revision(mut self, revision: impl Into<String>) -> Self {
self.revision = Some(revision.into());
self
}
pub fn with_file(mut self, file: impl Into<String>) -> Self {
self.files.push(file.into());
self
}
pub fn with_files(mut self, files: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.files.extend(files.into_iter().map(|f| f.into()));
self
}
}
pub struct ModelDownloader {
api: Api,
}
impl ModelDownloader {
pub async fn new() -> Result<Self> {
let api = Api::new()
.map_err(|e| AumateError::Other(format!("Failed to create HF Hub API: {}", e)))?;
Ok(Self { api })
}
pub async fn download_file(
&self,
repo_id: &str,
filename: &str,
revision: Option<&str>,
) -> Result<PathBuf> {
let repo = if let Some(rev) = revision {
self.api.repo(hf_hub::Repo::with_revision(
repo_id.to_string(),
hf_hub::RepoType::Model,
rev.to_string(),
))
} else {
self.api.model(repo_id.to_string())
};
log::info!("Downloading {} from {}", filename, repo_id);
let path = repo
.get(filename)
.await
.map_err(|e| AumateError::Other(format!("Failed to download {}: {}", filename, e)))?;
log::info!("Downloaded to: {:?}", path);
Ok(path)
}
pub async fn download_source(&self, source: &ModelSource) -> Result<Vec<PathBuf>> {
let mut paths = Vec::new();
for file in &source.files {
let path =
self.download_file(&source.repo_id, file, source.revision.as_deref()).await?;
paths.push(path);
}
Ok(paths)
}
pub async fn download_weights(&self, repo_id: &str) -> Result<PathBuf> {
self.download_file(repo_id, "model.safetensors", None).await
}
pub async fn download_tokenizer(&self, repo_id: &str) -> Result<PathBuf> {
self.download_file(repo_id, "tokenizer.json", None).await
}
pub async fn download_config(&self, repo_id: &str) -> Result<PathBuf> {
self.download_file(repo_id, "config.json", None).await
}
}
pub async fn download_model(source: &ModelSource) -> Result<Vec<PathBuf>> {
let downloader = ModelDownloader::new().await?;
downloader.download_source(source).await
}
pub mod whisper {
use super::ModelSource;
pub fn tiny() -> ModelSource {
ModelSource::new("openai/whisper-tiny").with_files([
"model.safetensors",
"tokenizer.json",
"config.json",
])
}
pub fn tiny_en() -> ModelSource {
ModelSource::new("openai/whisper-tiny.en").with_files([
"model.safetensors",
"tokenizer.json",
"config.json",
])
}
pub fn base() -> ModelSource {
ModelSource::new("openai/whisper-base").with_files([
"model.safetensors",
"tokenizer.json",
"config.json",
])
}
pub fn base_en() -> ModelSource {
ModelSource::new("openai/whisper-base.en").with_files([
"model.safetensors",
"tokenizer.json",
"config.json",
])
}
pub fn small() -> ModelSource {
ModelSource::new("openai/whisper-small").with_files([
"model.safetensors",
"tokenizer.json",
"config.json",
])
}
pub fn small_en() -> ModelSource {
ModelSource::new("openai/whisper-small.en").with_files([
"model.safetensors",
"tokenizer.json",
"config.json",
])
}
pub fn medium() -> ModelSource {
ModelSource::new("openai/whisper-medium").with_files([
"model.safetensors",
"tokenizer.json",
"config.json",
])
}
pub fn medium_en() -> ModelSource {
ModelSource::new("openai/whisper-medium.en").with_files([
"model.safetensors",
"tokenizer.json",
"config.json",
])
}
pub fn large_v3() -> ModelSource {
ModelSource::new("openai/whisper-large-v3").with_files([
"model.safetensors",
"tokenizer.json",
"config.json",
])
}
}
pub mod trocr {
use super::ModelSource;
pub fn base_handwritten() -> ModelSource {
ModelSource::new("microsoft/trocr-base-handwritten").with_files([
"model.safetensors",
"tokenizer.json",
"config.json",
])
}
pub fn base_printed() -> ModelSource {
ModelSource::new("microsoft/trocr-base-printed").with_files([
"model.safetensors",
"tokenizer.json",
"config.json",
])
}
pub fn large_handwritten() -> ModelSource {
ModelSource::new("microsoft/trocr-large-handwritten").with_files([
"model.safetensors",
"tokenizer.json",
"config.json",
])
}
pub fn large_printed() -> ModelSource {
ModelSource::new("microsoft/trocr-large-printed").with_files([
"model.safetensors",
"tokenizer.json",
"config.json",
])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_source_builder() {
let source = ModelSource::new("test/repo")
.with_revision("main")
.with_file("model.safetensors")
.with_files(["tokenizer.json", "config.json"]);
assert_eq!(source.repo_id, "test/repo");
assert_eq!(source.revision, Some("main".to_string()));
assert_eq!(source.files.len(), 3);
}
#[test]
fn test_whisper_sources() {
let tiny = whisper::tiny();
assert_eq!(tiny.repo_id, "openai/whisper-tiny");
assert_eq!(tiny.files.len(), 3);
}
#[test]
fn test_trocr_sources() {
let base = trocr::base_printed();
assert_eq!(base.repo_id, "microsoft/trocr-base-printed");
assert_eq!(base.files.len(), 3);
}
}