#![cfg(all(feature = "default-embedder", feature = "loader-test-hooks"))]
use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use httpmock::prelude::*;
use sha2::{Digest, Sha256};
use tempfile::TempDir;
use fathomdb_embedder::loader::{
load_pinned_default_embedder, load_with_config, EmbedderEvent, EmbedderLoadError,
LoadedWeights, LoaderConfig,
};
const HF_REVISION: &str = "5c38ec7c405ec4b44b94cc5a9bb96e735b38267a";
struct Fixture {
config_bytes: Vec<u8>,
tokenizer_bytes: Vec<u8>,
model_bytes: Vec<u8>,
}
impl Fixture {
fn new() -> Self {
Self {
config_bytes: br#"{"model_type":"bert","hidden_size":384}"#.to_vec(),
tokenizer_bytes: br#"{"version":"1.0","model":{"type":"WordPiece"}}"#.to_vec(),
model_bytes: (0u32..2048).flat_map(|n| n.to_le_bytes()).collect(),
}
}
fn sha_hex(bytes: &[u8]) -> String {
let mut h = Sha256::new();
h.update(bytes);
format!("{:x}", h.finalize())
}
fn config_sha(&self) -> String {
Self::sha_hex(&self.config_bytes)
}
fn tokenizer_sha(&self) -> String {
Self::sha_hex(&self.tokenizer_bytes)
}
fn model_sha(&self) -> String {
Self::sha_hex(&self.model_bytes)
}
}
fn resolve_path(file: &str) -> String {
format!("/BAAI/bge-small-en-v1.5/resolve/{HF_REVISION}/{file}")
}
fn test_config(server_base: &str, cache_root: &Path, fix: &Fixture) -> LoaderConfig {
LoaderConfig::for_tests()
.with_base_url(server_base.to_string())
.with_cache_root(cache_root.to_path_buf())
.with_test_pins(fix.config_sha(), fix.tokenizer_sha(), fix.model_sha())
}
#[test]
fn loads_pinned_model_with_correct_sha() {
let fix = Fixture::new();
let server = MockServer::start();
let tmp = TempDir::new().unwrap();
let m_cfg = server.mock(|when, then| {
when.method(GET).path(resolve_path("config.json"));
then.status(200).body(&fix.config_bytes);
});
let m_tok = server.mock(|when, then| {
when.method(GET).path(resolve_path("tokenizer.json"));
then.status(200).body(&fix.tokenizer_bytes);
});
let m_mdl = server.mock(|when, then| {
when.method(GET).path(resolve_path("model.safetensors"));
then.status(200).body(&fix.model_bytes);
});
let cache = tmp.path().to_path_buf();
let loaded: LoadedWeights =
load_with_config(test_config(&server.base_url(), &cache, &fix)).expect("loader ok");
assert!(loaded.config_json_path.is_file());
assert!(loaded.tokenizer_json_path.is_file());
assert!(loaded.model_safetensors_path.is_file());
let on_disk = fs::read(&loaded.model_safetensors_path).unwrap();
assert_eq!(Fixture::sha_hex(&on_disk), fix.model_sha());
assert!(loaded.bytes_downloaded > 0);
assert!(loaded
.events
.iter()
.any(|e| matches!(e, EmbedderEvent::DefaultEmbedderDownload { .. })));
m_cfg.assert();
m_tok.assert();
m_mdl.assert();
}
#[test]
fn rejects_checksum_mismatch() {
let fix = Fixture::new();
let server = MockServer::start();
let tmp = TempDir::new().unwrap();
server.mock(|when, then| {
when.method(GET).path(resolve_path("config.json"));
then.status(200).body(&fix.config_bytes);
});
server.mock(|when, then| {
when.method(GET).path(resolve_path("tokenizer.json"));
then.status(200).body(&fix.tokenizer_bytes);
});
let wrong = b"not the real model bytes".to_vec();
server.mock(|when, then| {
when.method(GET).path(resolve_path("model.safetensors"));
then.status(200).body(&wrong);
});
let cache = tmp.path().to_path_buf();
let err = load_with_config(test_config(&server.base_url(), &cache, &fix))
.expect_err("must fail closed on sha mismatch");
assert!(
matches!(err, EmbedderLoadError::ChecksumMismatch { .. }),
"expected ChecksumMismatch, got {err:?}"
);
let cache_dir = cache.join("fathomdb").join("embedders");
let mut found_model = false;
if cache_dir.is_dir() {
for entry in walkdir(&cache_dir) {
let name = entry.file_name().and_then(|n| n.to_str()).unwrap_or("");
if name.contains("model.safetensors") {
found_model = true;
}
}
}
assert!(!found_model, "model.safetensors (or .partial) must be removed on checksum mismatch");
}
#[test]
fn resumes_partial_download() {
let fix = Fixture::new();
let server = MockServer::start();
let tmp = TempDir::new().unwrap();
let cache = tmp.path().to_path_buf();
server.mock(|when, then| {
when.method(GET).path(resolve_path("config.json"));
then.status(200).body(&fix.config_bytes);
});
server.mock(|when, then| {
when.method(GET).path(resolve_path("tokenizer.json"));
then.status(200).body(&fix.tokenizer_bytes);
});
let half = fix.model_bytes.len() / 2;
let cfg = test_config(&server.base_url(), &cache, &fix);
let partial_dir = cfg.expected_cache_dir();
fs::create_dir_all(&partial_dir).unwrap();
let partial_path = partial_dir.join("model.safetensors.partial");
let mut f = fs::File::create(&partial_path).unwrap();
f.write_all(&fix.model_bytes[..half]).unwrap();
f.sync_all().unwrap();
drop(f);
let m_range = server.mock(|when, then| {
when.method(GET).path(resolve_path("model.safetensors")).header_exists("range");
then.status(206).body(&fix.model_bytes[half..]);
});
let loaded = load_with_config(cfg).expect("resume load ok");
let bytes = fs::read(&loaded.model_safetensors_path).unwrap();
assert_eq!(Fixture::sha_hex(&bytes), fix.model_sha());
m_range.assert();
}
#[test]
fn concurrent_loaders_serialize_via_filelock() {
let fix = Fixture::new();
let server = MockServer::start();
let tmp = TempDir::new().unwrap();
let cache = tmp.path().to_path_buf();
let cfg_calls = Arc::new(AtomicUsize::new(0));
let tok_calls = Arc::new(AtomicUsize::new(0));
let mdl_calls = Arc::new(AtomicUsize::new(0));
let _m_cfg = {
let calls = cfg_calls.clone();
let body = fix.config_bytes.clone();
server.mock(move |when, then| {
calls.fetch_add(1, Ordering::SeqCst);
when.method(GET).path(resolve_path("config.json"));
then.status(200).delay(Duration::from_millis(50)).body(body);
})
};
let _m_tok = {
let calls = tok_calls.clone();
let body = fix.tokenizer_bytes.clone();
server.mock(move |when, then| {
calls.fetch_add(1, Ordering::SeqCst);
when.method(GET).path(resolve_path("tokenizer.json"));
then.status(200).delay(Duration::from_millis(50)).body(body);
})
};
let _m_mdl = {
let calls = mdl_calls.clone();
let body = fix.model_bytes.clone();
server.mock(move |when, then| {
calls.fetch_add(1, Ordering::SeqCst);
when.method(GET).path(resolve_path("model.safetensors"));
then.status(200).delay(Duration::from_millis(50)).body(body);
})
};
let base = server.base_url();
let mut handles = Vec::new();
for _ in 0..4 {
let cfg = test_config(&base, &cache, &fix);
handles.push(thread::spawn(move || load_with_config(cfg)));
}
for h in handles {
h.join().unwrap().expect("each thread loads ok");
}
assert_eq!(cfg_calls.load(Ordering::SeqCst), 1);
assert_eq!(tok_calls.load(Ordering::SeqCst), 1);
assert_eq!(mdl_calls.load(Ordering::SeqCst), 1);
}
#[test]
fn auth_token_sent_when_env_set() {
let fix = Fixture::new();
let server = MockServer::start();
let tmp = TempDir::new().unwrap();
let cache = tmp.path().to_path_buf();
let m_cfg = server.mock(|when, then| {
when.method(GET).path(resolve_path("config.json")).header("authorization", "Bearer sekret");
then.status(200).body(&fix.config_bytes);
});
let m_tok = server.mock(|when, then| {
when.method(GET)
.path(resolve_path("tokenizer.json"))
.header("authorization", "Bearer sekret");
then.status(200).body(&fix.tokenizer_bytes);
});
let m_mdl = server.mock(|when, then| {
when.method(GET)
.path(resolve_path("model.safetensors"))
.header("authorization", "Bearer sekret");
then.status(200).body(&fix.model_bytes);
});
let cfg = test_config(&server.base_url(), &cache, &fix).with_hf_token(Some("sekret".into()));
load_with_config(cfg).expect("loads with bearer");
m_cfg.assert();
m_tok.assert();
m_mdl.assert();
let tmp2 = TempDir::new().unwrap();
let server2 = MockServer::start();
let m_cfg2 = server2.mock(|when, then| {
when.method(GET).path(resolve_path("config.json"));
then.status(200).body(&fix.config_bytes);
});
let m_tok2 = server2.mock(|when, then| {
when.method(GET).path(resolve_path("tokenizer.json"));
then.status(200).body(&fix.tokenizer_bytes);
});
let m_mdl2 = server2.mock(|when, then| {
when.method(GET).path(resolve_path("model.safetensors"));
then.status(200).body(&fix.model_bytes);
});
let cfg2 = test_config(&server2.base_url(), tmp2.path(), &fix).with_hf_token(None);
load_with_config(cfg2).expect("loads without token");
m_cfg2.assert();
m_tok2.assert();
m_mdl2.assert();
}
#[test]
fn respects_timeout_env_overrides() {
let _g = ENV_GUARD.lock().unwrap_or_else(|e| e.into_inner());
let prev_connect = std::env::var("FATHOMDB_EMBEDDER_CONNECT_TIMEOUT_S").ok();
let prev_read = std::env::var("FATHOMDB_EMBEDDER_READ_TIMEOUT_S").ok();
std::env::set_var("FATHOMDB_EMBEDDER_CONNECT_TIMEOUT_S", "7");
std::env::set_var("FATHOMDB_EMBEDDER_READ_TIMEOUT_S", "111");
let cfg = LoaderConfig::for_tests_reading_timeout_env();
assert_eq!(cfg.connect_timeout(), Duration::from_secs(7));
assert_eq!(cfg.read_timeout(), Duration::from_secs(111));
std::env::set_var("FATHOMDB_EMBEDDER_CONNECT_TIMEOUT_S", "not-a-number");
std::env::set_var("FATHOMDB_EMBEDDER_READ_TIMEOUT_S", "");
let cfg = LoaderConfig::for_tests_reading_timeout_env();
assert_eq!(cfg.connect_timeout(), Duration::from_secs(10), "invalid → default 10s");
assert_eq!(cfg.read_timeout(), Duration::from_secs(60), "invalid → default 60s");
std::env::remove_var("FATHOMDB_EMBEDDER_CONNECT_TIMEOUT_S");
std::env::remove_var("FATHOMDB_EMBEDDER_READ_TIMEOUT_S");
let cfg = LoaderConfig::for_tests_reading_timeout_env();
assert_eq!(cfg.connect_timeout(), Duration::from_secs(10));
assert_eq!(cfg.read_timeout(), Duration::from_secs(60));
match prev_connect {
Some(v) => std::env::set_var("FATHOMDB_EMBEDDER_CONNECT_TIMEOUT_S", v),
None => std::env::remove_var("FATHOMDB_EMBEDDER_CONNECT_TIMEOUT_S"),
}
match prev_read {
Some(v) => std::env::set_var("FATHOMDB_EMBEDDER_READ_TIMEOUT_S", v),
None => std::env::remove_var("FATHOMDB_EMBEDDER_READ_TIMEOUT_S"),
}
}
#[test]
fn hf_hub_compat_probe_reads_from_hub_layout() {
let fix = Fixture::new();
let server = MockServer::start();
let tmp = TempDir::new().unwrap();
let cache = tmp.path().to_path_buf();
let hf_home = tmp.path().join("hf_home");
let hub_dir = hf_home
.join("hub")
.join("models--BAAI--bge-small-en-v1.5")
.join("snapshots")
.join(HF_REVISION);
fs::create_dir_all(&hub_dir).unwrap();
let hub_config = hub_dir.join("config.json");
fs::write(&hub_config, &fix.config_bytes).unwrap();
let m_cfg_must_not_hit = server.mock(|when, then| {
when.method(GET).path(resolve_path("config.json"));
then.status(200).body(&fix.config_bytes);
});
let m_tok = server.mock(|when, then| {
when.method(GET).path(resolve_path("tokenizer.json"));
then.status(200).body(&fix.tokenizer_bytes);
});
let m_mdl = server.mock(|when, then| {
when.method(GET).path(resolve_path("model.safetensors"));
then.status(200).body(&fix.model_bytes);
});
let cfg = test_config(&server.base_url(), &cache, &fix).with_hf_hub_root(Some(hf_home.clone()));
let loaded = load_with_config(cfg).expect("loader ok with hub-probe hit");
m_cfg_must_not_hit.assert_hits(0);
m_tok.assert();
m_mdl.assert();
let cache_hit_files: Vec<&str> = loaded
.events
.iter()
.filter_map(|e| match e {
EmbedderEvent::DefaultEmbedderCacheHit { file, .. } => Some(file.as_str()),
_ => None,
})
.collect();
assert!(
cache_hit_files.contains(&"config.json"),
"expected DefaultEmbedderCacheHit for config.json, got {cache_hit_files:?}"
);
let hub_bytes = fs::read(&hub_config).unwrap();
assert_eq!(hub_bytes, fix.config_bytes, "hub source must not be modified");
let on_disk = fs::read(&loaded.config_json_path).unwrap();
assert_eq!(on_disk, fix.config_bytes);
}
static ENV_GUARD: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[test]
fn public_api_exists() {
let _: fn() -> Result<LoadedWeights, EmbedderLoadError> = load_pinned_default_embedder;
}
fn walkdir(root: &std::path::Path) -> Vec<PathBuf> {
let mut out = Vec::new();
let mut stack = vec![root.to_path_buf()];
while let Some(p) = stack.pop() {
if let Ok(rd) = fs::read_dir(&p) {
for entry in rd.flatten() {
let path = entry.path();
if path.is_dir() {
stack.push(path);
} else {
out.push(path);
}
}
}
}
out
}