#[cfg(feature = "local-embeddings")]
use anyhow::Context;
#[cfg(all(feature = "local-embeddings", feature = "reqwest"))]
use reqwest::Client;
#[cfg(all(feature = "local-embeddings", feature = "reqwest"))]
pub async fn download_model(model_name: &str, cache_dir: &std::path::Path) -> anyhow::Result<()> {
tracing::info!("Starting model download from HuggingFace: {}", model_name);
let sanitized_name = model_name.replace('/', "_");
let base_url = format!("https://huggingface.co/{model_name}/resolve/main");
let files_to_download = vec![
format!("{sanitized_name}.onnx"),
format!("{sanitized_name}_tokenizer.json"),
format!("{sanitized_name}_config.json"),
];
for filename in &files_to_download {
let url = format!("{base_url}/{filename}");
let target_path = cache_dir.join(filename);
if target_path.exists() {
let metadata = std::fs::metadata(&target_path).with_context(|| {
format!(
"Failed to read metadata for existing file: {}",
target_path.display()
)
})?;
if metadata.len() > 0 {
tracing::info!(
"File already exists and is valid: {} ({:.2} MB)",
filename,
metadata.len() as f64 / 1_048_576.0
);
continue;
}
tracing::warn!("File exists but is empty, will re-download: {}", filename);
}
tracing::info!("Downloading {} from {}", filename, url);
download_file_with_progress(&url, &target_path).await?;
validate_downloaded_file(&target_path, filename)?;
}
tracing::info!("Successfully downloaded model: {}", model_name);
Ok(())
}
#[cfg(all(feature = "local-embeddings", feature = "reqwest"))]
async fn download_file_with_progress(url: &str, path: &std::path::Path) -> anyhow::Result<()> {
use std::time::Duration;
let max_retries = 3;
let mut retry_count = 0;
while retry_count < max_retries {
match attempt_download(url, path).await {
Ok(()) => return Ok(()),
Err(e) if retry_count < max_retries - 1 => {
retry_count += 1;
let backoff_duration = Duration::from_millis(100 * 2_u64.pow(retry_count));
tracing::warn!(
"Download attempt {} failed: {}. Retrying in {:?}...",
retry_count + 1,
e,
backoff_duration
);
tokio::time::sleep(backoff_duration).await;
if path.exists() {
let _ = std::fs::remove_file(path);
}
}
Err(e) => {
if path.exists() {
let _ = std::fs::remove_file(path);
}
return Err(e);
}
}
}
Err(anyhow::anyhow!(
"Failed to download file after {max_retries} retries"
))
}
#[cfg(all(feature = "local-embeddings", feature = "reqwest"))]
async fn attempt_download(url: &str, path: &std::path::Path) -> anyhow::Result<()> {
use futures::StreamExt;
use std::time::Duration;
use tokio::io::{AsyncWriteExt, BufWriter};
let client = Client::builder()
.timeout(Duration::from_secs(300))
.build()
.context("Failed to build HTTP client")?;
let response = client
.get(url)
.send()
.await
.with_context(|| format!("Failed to fetch URL: {url}"))?;
if !response.status().is_success() {
return Err(anyhow::anyhow!(
"HTTP error downloading {url}: {}",
response.status()
));
}
let total_size = response.content_length().unwrap_or(0);
let mut downloaded = 0u64;
let mut stream = response.bytes_stream();
if let Some(parent) = path.parent() {
tokio::fs::create_dir_all(parent)
.await
.with_context(|| format!("Failed to create directory: {}", parent.display()))?;
}
let file = tokio::fs::File::create(path)
.await
.with_context(|| format!("Failed to create file: {}", path.display()))?;
let mut writer = BufWriter::new(file);
let start_time = std::time::Instant::now();
while let Some(chunk_result) = stream.next().await {
let chunk = chunk_result.with_context(|| "Failed to read download chunk")?;
writer
.write_all(&chunk)
.await
.context("Failed to write to file")?;
downloaded += chunk.len() as u64;
if total_size > 0 && (downloaded % 1_048_576 == 0 || start_time.elapsed().as_secs() >= 5) {
let progress = (downloaded as f64 / total_size as f64) * 100.0;
let elapsed = start_time.elapsed();
let speed = if elapsed.as_secs() > 0 {
(downloaded as f64 / 1_048_576.0) / elapsed.as_secs_f64()
} else {
0.0
};
tracing::info!(
"Progress: {:.1}% ({:.2}/{:.2} MB, {:.2} MB/s)",
progress,
downloaded as f64 / 1_048_576.0,
total_size as f64 / 1_048_576.0,
speed
);
}
}
writer.flush().await.context("Failed to flush file")?;
let elapsed = start_time.elapsed();
let speed = if elapsed.as_secs_f64() > 0.0 {
(downloaded as f64 / 1_048_576.0) / elapsed.as_secs_f64()
} else {
0.0
};
tracing::info!(
"Successfully downloaded {} bytes to {} in {:.2}s ({:.2} MB/s)",
downloaded,
path.display(),
elapsed.as_secs_f64(),
speed
);
Ok(())
}
#[cfg(all(feature = "local-embeddings", feature = "reqwest"))]
pub fn validate_downloaded_file(path: &std::path::Path, filename: &str) -> anyhow::Result<()> {
if !path.exists() {
return Err(anyhow::anyhow!(
"Downloaded file not found: {}",
path.display()
));
}
let metadata = std::fs::metadata(path)
.with_context(|| format!("Failed to read metadata for file: {}", path.display()))?;
if metadata.len() == 0 {
return Err(anyhow::anyhow!(
"Downloaded file is empty: {}",
path.display()
));
}
use std::io::Read;
let mut file = std::fs::File::open(path)
.with_context(|| format!("Failed to open file for validation: {}", path.display()))?;
let mut buffer = [0u8; 1024];
let bytes_read = file
.read(&mut buffer)
.with_context(|| format!("Failed to read from file: {}", path.display()))?;
if bytes_read == 0 {
return Err(anyhow::anyhow!(
"File contains no readable data: {}",
path.display()
));
}
tracing::info!(
"Validated {}: {:.2} MB, readable",
filename,
metadata.len() as f64 / 1_048_576.0
);
Ok(())
}