use std::fs::{self, File};
use std::io::{BufWriter, Read, Write};
use std::path::{Path, PathBuf};
use std::time::{Duration, Instant};
use crate::error::{InferenceError, Result};
const ASSETS_BASE_URL: &str = "https://github.com/ultralytics/assets/releases/download/v8.4.0";
const MODEL_FAMILIES: &[&str] = &["yolo26", "yolo11"];
const MODEL_SIZES: &[&str] = &["n", "s", "m", "l", "x"];
const MODEL_VARIANTS: &[&str] = &["", "-seg", "-pose", "-obb", "-cls"];
fn downloadable_models() -> Vec<String> {
MODEL_FAMILIES
.iter()
.flat_map(|family| {
MODEL_SIZES.iter().flat_map(move |size| {
MODEL_VARIANTS
.iter()
.map(move |variant| format!("{family}{size}{variant}.onnx"))
})
})
.collect()
}
const DEFAULT_BUS_IMAGE_URL: &str = "https://ultralytics.com/images/bus.jpg";
const DEFAULT_ZIDANE_IMAGE_URL: &str = "https://ultralytics.com/images/zidane.jpg";
const DEFAULT_BOATS_IMAGE_URL: &str = "https://ultralytics.com/images/boats.jpg";
pub const DEFAULT_IMAGES: &[&str] = &[DEFAULT_BUS_IMAGE_URL, DEFAULT_ZIDANE_IMAGE_URL];
pub const DEFAULT_OBB_IMAGE: &str = DEFAULT_BOATS_IMAGE_URL;
const CONNECT_TIMEOUT: u64 = 30;
const READ_TIMEOUT: u64 = 300;
const MAX_RETRIES: u32 = 3;
const RETRY_BASE_DELAY_SECS: u64 = 2;
const BAR_WIDTH: usize = 12;
const MIN_UPDATE_INTERVAL: f64 = 0.1;
fn format_bytes(bytes: f64) -> String {
for (unit, factor) in [
("GB", 1_073_741_824.0_f64),
("MB", 1_048_576.0),
("KB", 1024.0),
] {
if bytes >= factor {
return format!("{:.1}{unit}", bytes / factor);
}
}
format!("{bytes:.0}B")
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
fn format_time(seconds: f64) -> String {
if seconds < 60.0 {
format!("{seconds:.1}s")
} else if seconds < 3600.0 {
format!("{}:{:04.1}", (seconds / 60.0) as u32, seconds % 60.0)
} else {
format!(
"{}:{:02}:{:04.1}",
(seconds / 3600.0) as u32,
((seconds % 3600.0) / 60.0) as u32,
seconds % 60.0
)
}
}
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
fn generate_bar(progress: f64, width: usize) -> String {
let filled = ((progress * width as f64) as usize).min(width);
format!("{}{}", "━".repeat(filled), "─".repeat(width - filled))
}
const fn is_transient(e: &ureq::Error) -> bool {
match e {
ureq::Error::Timeout(_) | ureq::Error::Io(_) => true,
ureq::Error::StatusCode(c) => *c >= 500,
_ => false,
}
}
#[allow(
clippy::similar_names,
clippy::too_many_lines,
clippy::large_stack_arrays,
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
fn download_file(url: &str, dest: &Path) -> Result<()> {
let mut last_err = InferenceError::ModelLoadError(String::new());
for attempt in 1..=MAX_RETRIES {
let attempt_result: std::result::Result<(), (InferenceError, bool)> = (|| {
let config = ureq::Agent::config_builder()
.timeout_connect(Some(Duration::from_secs(CONNECT_TIMEOUT)))
.timeout_recv_body(Some(Duration::from_secs(READ_TIMEOUT)))
.build();
let agent = ureq::Agent::new_with_config(config);
let response = agent.get(url).call().map_err(|e| {
let msg = match &e {
ureq::Error::Timeout(_) => {
format!("Connection timed out while downloading {url}")
}
ureq::Error::Io(io_err) => format!("Network error downloading {url}: {io_err}"),
_ => format!("Failed to download {url}: {e}"),
};
(InferenceError::ModelLoadError(msg), is_transient(&e))
})?;
let total_size: u64 = response
.headers()
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let temp_path = dest.with_file_name(format!(
"{}.part.{}.{}",
dest.file_name()
.and_then(|n| n.to_str())
.unwrap_or("download"),
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos()
));
let _ = fs::remove_file(&temp_path);
let mut downloaded: u64 = 0;
let start_time = Instant::now();
let desc: String = format!("Downloading {url} to '{}'", dest.display());
let stream_result: std::result::Result<(), (InferenceError, bool)> = {
let mut writer = BufWriter::new(File::create(&temp_path).map_err(|e| {
(
InferenceError::ModelLoadError(format!(
"Failed to create temp file {}: {e}",
temp_path.display()
)),
false,
)
})?);
let mut reader = response.into_body().into_reader();
let mut buffer = [0u8; 65536];
let mut last_update = Instant::now();
(|| {
loop {
let bytes_read = reader.read(&mut buffer).map_err(|e| {
(
InferenceError::ModelLoadError(format!(
"Failed to read from network: {e}"
)),
true,
)
})?;
if bytes_read == 0 {
break;
}
writer.write_all(&buffer[..bytes_read]).map_err(|e| {
(
InferenceError::ModelLoadError(format!(
"Failed to write to temp file: {e}"
)),
false,
)
})?;
downloaded += bytes_read as u64;
let now = Instant::now();
if now.duration_since(last_update).as_secs_f64() < MIN_UPDATE_INTERVAL {
continue;
}
last_update = now;
let elapsed = start_time.elapsed().as_secs_f64();
let rate = if elapsed > 0.0 {
downloaded as f64 / elapsed
} else {
0.0
};
if total_size > 0 {
let progress = (downloaded as f64 / total_size as f64).min(1.0);
let bar = generate_bar(progress, BAR_WIDTH);
eprint!(
"\r\x1b[K{desc}: {}% {bar} {}/{} {}/s {}",
(progress * 100.0) as u8,
format_bytes(downloaded as f64),
format_bytes(total_size as f64),
format_bytes(rate),
format_time(elapsed)
);
} else {
eprint!(
"\r\x1b[K{desc}: {} {}/s {}",
format_bytes(downloaded as f64),
format_bytes(rate),
format_time(elapsed)
);
}
std::io::stderr().flush().ok();
}
writer.flush().map_err(|e| {
(
InferenceError::ModelLoadError(format!(
"Failed to flush temp file: {e}"
)),
false,
)
})?;
Ok(())
})()
};
if stream_result.is_err() {
let _ = fs::remove_file(&temp_path);
}
stream_result?;
let elapsed = start_time.elapsed().as_secs_f64();
let rate = if elapsed > 0.0 {
downloaded as f64 / elapsed
} else {
0.0
};
if total_size > 0 {
eprintln!(
"\r\x1b[K{desc}: 100% {} {} {}/s {}",
generate_bar(1.0, BAR_WIDTH),
format_bytes(total_size as f64),
format_bytes(rate),
format_time(elapsed)
);
} else {
eprintln!(
"\r\x1b[K{desc}: {} {}/s {}",
format_bytes(downloaded as f64),
format_bytes(rate),
format_time(elapsed)
);
}
if let Err(e) = fs::rename(&temp_path, dest) {
let _ = fs::remove_file(&temp_path);
if dest.exists() {
return Ok(());
}
return Err((
InferenceError::ModelLoadError(format!(
"Failed to move downloaded file to {}: {e}",
dest.display()
)),
false,
));
}
Ok(())
})();
match attempt_result {
Ok(()) => return Ok(()),
Err((e, false)) => return Err(e),
Err((e, true)) => {
last_err = e;
if attempt < MAX_RETRIES {
let delay = RETRY_BASE_DELAY_SECS * (1 << (attempt - 1));
eprintln!(
"Download attempt {attempt}/{MAX_RETRIES} failed: {last_err}. Retrying in {delay}s..."
);
std::thread::sleep(Duration::from_secs(delay));
}
}
}
}
Err(last_err)
}
#[allow(clippy::missing_errors_doc)]
pub fn try_download_model<P: AsRef<Path>>(path: P) -> Result<PathBuf> {
let path = path.as_ref();
let filename = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
let models = downloadable_models();
if !models.iter().any(|m| m == filename) {
return Err(InferenceError::ModelLoadError(format!(
"Model file not found: {}. Auto-download is supported for: {}",
path.display(),
models.join(", "),
)));
}
download_file(&format!("{ASSETS_BASE_URL}/{filename}"), path)?;
Ok(path.to_path_buf())
}
pub fn download_image(url: &str) -> Result<String> {
let filename = url.rsplit('/').next().unwrap_or("image.jpg");
let dest = Path::new(filename);
if !dest.exists() {
download_file(url, dest)?;
}
Ok(dest
.canonicalize()
.or_else(|_| std::env::current_dir().map(|p| p.join(filename)))
.map_or_else(
|_| filename.to_string(),
|p| p.to_string_lossy().into_owned(),
))
}
#[must_use]
pub fn download_images(urls: &[&str]) -> Vec<String> {
urls.iter()
.filter_map(|url| download_image(url).ok())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unknown_model_returns_error() {
let result = try_download_model("unknown_model.onnx");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("Auto-download is supported for"));
}
#[test]
fn test_format_bytes() {
assert_eq!(format_bytes(500.0), "500B");
assert_eq!(format_bytes(1024.0), "1.0KB");
assert_eq!(format_bytes(1_048_576.0), "1.0MB");
assert_eq!(format_bytes(1_073_741_824.0), "1.0GB");
}
#[test]
fn test_format_time() {
assert_eq!(format_time(5.5), "5.5s");
assert_eq!(format_time(65.0), "1:05.0");
}
#[test]
fn test_generate_bar() {
assert_eq!(generate_bar(0.0, 10), "──────────");
assert_eq!(generate_bar(1.0, 10), "━━━━━━━━━━");
assert_eq!(generate_bar(0.5, 10), "━━━━━─────");
}
}