use super::progress::{ProgressBar, ProgressStyle};
use super::registry::ModelInfo;
use super::{default_cache_dir, get_hf_token, HubError, Result};
use regex::Regex;
use sha2::{Digest, Sha256};
use std::fs::{self, File};
use std::io::{self, BufWriter, Write};
use std::path::{Path, PathBuf};
const ALLOWED_DOMAINS: &[&str] = &["huggingface.co", "hf.co", "cdn-lfs.huggingface.co"];
fn validate_url(url: &str) -> Result<()> {
let url_lower = url.to_lowercase();
if !url_lower.starts_with("https://") {
return Err(HubError::InvalidFormat(
"Only HTTPS URLs are allowed for downloads".to_string(),
));
}
let without_scheme = &url[8..]; let host_end = without_scheme.find('/').unwrap_or(without_scheme.len());
let host = &without_scheme[..host_end];
let host = host.split(':').next().unwrap_or(host);
let is_allowed = ALLOWED_DOMAINS
.iter()
.any(|&domain| host == domain || host.ends_with(&format!(".{}", domain)));
if !is_allowed {
return Err(HubError::InvalidFormat(format!(
"URL host '{}' is not in the allowed domains: {:?}",
host, ALLOWED_DOMAINS
)));
}
Ok(())
}
fn validate_repo_id(repo_id: &str) -> Result<()> {
let slash_count = repo_id.chars().filter(|&c| c == '/').count();
if slash_count != 1 {
return Err(HubError::InvalidFormat(
"Repository ID must be in format 'username/repo-name'".to_string(),
));
}
let valid_pattern = Regex::new(r"^[a-zA-Z0-9][a-zA-Z0-9._-]*/[a-zA-Z0-9][a-zA-Z0-9._-]*$")
.expect("Invalid regex pattern");
if !valid_pattern.is_match(repo_id) {
return Err(HubError::InvalidFormat(format!(
"Repository ID '{}' contains invalid characters. Only alphanumeric, /, -, _, . are allowed",
repo_id
)));
}
if repo_id.contains("..") {
return Err(HubError::InvalidFormat(
"Repository ID cannot contain '..' (path traversal)".to_string(),
));
}
Ok(())
}
fn validate_and_canonicalize_path(path: &Path, base_dir: &Path) -> Result<PathBuf> {
let canonical_base = base_dir
.canonicalize()
.map_err(|e| HubError::Config(format!("Failed to canonicalize base directory: {}", e)))?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let canonical_path = if path.exists() {
path.canonicalize()
.map_err(|e| HubError::Config(format!("Failed to canonicalize path: {}", e)))?
} else if let Some(parent) = path.parent() {
let canonical_parent = parent
.canonicalize()
.map_err(|e| HubError::Config(format!("Failed to canonicalize parent path: {}", e)))?;
canonical_parent.join(
path.file_name()
.ok_or_else(|| HubError::InvalidFormat("Invalid file path".to_string()))?,
)
} else {
return Err(HubError::InvalidFormat("Invalid file path".to_string()));
};
if !canonical_path.starts_with(&canonical_base) {
return Err(HubError::InvalidFormat(format!(
"Path '{}' is outside allowed directory '{}'",
canonical_path.display(),
canonical_base.display()
)));
}
Ok(canonical_path)
}
#[derive(Debug, Clone)]
pub struct DownloadConfig {
pub cache_dir: PathBuf,
pub hf_token: Option<String>,
pub resume: bool,
pub show_progress: bool,
pub verify_checksum: bool,
pub max_retries: u32,
}
impl Default for DownloadConfig {
fn default() -> Self {
Self {
cache_dir: default_cache_dir(),
hf_token: get_hf_token(),
resume: true,
show_progress: true,
verify_checksum: true,
max_retries: 3,
}
}
}
#[derive(Debug, Clone)]
pub struct DownloadProgress {
pub total_bytes: u64,
pub downloaded_bytes: u64,
pub speed_bps: f64,
pub eta_seconds: f64,
pub stage: DownloadStage,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DownloadStage {
Preparing,
Downloading,
Verifying,
Complete,
Failed(String),
}
impl DownloadProgress {
pub fn percentage(&self) -> f32 {
if self.total_bytes == 0 {
0.0
} else {
(self.downloaded_bytes as f64 / self.total_bytes as f64 * 100.0) as f32
}
}
pub fn speed_str(&self) -> String {
format_bytes_per_sec(self.speed_bps)
}
pub fn eta_str(&self) -> String {
format_duration(self.eta_seconds as u64)
}
}
pub struct ChecksumVerifier {
hasher: Sha256,
bytes_hashed: u64,
}
impl ChecksumVerifier {
pub fn new() -> Self {
Self {
hasher: Sha256::new(),
bytes_hashed: 0,
}
}
pub fn update(&mut self, data: &[u8]) {
self.hasher.update(data);
self.bytes_hashed += data.len() as u64;
}
pub fn finalize(self) -> String {
format!("{:x}", self.hasher.finalize())
}
pub fn verify(self, expected: &str) -> Result<()> {
let actual = self.finalize();
if actual == expected {
Ok(())
} else {
Err(HubError::ChecksumMismatch {
expected: expected.to_string(),
actual,
})
}
}
}
impl Default for ChecksumVerifier {
fn default() -> Self {
Self::new()
}
}
pub struct ModelDownloader {
config: DownloadConfig,
}
impl ModelDownloader {
pub fn new() -> Self {
Self {
config: DownloadConfig::default(),
}
}
pub fn with_config(config: DownloadConfig) -> Self {
Self { config }
}
pub fn download_by_id(&self, model_id: &str) -> Result<PathBuf> {
let registry = super::registry::RuvLtraRegistry::new();
let model_info = registry
.get(model_id)
.ok_or_else(|| HubError::NotFound(model_id.to_string()))?;
self.download(model_info, None)
}
pub fn download(&self, model_info: &ModelInfo, target_path: Option<&Path>) -> Result<PathBuf> {
let path = if let Some(p) = target_path {
p.to_path_buf()
} else {
self.config.cache_dir.join(&model_info.filename)
};
let path = validate_and_canonicalize_path(&path, &self.config.cache_dir)?;
if path.exists() && !self.config.resume {
if self.config.verify_checksum {
if let Some(checksum) = &model_info.checksum {
self.verify_file(&path, checksum)?;
}
}
return Ok(path);
}
let url = model_info.download_url();
validate_url(&url)?;
self.download_file(
&url,
&path,
model_info.size_bytes,
model_info.checksum.as_deref(),
)?;
Ok(path)
}
fn download_file(
&self,
url: &str,
path: &Path,
expected_size: u64,
expected_checksum: Option<&str>,
) -> Result<()> {
if self.has_curl() {
self.download_with_curl(url, path, expected_size, expected_checksum)
} else if self.has_wget() {
self.download_with_wget(url, path, expected_size, expected_checksum)
} else {
Err(HubError::Config(
"Download requires curl or wget. Please install: brew install curl (macOS) or apt install curl (Linux)"
.to_string(),
))
}
}
fn has_curl(&self) -> bool {
std::process::Command::new("which")
.arg("curl")
.output()
.map(|o| o.status.success())
.unwrap_or(false)
}
fn has_wget(&self) -> bool {
std::process::Command::new("which")
.arg("wget")
.output()
.map(|o| o.status.success())
.unwrap_or(false)
}
fn download_with_curl(
&self,
url: &str,
path: &Path,
_expected_size: u64,
expected_checksum: Option<&str>,
) -> Result<()> {
let mut args = vec![
"-L".to_string(), "-#".to_string(), "--fail".to_string(), ];
if self.config.resume && path.exists() {
args.push("-C".to_string());
args.push("-".to_string()); }
if let Some(token) = &self.config.hf_token {
args.push("-H".to_string());
args.push(format!("Authorization: Bearer {}", token));
}
args.push("-o".to_string());
args.push(path.to_str().unwrap().to_string());
args.push(url.to_string());
let status = std::process::Command::new("curl")
.args(&args)
.status()
.map_err(|e| HubError::Network(e.to_string()))?;
if !status.success() {
return Err(HubError::Network(format!(
"curl failed with status: {}",
status
)));
}
if self.config.verify_checksum {
if let Some(checksum) = expected_checksum {
self.verify_file(path, checksum)?;
}
}
Ok(())
}
fn download_with_wget(
&self,
url: &str,
path: &Path,
_expected_size: u64,
expected_checksum: Option<&str>,
) -> Result<()> {
let mut args = vec![
"-q".to_string(), "--show-progress".to_string(), ];
if self.config.resume && path.exists() {
args.push("-c".to_string()); }
if let Some(token) = &self.config.hf_token {
args.push("--header".to_string());
args.push(format!("Authorization: Bearer {}", token));
}
args.push("-O".to_string());
args.push(path.to_str().unwrap().to_string());
args.push(url.to_string());
let status = std::process::Command::new("wget")
.args(&args)
.status()
.map_err(|e| HubError::Network(e.to_string()))?;
if !status.success() {
return Err(HubError::Network(format!(
"wget failed with status: {}",
status
)));
}
if self.config.verify_checksum {
if let Some(checksum) = expected_checksum {
self.verify_file(path, checksum)?;
}
}
Ok(())
}
fn verify_file(&self, path: &Path, expected_checksum: &str) -> Result<()> {
use std::io::Read;
let mut file = File::open(path)?;
let mut verifier = ChecksumVerifier::new();
let mut buffer = [0u8; 8192];
loop {
let n = file.read(&mut buffer)?;
if n == 0 {
break;
}
verifier.update(&buffer[..n]);
}
verifier.verify(expected_checksum)
}
}
impl Default for ModelDownloader {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, thiserror::Error)]
pub enum DownloadError {
#[error("HTTP error: {0}")]
Http(String),
#[error("IO error: {0}")]
Io(#[from] io::Error),
#[error("Checksum verification failed")]
ChecksumMismatch,
}
fn format_bytes_per_sec(bps: f64) -> String {
const KB: f64 = 1024.0;
const MB: f64 = KB * 1024.0;
const GB: f64 = MB * 1024.0;
if bps >= GB {
format!("{:.2} GB/s", bps / GB)
} else if bps >= MB {
format!("{:.2} MB/s", bps / MB)
} else if bps >= KB {
format!("{:.2} KB/s", bps / KB)
} else {
format!("{:.0} B/s", bps)
}
}
fn format_duration(secs: u64) -> String {
if secs < 60 {
format!("{}s", secs)
} else if secs < 3600 {
format!("{}m {}s", secs / 60, secs % 60)
} else {
format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_download_config_default() {
let config = DownloadConfig::default();
assert!(config.resume);
assert!(config.show_progress);
assert!(config.verify_checksum);
}
#[test]
fn test_download_progress() {
let progress = DownloadProgress {
total_bytes: 1000,
downloaded_bytes: 500,
speed_bps: 1024.0 * 1024.0,
eta_seconds: 30.0,
stage: DownloadStage::Downloading,
};
assert_eq!(progress.percentage(), 50.0);
assert!(progress.speed_str().contains("MB/s"));
}
#[test]
fn test_checksum_verifier() {
let mut verifier = ChecksumVerifier::new();
verifier.update(b"hello world");
let checksum = verifier.finalize();
assert!(!checksum.is_empty());
assert_eq!(checksum.len(), 64); }
#[test]
fn test_format_bytes_per_sec() {
assert_eq!(format_bytes_per_sec(500.0), "500 B/s");
assert_eq!(format_bytes_per_sec(1024.0 * 10.0), "10.00 KB/s");
assert_eq!(format_bytes_per_sec(1024.0 * 1024.0 * 5.0), "5.00 MB/s");
}
#[test]
fn test_format_duration() {
assert_eq!(format_duration(30), "30s");
assert_eq!(format_duration(90), "1m 30s");
assert_eq!(format_duration(3700), "1h 1m");
}
}