use anyhow::{bail, Context, Result};
use sha2::{Digest, Sha256};
use std::io::Write;
use std::path::{Component, Path, PathBuf};
use std::time::Instant;
use tracing::{info, warn};
use crate::types::ModelFile;
const TRACE_TARGET: &str = "studio_worker::engine::download";
const DOWNLOAD_TIMEOUT_SECS: u64 = 30 * 60;
pub fn model_cache_path(dir: &Path, filename: &str) -> Result<PathBuf> {
let path = Path::new(filename);
let mut components = path.components();
match (components.next(), components.next()) {
(Some(Component::Normal(name)), None)
if !filename.contains('/') && !filename.contains('\\') =>
{
Ok(dir.join(name))
}
_ => bail!("model filename must be a plain file name: {filename:?}"),
}
}
pub fn verify_download_len(copied: u64, expected: Option<u64>) -> Result<()> {
match expected {
Some(expected) if copied != expected => bail!(
"size mismatch: wrote {copied} bytes but the server declared \
Content-Length {expected} (download truncated or corrupt)"
),
_ => Ok(()),
}
}
pub fn verify_sha256(actual_hex: &str, expected: Option<&str>) -> Result<()> {
match expected {
Some(expected) if !actual_hex.eq_ignore_ascii_case(expected.trim()) => bail!(
"sha256 mismatch: downloaded body hashes to {actual_hex} but the registry \
expects {expected} (corrupted or tampered download)"
),
_ => Ok(()),
}
}
struct HashingWriter<W: Write> {
inner: W,
hasher: Sha256,
}
impl<W: Write> Write for HashingWriter<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let written = self.inner.write(buf)?;
self.hasher.update(&buf[..written]);
Ok(written)
}
fn flush(&mut self) -> std::io::Result<()> {
self.inner.flush()
}
}
pub fn sniff_image_extension(bytes: &[u8]) -> Option<&'static str> {
let starts = |sig: &[u8]| bytes.len() >= sig.len() && &bytes[..sig.len()] == sig;
if starts(&[0xff, 0xd8, 0xff]) {
Some("jpg")
} else if starts(&[0x89, b'P', b'N', b'G', 0x0d, 0x0a, 0x1a, 0x0a]) {
Some("png")
} else if bytes.len() >= 12 && &bytes[0..4] == b"RIFF" && &bytes[8..12] == b"WEBP" {
Some("webp")
} else if starts(b"GIF87a") || starts(b"GIF89a") {
Some("gif")
} else if starts(b"BM") {
Some("bmp")
} else if starts(&[0x49, 0x49, 0x2a, 0x00]) || starts(&[0x4d, 0x4d, 0x00, 0x2a]) {
Some("tif")
} else {
None
}
}
pub fn ensure_correct_image_extension(path: &Path) -> Result<PathBuf> {
let mut header = [0u8; 16];
let read = {
use std::io::Read;
let mut file = std::fs::File::open(path)
.with_context(|| format!("opening input image {}", path.display()))?;
file.read(&mut header)
.with_context(|| format!("reading input image header {}", path.display()))?
};
let Some(actual_ext) = sniff_image_extension(&header[..read]) else {
return Ok(path.to_path_buf());
};
let current_ext = path
.extension()
.and_then(|e| e.to_str())
.map(|e| e.to_ascii_lowercase());
let matches = current_ext.as_deref() == Some(actual_ext)
|| (actual_ext == "jpg" && current_ext.as_deref() == Some("jpeg"));
if matches {
return Ok(path.to_path_buf());
}
let corrected = path.with_extension(actual_ext);
std::fs::rename(path, &corrected)
.with_context(|| format!("renaming {} -> {}", path.display(), corrected.display()))?;
info!(
target: TRACE_TARGET,
op = "sniff",
from = %path.display(),
to = %corrected.display(),
actual_ext,
"renamed input image to match its actual format for sd-cli"
);
Ok(corrected)
}
pub fn remove_temp_file(path: &Path) {
if let Err(e) = std::fs::remove_file(path) {
if e.kind() != std::io::ErrorKind::NotFound {
warn!(
target: TRACE_TARGET,
op = "cleanup",
path = %path.display(),
error = %e,
"failed to remove temp file"
);
}
}
}
#[derive(Default)]
pub struct TempFileGuard {
paths: Vec<PathBuf>,
}
impl TempFileGuard {
pub fn new() -> Self {
Self { paths: Vec::new() }
}
pub fn push(&mut self, path: PathBuf) {
self.paths.push(path);
}
}
impl Drop for TempFileGuard {
fn drop(&mut self) {
for path in &self.paths {
remove_temp_file(path);
}
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
pub fn ensure_file(dir: &Path, file: &ModelFile) -> Result<PathBuf> {
let filename = file.filename.as_str();
let url = file.url.as_str();
let local = model_cache_path(dir, filename)?;
if local.is_file() {
tracing::debug!(
target: TRACE_TARGET,
op = "ensure_file",
filename,
path = %local.display(),
"cached"
);
return Ok(local);
}
download_file_verified(url, &local, file.sha256.as_deref())
.with_context(|| format!("downloading {filename} ({url}) -> {}", local.display()))?;
Ok(local)
}
#[cfg_attr(coverage_nightly, coverage(off))]
pub fn download_file(url: &str, dest: &Path) -> Result<()> {
download_file_verified(url, dest, None)
}
#[cfg_attr(coverage_nightly, coverage(off))]
pub fn download_file_verified(url: &str, dest: &Path, expected_sha256: Option<&str>) -> Result<()> {
if let Some(parent) = dest.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("creating {}", parent.display()))?;
}
let part = dest.with_extension("part");
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(DOWNLOAD_TIMEOUT_SECS))
.user_agent(concat!("studio-worker/", env!("CARGO_PKG_VERSION")))
.build()?;
info!(
target: TRACE_TARGET,
op = "download",
url,
dest = %dest.display(),
"starting"
);
let started = Instant::now();
let mut response = match client.get(url).send() {
Ok(response) => response,
Err(e) => {
warn!(
target: TRACE_TARGET,
op = "download",
url,
dest = %dest.display(),
elapsed_ms = started.elapsed().as_millis() as u64,
error = %e,
"download failed: request error"
);
return Err(e).context("GET");
}
};
let status = response.status();
if !status.is_success() {
warn!(
target: TRACE_TARGET,
op = "download",
url,
dest = %dest.display(),
status = status.as_u16(),
elapsed_ms = started.elapsed().as_millis() as u64,
"download failed: non-success status"
);
bail!("GET {url} -> {status}");
}
let expected_len = response.content_length();
let file =
std::fs::File::create(&part).with_context(|| format!("creating {}", part.display()))?;
let mut writer = HashingWriter {
inner: file,
hasher: Sha256::new(),
};
let copied = std::io::copy(&mut response, &mut writer);
let digest = writer.hasher.finalize();
drop(writer.inner);
let bytes = match copied {
Ok(bytes) => bytes,
Err(e) => {
remove_temp_file(&part);
warn!(
target: TRACE_TARGET,
op = "download",
url,
dest = %dest.display(),
elapsed_ms = started.elapsed().as_millis() as u64,
error = %e,
"download failed: streaming body"
);
return Err(e).context("streaming body");
}
};
if let Err(e) = verify_download_len(bytes, expected_len) {
remove_temp_file(&part);
warn!(
target: TRACE_TARGET,
op = "download",
url,
dest = %dest.display(),
bytes,
elapsed_ms = started.elapsed().as_millis() as u64,
error = %e,
"download failed: size mismatch"
);
return Err(e).with_context(|| format!("downloading {url}"));
}
let actual_hex: String = digest.iter().map(|b| format!("{b:02x}")).collect();
if let Err(e) = verify_sha256(&actual_hex, expected_sha256) {
remove_temp_file(&part);
warn!(
target: TRACE_TARGET,
op = "download",
url,
dest = %dest.display(),
bytes,
elapsed_ms = started.elapsed().as_millis() as u64,
error = %e,
"download failed: sha256 mismatch"
);
return Err(e).with_context(|| format!("downloading {url}"));
}
std::fs::rename(&part, dest)
.with_context(|| format!("renaming {} -> {}", part.display(), dest.display()))?;
let elapsed_ms = started.elapsed().as_millis() as u64;
info!(
target: TRACE_TARGET,
op = "download",
url,
dest = %dest.display(),
bytes,
elapsed_ms,
"done"
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
const LOSSY_WEBP: &[u8] = include_bytes!("../../tests/fixtures/lossy-vp8.webp");
#[test]
fn sniff_image_extension_maps_each_magic_to_an_sd_cli_extension() {
assert_eq!(sniff_image_extension(LOSSY_WEBP), Some("webp"));
assert_eq!(
sniff_image_extension(&[0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10]),
Some("jpg"),
"JPEG (the bytes studio serves under .webp URLs)"
);
assert_eq!(
sniff_image_extension(&[0x89, b'P', b'N', b'G', 0x0d, 0x0a, 0x1a, 0x0a]),
Some("png")
);
assert_eq!(sniff_image_extension(b"GIF89a..."), Some("gif"));
assert_eq!(sniff_image_extension(b"BM......"), Some("bmp"));
assert_eq!(
sniff_image_extension(&[0x49, 0x49, 0x2a, 0x00]),
Some("tif")
);
assert_eq!(sniff_image_extension(b"RIFF\x00\x00\x00\x00WAVEfmt "), None);
assert_eq!(sniff_image_extension(b"\x00\x01\x02"), None);
assert_eq!(sniff_image_extension(b""), None);
}
#[test]
fn ensure_correct_image_extension_renames_jpeg_served_as_webp() {
let dir = tempdir().unwrap();
let mislabelled = dir.path().join("out-init.webp");
std::fs::write(
&mislabelled,
[0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10, 0x4a, 0x46],
)
.unwrap();
let corrected = ensure_correct_image_extension(&mislabelled).unwrap();
assert_eq!(corrected, dir.path().join("out-init.jpg"));
assert!(corrected.exists(), "renamed file carries the bytes");
assert!(
!mislabelled.exists(),
"the misnamed file is gone after rename"
);
}
#[test]
fn ensure_correct_image_extension_renames_webp_served_as_png() {
let dir = tempdir().unwrap();
let mislabelled = dir.path().join("out-init.png");
std::fs::write(&mislabelled, LOSSY_WEBP).unwrap();
let corrected = ensure_correct_image_extension(&mislabelled).unwrap();
assert_eq!(corrected, dir.path().join("out-init.webp"));
assert!(corrected.exists() && !mislabelled.exists());
}
#[test]
fn ensure_correct_image_extension_leaves_correct_or_unknown_files_in_place() {
let dir = tempdir().unwrap();
let png = dir.path().join("out-mask.png");
std::fs::write(&png, [0x89, b'P', b'N', b'G', 0x0d, 0x0a, 0x1a, 0x0a]).unwrap();
assert_eq!(ensure_correct_image_extension(&png).unwrap(), png);
assert!(png.exists());
let jpeg = dir.path().join("out-ref.jpeg");
std::fs::write(&jpeg, [0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10]).unwrap();
assert_eq!(ensure_correct_image_extension(&jpeg).unwrap(), jpeg);
assert!(jpeg.exists() && !dir.path().join("out-ref.jpg").exists());
let unknown = dir.path().join("out-init.webp");
std::fs::write(&unknown, [0x00, 0x01, 0x02, 0x03]).unwrap();
assert_eq!(ensure_correct_image_extension(&unknown).unwrap(), unknown);
assert!(unknown.exists());
}
#[test]
fn model_cache_path_accepts_plain_filenames_only() {
let root = Path::new("/models");
assert_eq!(
model_cache_path(root, "model.gguf").unwrap(),
PathBuf::from("/models/model.gguf")
);
assert!(model_cache_path(root, "../outside.gguf").is_err());
assert!(model_cache_path(root, "nested/model.gguf").is_err());
assert!(model_cache_path(root, "/tmp/model.gguf").is_err());
assert!(model_cache_path(root, r"nested\model.gguf").is_err());
assert!(model_cache_path(root, "").is_err());
}
#[test]
fn verify_download_len_accepts_exact_match() {
assert!(verify_download_len(2_700_000_000, Some(2_700_000_000)).is_ok());
}
#[test]
fn verify_download_len_accepts_when_length_unknown() {
assert!(verify_download_len(123, None).is_ok());
}
#[test]
fn verify_download_len_rejects_truncated_download() {
let err = verify_download_len(40, Some(100)).unwrap_err().to_string();
assert!(err.contains("size mismatch"), "got: {err}");
assert!(err.contains("40"), "got: {err}");
assert!(err.contains("100"), "got: {err}");
}
#[test]
fn verify_download_len_rejects_overlong_download() {
assert!(verify_download_len(120, Some(100)).is_err());
}
fn test_file(filename: &str, url: &str) -> ModelFile {
ModelFile {
role: crate::types::ModelFileRole::Model,
url: url.to_string(),
filename: filename.to_string(),
approx_bytes: None,
sha256: None,
}
}
#[test]
fn ensure_file_returns_cached_path_without_network() {
let dir = tempdir().unwrap();
std::fs::write(dir.path().join("cached.gguf"), b"already here").unwrap();
let path = ensure_file(
dir.path(),
&test_file("cached.gguf", "https://example.invalid/x"),
)
.unwrap();
assert_eq!(path, dir.path().join("cached.gguf"));
assert_eq!(std::fs::read(&path).unwrap(), b"already here");
}
#[test]
fn ensure_file_rejects_path_traversal_before_any_network() {
let dir = tempdir().unwrap();
let err = ensure_file(
dir.path(),
&test_file("../escape.gguf", "https://example.invalid/x"),
)
.unwrap_err()
.to_string();
assert!(err.contains("plain file name"), "got: {err}");
}
#[test]
fn verify_sha256_accepts_match_and_absence() {
assert!(verify_sha256("abc123", Some("abc123")).is_ok());
assert!(
verify_sha256("abc123", Some("ABC123")).is_ok(),
"case-insensitive"
);
assert!(
verify_sha256("abc123", Some(" abc123 ")).is_ok(),
"whitespace-tolerant"
);
assert!(
verify_sha256("abc123", None).is_ok(),
"legacy rows have no hash"
);
}
#[test]
fn verify_sha256_rejects_mismatch() {
let err = verify_sha256("abc123", Some("def456"))
.unwrap_err()
.to_string();
assert!(err.contains("sha256 mismatch"), "got: {err}");
assert!(
err.contains("abc123") && err.contains("def456"),
"must name both digests: {err}"
);
}
struct ProbeWriter {
sink: Vec<u8>,
max_per_write: usize,
flushes: usize,
}
impl Write for ProbeWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let take = buf.len().min(self.max_per_write);
self.sink.extend_from_slice(&buf[..take]);
Ok(take)
}
fn flush(&mut self) -> std::io::Result<()> {
self.flushes += 1;
Ok(())
}
}
fn hex(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{b:02x}")).collect()
}
#[test]
fn hashing_writer_hashes_only_the_bytes_the_inner_accepted() {
let mut writer = HashingWriter {
inner: ProbeWriter {
sink: Vec::new(),
max_per_write: 3,
flushes: 0,
},
hasher: Sha256::new(),
};
let written = writer.write(b"abcdefgh").unwrap();
assert_eq!(written, 3, "inner accepts at most 3 bytes per write");
assert_eq!(writer.inner.sink, b"abc", "only the prefix reaches inner");
assert_eq!(
hex(&writer.hasher.finalize()),
hex(&Sha256::digest(b"abc")),
"hash covers only the accepted prefix"
);
}
#[test]
fn hashing_writer_digest_matches_a_short_writing_stream_end_to_end() {
let source = b"the quick brown model weights".to_vec();
let mut reader = source.as_slice();
let mut writer = HashingWriter {
inner: ProbeWriter {
sink: Vec::new(),
max_per_write: 4,
flushes: 0,
},
hasher: Sha256::new(),
};
let copied = std::io::copy(&mut reader, &mut writer).unwrap();
assert_eq!(copied as usize, source.len());
assert_eq!(
writer.inner.sink, source,
"every byte reaches the cache file"
);
assert_eq!(
hex(&writer.hasher.finalize()),
hex(&Sha256::digest(&source)),
"digest matches the full body"
);
}
#[test]
fn hashing_writer_flush_delegates_to_the_inner_writer() {
let mut writer = HashingWriter {
inner: ProbeWriter {
sink: Vec::new(),
max_per_write: usize::MAX,
flushes: 0,
},
hasher: Sha256::new(),
};
writer.flush().unwrap();
writer.flush().unwrap();
assert_eq!(writer.inner.flushes, 2, "flush is forwarded to inner");
}
#[test]
fn remove_temp_file_deletes_an_existing_file_quietly() {
let dir = tempdir().unwrap();
let f = dir.path().join("artefact.webp");
std::fs::write(&f, b"bytes").unwrap();
let out = crate::test_support::capture({
let f = f.clone();
move || remove_temp_file(&f)
});
assert!(!f.exists(), "file should be gone after cleanup");
assert!(
!out.contains("failed to remove temp file"),
"the success path must not warn: {out:?}"
);
}
#[test]
fn remove_temp_file_ignores_a_missing_file() {
let dir = tempdir().unwrap();
let out = crate::test_support::capture({
let missing = dir.path().join("never.part");
move || remove_temp_file(&missing)
});
assert!(
!out.contains("failed to remove temp file"),
"a not-found temp file is the desired end state: {out:?}"
);
}
#[test]
fn remove_temp_file_surfaces_a_failed_removal() {
let dir = tempdir().unwrap();
let stubborn = dir.path().join("subdir");
std::fs::create_dir(&stubborn).unwrap();
let out = crate::test_support::capture(move || remove_temp_file(&stubborn));
assert!(
out.contains("failed to remove temp file"),
"a failed removal must surface in the logs: {out:?}"
);
assert!(
out.contains("subdir"),
"the warning must name the offending path: {out:?}"
);
assert!(
out.contains("cleanup"),
"the warning should tag the cleanup op: {out:?}"
);
}
#[test]
fn temp_file_guard_removes_every_registered_file_on_drop() {
let dir = tempdir().unwrap();
let out = dir.path().join("out.webp");
let init = dir.path().join("out-init.png");
std::fs::write(&out, b"image").unwrap();
std::fs::write(&init, b"init").unwrap();
{
let mut guard = TempFileGuard::new();
guard.push(out.clone());
guard.push(init.clone());
assert!(out.exists() && init.exists(), "files present before drop");
}
assert!(!out.exists(), "output temp must be removed on drop");
assert!(!init.exists(), "init-image temp must be removed on drop");
}
#[test]
fn temp_file_guard_tolerates_a_file_that_never_materialised() {
let dir = tempdir().unwrap();
let missing = dir.path().join("never-written.webp");
let out = crate::test_support::capture(move || {
let mut guard = TempFileGuard::new();
guard.push(missing);
drop(guard);
});
assert!(
!out.contains("failed to remove temp file"),
"a never-created temp file must not warn on cleanup: {out:?}"
);
}
}