use std::io::Write;
use std::path::Path;
use studio_worker::engine::sd_provision;
use tokio::sync::Mutex;
use wiremock::matchers::{method, path as match_path};
use wiremock::{Mock, MockServer, ResponseTemplate};
static URL_ENV_LOCK: Mutex<()> = Mutex::const_new(());
const URL_ENV: &str = "STUDIO_WORKER_SDCPP_URL";
fn detached<R: Send + 'static>(f: impl FnOnce() -> R + Send + 'static) -> R {
std::thread::spawn(f)
.join()
.expect("worker thread panicked")
}
fn fake_release_zip(include_binary: bool) -> Vec<u8> {
let mut buf = Vec::new();
{
let mut zw = zip::ZipWriter::new(std::io::Cursor::new(&mut buf));
let opts: zip::write::FileOptions<()> = zip::write::FileOptions::default()
.compression_method(zip::CompressionMethod::Deflated)
.unix_permissions(0o755);
if include_binary {
zw.start_file(sd_provision::binary_name(), opts).unwrap();
zw.write_all(b"#!/bin/sh\necho fake-sd-cli\n").unwrap();
}
let lib_opts: zip::write::FileOptions<()> =
zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Deflated);
zw.start_file("libstable-diffusion.so", lib_opts).unwrap();
zw.write_all(b"pretend shared library").unwrap();
zw.start_file("stable-diffusion.cpp.txt", lib_opts).unwrap();
zw.write_all(b"MIT").unwrap();
zw.finish().unwrap();
}
buf
}
#[cfg(unix)]
fn is_executable(path: &Path) -> bool {
use std::os::unix::fs::PermissionsExt;
std::fs::metadata(path).unwrap().permissions().mode() & 0o111 != 0
}
#[cfg(not(unix))]
fn is_executable(_path: &Path) -> bool {
true
}
#[tokio::test]
async fn provision_downloads_extracts_and_caches_sd_cli() {
let _guard = URL_ENV_LOCK.lock().await;
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(match_path("/sd.zip"))
.respond_with(ResponseTemplate::new(200).set_body_bytes(fake_release_zip(true)))
.expect(1)
.mount(&server)
.await;
let url = format!("{}/sd.zip", server.uri());
let dir = tempfile::tempdir().unwrap();
let models_root = dir.path().to_path_buf();
std::env::set_var(URL_ENV, &url);
let root = models_root.clone();
let first = detached(move || sd_provision::provision(&root).unwrap());
std::env::remove_var(URL_ENV);
let expected = models_root.join("bin").join(sd_provision::binary_name());
assert_eq!(first, expected);
assert!(first.is_file(), "binary must be installed");
assert_eq!(
std::fs::read(&first).unwrap(),
b"#!/bin/sh\necho fake-sd-cli\n"
);
assert!(is_executable(&first), "binary must be executable");
assert!(models_root
.join("bin")
.join("libstable-diffusion.so")
.is_file());
let leftovers: Vec<_> = std::fs::read_dir(&models_root)
.unwrap()
.filter_map(|e| e.ok())
.map(|e| e.file_name().to_string_lossy().into_owned())
.filter(|n| n.starts_with(".sd-cli"))
.collect();
assert!(leftovers.is_empty(), "scratch litter left: {leftovers:?}");
let root = models_root.clone();
let second = detached(move || sd_provision::provision(&root).unwrap());
assert_eq!(second, expected);
}
#[tokio::test]
async fn provision_errors_when_zip_lacks_the_binary() {
let _guard = URL_ENV_LOCK.lock().await;
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(match_path("/broken.zip"))
.respond_with(ResponseTemplate::new(200).set_body_bytes(fake_release_zip(false)))
.mount(&server)
.await;
let url = format!("{}/broken.zip", server.uri());
let dir = tempfile::tempdir().unwrap();
let models_root = dir.path().to_path_buf();
std::env::set_var(URL_ENV, &url);
let root = models_root.clone();
let err = detached(move || {
sd_provision::provision(&root)
.expect_err("a zip with no binary must error")
.to_string()
});
std::env::remove_var(URL_ENV);
assert!(err.contains(sd_provision::binary_name()), "got: {err}");
assert!(!models_root
.join("bin")
.join(sd_provision::binary_name())
.exists());
}