#![cfg(all(test, feature = "ort"))]
#![allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
use std::io::Write;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use sha2::{Digest, Sha256};
use tempfile::TempDir;
use tiny_http::{Header, Method, Response, Server, StatusCode};
use super::manifest::{Manifest, ManifestFile};
use super::{ensure_with_manifest, BootstrapError};
fn fixture_bytes(name: &str) -> Vec<u8> {
match name {
"model_q4f16.onnx" => b"AAAAAAAAAAAAAAAA".to_vec(),
"tokenizer.json" => b"BBBBBBBBBBBBBBBB".to_vec(),
"config.json" => b"CCCCCCCCCCCCCCCC".to_vec(),
_ => b"DDDDDDDDDDDDDDDD".to_vec(),
}
}
fn fixture_sha256(name: &str) -> String {
hex::encode(Sha256::digest(fixture_bytes(name)))
}
fn spawn_stub_server<F>(
handler: F,
) -> (
String,
thread::JoinHandle<()>,
Arc<AtomicUsize>,
Arc<Server>,
)
where
F: Fn(&tiny_http::Request, usize) -> Response<std::io::Cursor<Vec<u8>>> + Send + Sync + 'static,
{
let server = Arc::new(Server::http("127.0.0.1:0").expect("bind 127.0.0.1:0"));
let port = server.server_addr().to_ip().expect("bind ip").port();
let base_url = format!("http://127.0.0.1:{port}");
let counter = Arc::new(AtomicUsize::new(0));
let counter_h = counter.clone();
let server_h = server.clone();
let handle = thread::spawn(move || {
for req in server_h.incoming_requests() {
let n = counter_h.fetch_add(1, Ordering::SeqCst);
let resp = handler(&req, n);
let _ = req.respond(resp);
}
});
(base_url, handle, counter, server)
}
fn build_test_manifest(base_url: &str) -> Manifest {
Manifest {
model_name: "test-model".to_string(),
version: "v1".to_string(),
chunk_count: 16,
files: vec![
ManifestFile {
name: "model_q4f16.onnx".to_string(),
size_bytes: fixture_bytes("model_q4f16.onnx").len() as u64,
sha256: fixture_sha256("model_q4f16.onnx"),
primary_url: format!("{base_url}/model_q4f16.onnx"),
fallback_urls: vec![],
},
ManifestFile {
name: "tokenizer.json".to_string(),
size_bytes: fixture_bytes("tokenizer.json").len() as u64,
sha256: fixture_sha256("tokenizer.json"),
primary_url: format!("{base_url}/tokenizer.json"),
fallback_urls: vec![],
},
ManifestFile {
name: "config.json".to_string(),
size_bytes: fixture_bytes("config.json").len() as u64,
sha256: fixture_sha256("config.json"),
primary_url: format!("{base_url}/config.json"),
fallback_urls: vec![],
},
],
..Default::default()
}
}
fn serve_chunk(req: &tiny_http::Request, etag: &str) -> Response<std::io::Cursor<Vec<u8>>> {
let path = req.url().trim_start_matches('/');
let bytes = fixture_bytes(path);
if req.method() == &Method::Head {
let mut resp = Response::from_data(Vec::new());
resp = resp
.with_header(
Header::from_bytes(&b"Content-Length"[..], bytes.len().to_string()).unwrap(),
)
.with_header(Header::from_bytes(&b"ETag"[..], etag).unwrap());
return resp;
}
let range = req
.headers()
.iter()
.find(|h| h.field.as_str().as_str().eq_ignore_ascii_case("range"))
.map(|h| h.value.as_str().to_string());
if let Some(r) = range {
let r = r.trim_start_matches("bytes=");
let parts: Vec<&str> = r.split('-').collect();
if parts.len() == 2 {
let start: usize = parts[0].parse().unwrap_or(0);
let end: usize = parts[1].parse().unwrap_or(bytes.len() - 1);
let end = end.min(bytes.len().saturating_sub(1));
if start <= end && start < bytes.len() {
let slice = bytes[start..=end].to_vec();
return Response::from_data(slice).with_status_code(StatusCode(206));
}
}
}
Response::from_data(bytes)
}
#[test]
fn test_happy_path_16chunk() {
let (base_url, _h, _counter, server) = spawn_stub_server(|req, _n| serve_chunk(req, "\"v1\""));
let tmp = TempDir::new().expect("tmp");
let manifest = build_test_manifest(&base_url);
let result = ensure_with_manifest(Some(tmp.path()), &manifest);
drop(server);
let paths = result.expect("happy path bootstrap should succeed");
assert!(paths.onnx.exists());
assert!(paths.tokenizer.exists());
assert!(paths.config.exists());
assert_eq!(
std::fs::read(&paths.onnx).unwrap(),
fixture_bytes("model_q4f16.onnx")
);
}
#[test]
fn test_etag_304_short_circuit() {
let tmp = TempDir::new().expect("tmp");
for name in ["model_q4f16.onnx", "tokenizer.json", "config.json"] {
let path = tmp.path().join(name);
let mut f = std::fs::File::create(&path).unwrap();
f.write_all(&fixture_bytes(name)).unwrap();
}
let (base_url, _h, counter, server) = spawn_stub_server(|_req, _n| {
panic!("server should NOT be hit when local cache is sha256-valid");
});
let manifest = build_test_manifest(&base_url);
let result = ensure_with_manifest(Some(tmp.path()), &manifest);
drop(server);
let paths = result.expect("304 short-circuit should succeed");
assert!(paths.onnx.exists());
assert_eq!(
counter.load(Ordering::SeqCst),
0,
"expected 0 server hits when local cache valid"
);
}
#[test]
fn test_sha256_mismatch_deletes() {
let (base_url, _h, _counter, server) = spawn_stub_server(|req, _n| serve_chunk(req, "\"v1\""));
let tmp = TempDir::new().expect("tmp");
let mut manifest = build_test_manifest(&base_url);
manifest.files[0].sha256 =
"0000000000000000000000000000000000000000000000000000000000000000".to_string();
let result = ensure_with_manifest(Some(tmp.path()), &manifest);
drop(server);
match result {
Err(BootstrapError::Sha256Mismatch { expected, actual }) => {
assert_eq!(
expected,
"0000000000000000000000000000000000000000000000000000000000000000"
);
assert_ne!(expected, actual, "actual must differ from expected");
}
other => panic!("expected Sha256Mismatch, got {other:?}"),
}
let onnx_path = tmp.path().join("model_q4f16.onnx");
assert!(
!onnx_path.exists(),
"expected onnx artifact deleted after sha256 mismatch"
);
let leftovers: Vec<_> = std::fs::read_dir(tmp.path())
.unwrap()
.filter_map(|e| e.ok())
.filter(|e| e.file_name().to_string_lossy().starts_with(".partial."))
.collect();
assert!(
leftovers.is_empty(),
"expected 0 .partial.* leftovers, got {leftovers:?}"
);
}
#[test]
fn test_fallback_url_recovers() {
let primary_fail = Arc::new(AtomicUsize::new(0));
let primary_fail_h = primary_fail.clone();
let primary_server = Arc::new(Server::http("127.0.0.1:0").unwrap());
let primary_port = primary_server.server_addr().to_ip().unwrap().port();
let primary_url = format!("http://127.0.0.1:{primary_port}");
let primary_h = primary_server.clone();
let _t1 = thread::spawn(move || {
for req in primary_h.incoming_requests() {
primary_fail_h.fetch_add(1, Ordering::SeqCst);
let resp: Response<std::io::Cursor<Vec<u8>>> =
Response::from_data(Vec::new()).with_status_code(StatusCode(500));
let _ = req.respond(resp);
}
});
let (fallback_url, _h, fallback_counter, fallback_server) =
spawn_stub_server(|req, _n| serve_chunk(req, "\"v1\""));
let tmp = TempDir::new().expect("tmp");
let mut manifest = Manifest {
model_name: "test-model".to_string(),
version: "v1".to_string(),
chunk_count: 16,
files: vec![],
..Default::default()
};
for name in ["model_q4f16.onnx", "tokenizer.json", "config.json"] {
manifest.files.push(ManifestFile {
name: name.to_string(),
size_bytes: fixture_bytes(name).len() as u64,
sha256: fixture_sha256(name),
primary_url: format!("{primary_url}/{name}"),
fallback_urls: vec![format!("{fallback_url}/{name}")],
});
}
let result = ensure_with_manifest(Some(tmp.path()), &manifest);
drop(primary_server);
drop(fallback_server);
let paths = result.expect("fallback should recover");
assert!(paths.onnx.exists());
assert!(
primary_fail.load(Ordering::SeqCst) >= 1,
"primary should be tried at least once"
);
assert!(
fallback_counter.load(Ordering::SeqCst) >= 1,
"fallback should serve at least one request"
);
}
#[test]
fn test_all_urls_fail() {
let tmp = TempDir::new().expect("tmp");
let mut manifest = Manifest {
model_name: "test-model".to_string(),
version: "v1".to_string(),
chunk_count: 16,
files: vec![],
..Default::default()
};
let unreachable_primary = "http://127.0.0.1:1/file";
let unreachable_fallback = "http://127.0.0.1:2/file";
for name in ["model_q4f16.onnx", "tokenizer.json", "config.json"] {
manifest.files.push(ManifestFile {
name: name.to_string(),
size_bytes: fixture_bytes(name).len() as u64,
sha256: fixture_sha256(name),
primary_url: unreachable_primary.to_string(),
fallback_urls: vec![unreachable_fallback.to_string()],
});
}
let result = ensure_with_manifest(Some(tmp.path()), &manifest);
match result {
Err(BootstrapError::NetworkUnreachable {
tried_urls,
last_error,
}) => {
assert!(
tried_urls.contains(&unreachable_primary.to_string()),
"tried_urls must include primary"
);
assert!(
tried_urls.contains(&unreachable_fallback.to_string()),
"tried_urls must include fallback"
);
assert!(!last_error.is_empty(), "last_error must be populated");
}
other => panic!("expected NetworkUnreachable, got {other:?}"),
}
}
#[test]
fn cfg_gate_active() {
const { assert!(cfg!(feature = "ort")) };
let _ = std::any::TypeId::of::<super::ModelPaths>();
let _ = std::any::TypeId::of::<super::BootstrapError>();
}
#[test]
fn smoke_timeout_below_5min() {
let _budget = Duration::from_secs(300);
}