use anyhow::{bail, Context, Result};
use std::path::{Component, Path, PathBuf};
use std::time::Instant;
use tracing::{info, warn};
const TRACE_TARGET: &str = "studio_worker::engine::download";
const DOWNLOAD_TIMEOUT_SECS: u64 = 30 * 60;
pub fn model_cache_path(dir: &Path, filename: &str) -> Result<PathBuf> {
let path = Path::new(filename);
let mut components = path.components();
match (components.next(), components.next()) {
(Some(Component::Normal(name)), None)
if !filename.contains('/') && !filename.contains('\\') =>
{
Ok(dir.join(name))
}
_ => bail!("model filename must be a plain file name: {filename:?}"),
}
}
pub fn verify_download_len(copied: u64, expected: Option<u64>) -> Result<()> {
match expected {
Some(expected) if copied != expected => bail!(
"size mismatch: wrote {copied} bytes but the server declared \
Content-Length {expected} (download truncated or corrupt)"
),
_ => Ok(()),
}
}
pub fn remove_partial(path: &Path) {
if let Err(e) = std::fs::remove_file(path) {
if e.kind() != std::io::ErrorKind::NotFound {
warn!(
target: TRACE_TARGET,
op = "cleanup",
path = %path.display(),
error = %e,
"failed to remove partial download"
);
}
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
pub fn ensure_file(dir: &Path, filename: &str, url: &str) -> Result<PathBuf> {
let local = model_cache_path(dir, filename)?;
if local.is_file() {
tracing::debug!(
target: TRACE_TARGET,
op = "ensure_file",
filename,
path = %local.display(),
"cached"
);
return Ok(local);
}
download_file(url, &local)
.with_context(|| format!("downloading {filename} ({url}) -> {}", local.display()))?;
Ok(local)
}
#[cfg_attr(coverage_nightly, coverage(off))]
pub fn download_file(url: &str, dest: &Path) -> Result<()> {
if let Some(parent) = dest.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("creating {}", parent.display()))?;
}
let part = dest.with_extension("part");
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(DOWNLOAD_TIMEOUT_SECS))
.user_agent(concat!("studio-worker/", env!("CARGO_PKG_VERSION")))
.build()?;
info!(
target: TRACE_TARGET,
op = "download",
url,
dest = %dest.display(),
"starting"
);
let started = Instant::now();
let mut response = client.get(url).send().context("GET")?;
if !response.status().is_success() {
bail!("GET {url} -> {}", response.status());
}
let expected_len = response.content_length();
let mut file =
std::fs::File::create(&part).with_context(|| format!("creating {}", part.display()))?;
let copied = std::io::copy(&mut response, &mut file);
drop(file);
let bytes = match copied {
Ok(bytes) => bytes,
Err(e) => {
remove_partial(&part);
return Err(e).context("streaming body");
}
};
if let Err(e) = verify_download_len(bytes, expected_len) {
remove_partial(&part);
return Err(e).with_context(|| format!("downloading {url}"));
}
std::fs::rename(&part, dest)
.with_context(|| format!("renaming {} -> {}", part.display(), dest.display()))?;
let elapsed_ms = started.elapsed().as_millis() as u64;
info!(
target: TRACE_TARGET,
op = "download",
url,
dest = %dest.display(),
bytes,
elapsed_ms,
"done"
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn model_cache_path_accepts_plain_filenames_only() {
let root = Path::new("/models");
assert_eq!(
model_cache_path(root, "model.gguf").unwrap(),
PathBuf::from("/models/model.gguf")
);
assert!(model_cache_path(root, "../outside.gguf").is_err());
assert!(model_cache_path(root, "nested/model.gguf").is_err());
assert!(model_cache_path(root, "/tmp/model.gguf").is_err());
assert!(model_cache_path(root, r"nested\model.gguf").is_err());
assert!(model_cache_path(root, "").is_err());
}
#[test]
fn verify_download_len_accepts_exact_match() {
assert!(verify_download_len(2_700_000_000, Some(2_700_000_000)).is_ok());
}
#[test]
fn verify_download_len_accepts_when_length_unknown() {
assert!(verify_download_len(123, None).is_ok());
}
#[test]
fn verify_download_len_rejects_truncated_download() {
let err = verify_download_len(40, Some(100)).unwrap_err().to_string();
assert!(err.contains("size mismatch"), "got: {err}");
assert!(err.contains("40"), "got: {err}");
assert!(err.contains("100"), "got: {err}");
}
#[test]
fn verify_download_len_rejects_overlong_download() {
assert!(verify_download_len(120, Some(100)).is_err());
}
#[test]
fn ensure_file_returns_cached_path_without_network() {
let dir = tempdir().unwrap();
std::fs::write(dir.path().join("cached.gguf"), b"already here").unwrap();
let path = ensure_file(dir.path(), "cached.gguf", "https://example.invalid/x").unwrap();
assert_eq!(path, dir.path().join("cached.gguf"));
assert_eq!(std::fs::read(&path).unwrap(), b"already here");
}
#[test]
fn ensure_file_rejects_path_traversal_before_any_network() {
let dir = tempdir().unwrap();
let err = ensure_file(dir.path(), "../escape.gguf", "https://example.invalid/x")
.unwrap_err()
.to_string();
assert!(err.contains("plain file name"), "got: {err}");
}
#[test]
fn remove_partial_ignores_a_missing_file() {
let dir = tempdir().unwrap();
let out = crate::test_support::capture({
let missing = dir.path().join("never.part");
move || remove_partial(&missing)
});
assert!(
!out.contains("failed to remove partial download"),
"a not-found partial is the desired end state: {out:?}"
);
}
#[test]
fn remove_partial_surfaces_a_failed_removal() {
let dir = tempdir().unwrap();
let stubborn = dir.path().join("subdir");
std::fs::create_dir(&stubborn).unwrap();
let out = crate::test_support::capture(move || remove_partial(&stubborn));
assert!(
out.contains("failed to remove partial download"),
"a failed removal must surface in the logs: {out:?}"
);
}
}