use std::io::{Read, Write};
use std::net::TcpListener;
use studio_worker::engine::download;
use studio_worker::test_support::capture as captured_logs_for;
use studio_worker::types::{ModelFile, ModelFileRole};
use wiremock::matchers::{method, path as match_path};
use wiremock::{Mock, MockServer, ResponseTemplate};
const BODY_SHA256: &str = "7820645b979bcfe59530fcd3b377c10e20bffed93396b8b3ffbd506f06aaacfe";
fn model_file(filename: &str, url: &str, sha256: Option<&str>) -> ModelFile {
ModelFile {
role: ModelFileRole::Model,
url: url.to_string(),
filename: filename.to_string(),
approx_bytes: None,
sha256: sha256.map(str::to_string),
}
}
fn detached<R: Send + 'static>(f: impl FnOnce() -> R + Send + 'static) -> R {
std::thread::spawn(f)
.join()
.expect("worker thread panicked")
}
fn serve_truncated_response() -> String {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind loopback");
let addr = listener.local_addr().expect("local addr");
std::thread::spawn(move || {
if let Ok((mut stream, _)) = listener.accept() {
let mut buf = [0u8; 1024];
let _ = stream.read(&mut buf);
let _ = stream
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 9999\r\n\r\nonly a few bytes");
let _ = stream.flush();
}
});
format!("http://{addr}")
}
fn refused_loopback_url() -> String {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind loopback");
let addr = listener.local_addr().expect("local addr");
drop(listener);
format!("http://{addr}/model.gguf")
}
#[tokio::test]
async fn download_file_writes_the_served_bytes() {
let server = MockServer::start().await;
let body = b"a tiny pretend model".to_vec();
Mock::given(method("GET"))
.and(match_path("/model.gguf"))
.respond_with(ResponseTemplate::new(200).set_body_bytes(body.clone()))
.mount(&server)
.await;
let url = format!("{}/model.gguf", server.uri());
let dir = tempfile::tempdir().unwrap();
let dest = dir.path().join("model.gguf");
let dest_for_thread = dest.clone();
detached(move || download::download_file(&url, &dest_for_thread).unwrap());
assert_eq!(std::fs::read(&dest).unwrap(), body);
assert!(!dest.with_extension("part").exists());
}
#[tokio::test]
async fn ensure_file_accepts_a_matching_sha256() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(match_path("/verified.gguf"))
.respond_with(ResponseTemplate::new(200).set_body_bytes(b"a tiny pretend model".to_vec()))
.mount(&server)
.await;
let url = format!("{}/verified.gguf", server.uri());
let dir = tempfile::tempdir().unwrap();
let dir_path = dir.path().to_path_buf();
let file = model_file("verified.gguf", &url, Some(BODY_SHA256));
let local = detached(move || download::ensure_file(&dir_path, &file).unwrap());
assert!(local.is_file());
}
#[tokio::test]
async fn ensure_file_rejects_a_sha256_mismatch_and_caches_nothing() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(match_path("/tampered.gguf"))
.respond_with(ResponseTemplate::new(200).set_body_bytes(b"evil bytes".to_vec()))
.mount(&server)
.await;
let url = format!("{}/tampered.gguf", server.uri());
let dir = tempfile::tempdir().unwrap();
let dir_path = dir.path().to_path_buf();
let file = model_file("tampered.gguf", &url, Some(BODY_SHA256));
let err = detached(move || download::ensure_file(&dir_path, &file).unwrap_err());
assert!(
err.to_string().contains("sha256") || format!("{err:#}").contains("sha256"),
"error must name the sha256 mismatch: {err:#}"
);
assert!(!dir.path().join("tampered.gguf").exists());
assert!(!dir.path().join("tampered.part").exists());
}
#[tokio::test]
async fn ensure_file_is_case_insensitive_about_the_expected_hash() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(match_path("/upper.gguf"))
.respond_with(ResponseTemplate::new(200).set_body_bytes(b"a tiny pretend model".to_vec()))
.mount(&server)
.await;
let url = format!("{}/upper.gguf", server.uri());
let dir = tempfile::tempdir().unwrap();
let dir_path = dir.path().to_path_buf();
let file = model_file("upper.gguf", &url, Some(&BODY_SHA256.to_uppercase()));
let local = detached(move || download::ensure_file(&dir_path, &file).unwrap());
assert!(local.is_file());
}
#[tokio::test]
async fn download_file_surfaces_a_non_success_status() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(match_path("/missing.gguf"))
.respond_with(ResponseTemplate::new(404))
.mount(&server)
.await;
let url = format!("{}/missing.gguf", server.uri());
let dir = tempfile::tempdir().unwrap();
let dest = dir.path().join("missing.gguf");
let dest_for_thread = dest.clone();
let err = detached(move || {
download::download_file(&url, &dest_for_thread)
.expect_err("404 must error")
.to_string()
});
assert!(err.contains("404"), "got: {err}");
assert!(!dest.exists());
}
#[test]
fn download_surfaces_a_connection_level_failure_and_caches_nothing() {
let url = refused_loopback_url();
let dir = tempfile::tempdir().unwrap();
let dest = dir.path().join("refused.gguf");
let dest_for_thread = dest.clone();
let err = detached(move || {
download::download_file(&url, &dest_for_thread)
.expect_err("a refused connection must error")
.to_string()
});
assert_eq!(
err, "GET",
"expected the connection-level GET context, got: {err}"
);
assert!(!dest.exists(), "no file committed on a refused connection");
assert!(
!dest.with_extension("part").exists(),
".part scratch litter left behind"
);
}
#[test]
fn download_rejects_a_truncated_body_and_caches_nothing() {
let url = format!("{}/truncated.gguf", serve_truncated_response());
let dir = tempfile::tempdir().unwrap();
let dest = dir.path().join("truncated.gguf");
let dest_for_thread = dest.clone();
let err = detached(move || {
download::download_file(&url, &dest_for_thread)
.expect_err("a truncated download must be rejected")
.to_string()
});
assert!(
err.contains("streaming body"),
"expected the streaming-read integrity guard to fire, got: {err}"
);
assert!(!dest.exists(), "no truncated file may be committed");
assert!(
!dest.with_extension("part").exists(),
".part scratch litter left behind"
);
}
#[test]
fn truncated_download_emits_a_warn_breadcrumb() {
let url = format!("{}/short.gguf", serve_truncated_response());
let dir = tempfile::tempdir().unwrap();
let dest = dir.path().join("short.gguf");
let dest_for_thread = dest.clone();
let logs = captured_logs_for(move || {
let _ = download::download_file(&url, &dest_for_thread);
});
assert!(logs.contains("WARN"), "expected WARN event, got: {logs}");
assert!(
logs.contains("op=\"download\""),
"expected op field, got: {logs}"
);
assert!(
logs.contains("download failed: streaming body"),
"expected the streaming-body failure breadcrumb, got: {logs}"
);
assert!(
logs.contains("elapsed_ms"),
"expected elapsed_ms field, got: {logs}"
);
}
#[tokio::test]
async fn download_failure_emits_a_warn_breadcrumb_with_status() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(match_path("/gone.gguf"))
.respond_with(ResponseTemplate::new(404))
.mount(&server)
.await;
let url = format!("{}/gone.gguf", server.uri());
let dir = tempfile::tempdir().unwrap();
let dest = dir.path().join("gone.gguf");
let dest_for_thread = dest.clone();
let logs = captured_logs_for(move || {
let _ = download::download_file(&url, &dest_for_thread);
});
assert!(logs.contains("WARN"), "expected WARN event, got: {logs}");
assert!(
logs.contains("op=\"download\""),
"expected op field, got: {logs}"
);
assert!(
logs.contains("status=404"),
"expected status field, got: {logs}"
);
assert!(
logs.contains("elapsed_ms"),
"expected elapsed_ms field, got: {logs}"
);
assert!(
logs.contains("gone.gguf"),
"expected the dest/url in the breadcrumb, got: {logs}"
);
assert!(!dest.exists(), "no file committed on a failed download");
}
#[test]
fn connection_level_failure_emits_a_warn_breadcrumb() {
let url = refused_loopback_url();
let dir = tempfile::tempdir().unwrap();
let dest = dir.path().join("refused.gguf");
let dest_for_thread = dest.clone();
let logs = captured_logs_for(move || {
let _ = download::download_file(&url, &dest_for_thread);
});
assert!(logs.contains("WARN"), "expected WARN event, got: {logs}");
assert!(
logs.contains("op=\"download\""),
"expected op field, got: {logs}"
);
assert!(
logs.contains("download failed: request error"),
"expected the connection-level failure breadcrumb, got: {logs}"
);
assert!(
logs.contains("elapsed_ms"),
"expected elapsed_ms field, got: {logs}"
);
assert!(!dest.exists(), "no file committed on a refused connection");
}
#[tokio::test]
async fn sha256_mismatch_emits_a_warn_breadcrumb() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(match_path("/tampered.gguf"))
.respond_with(ResponseTemplate::new(200).set_body_bytes(b"evil bytes".to_vec()))
.mount(&server)
.await;
let url = format!("{}/tampered.gguf", server.uri());
let dir = tempfile::tempdir().unwrap();
let dir_path = dir.path().to_path_buf();
let file = model_file("tampered.gguf", &url, Some(BODY_SHA256));
let logs = captured_logs_for(move || {
let _ = download::ensure_file(&dir_path, &file);
});
assert!(logs.contains("WARN"), "expected WARN event, got: {logs}");
assert!(
logs.contains("op=\"download\""),
"expected op field, got: {logs}"
);
assert!(
logs.contains("sha256"),
"expected the sha256 reason in the breadcrumb, got: {logs}"
);
}
#[tokio::test]
async fn ensure_file_downloads_when_missing_then_reuses_the_cache() {
let server = MockServer::start().await;
let body = b"cached model bytes".to_vec();
Mock::given(method("GET"))
.and(match_path("/once.gguf"))
.respond_with(ResponseTemplate::new(200).set_body_bytes(body.clone()))
.expect(1)
.mount(&server)
.await;
let url = format!("{}/once.gguf", server.uri());
let dir = tempfile::tempdir().unwrap();
let dir_path = dir.path().to_path_buf();
let file_for_thread = model_file("once.gguf", &url, None);
let first = detached(move || download::ensure_file(&dir_path, &file_for_thread).unwrap());
assert_eq!(std::fs::read(&first).unwrap(), body);
let dir_path = dir.path().to_path_buf();
let file = model_file("once.gguf", &url, None);
let second = detached(move || download::ensure_file(&dir_path, &file).unwrap());
assert_eq!(first, second);
}