fathomdb-embedder 0.8.0

FathomDB embedder runtime — built-in embedder implementations for the fathomdb-embedder-api trait.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
//! Integration tests for the default-embedder loader.
//!
//! Per `dev/plans/prompts/0.7.1-EMBEDDER-UNDEFER-HANDOFF.md` §EU-3, this slice
//! ships five required tests that drive the loader contract:
//!
//! 1. `loads_pinned_model_with_correct_sha`
//! 2. `rejects_checksum_mismatch`
//! 3. `resumes_partial_download`
//! 4. `concurrent_loaders_serialize_via_filelock`
//! 5. `auth_token_sent_when_env_set`
//!
//! All tests run against a local `httpmock` server so the suite never touches
//! the network. The entire file is gated behind the `default-embedder` Cargo
//! feature: without it the crate stays a tiny `NoopEmbedder` holder with zero
//! optional deps.
//!
//! Concurrency-test variant choice (see §EU-3 test 4): we assert that across
//! N=4 concurrent loaders the mock observes **exactly one** complete set of
//! fetches (one config + one tokenizer + one model). The fs2 exclusive lock
//! serializes the first-use cold path; the late-arriving threads observe the
//! verified cache files after the lock releases and short-circuit before
//! hitting HTTP at all. This variant is cleaner to assert and exercises the
//! "cache-hit path does NOT take the lock" property.

#![cfg(all(feature = "default-embedder", feature = "loader-test-hooks"))]

use std::fs;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::Duration;

use httpmock::prelude::*;
use sha2::{Digest, Sha256};
use tempfile::TempDir;

use fathomdb_embedder::loader::{
    load_pinned_default_embedder, load_with_config, EmbedderEvent, EmbedderLoadError,
    LoadedWeights, LoaderConfig,
};

const HF_REVISION: &str = "5c38ec7c405ec4b44b94cc5a9bb96e735b38267a";

/// Fixture bytes for each pinned file. Content is small + deterministic so the
/// tests can pin sha256 values directly. The real HF SHAs in `loader.rs` are
/// for production fetches; tests override the pinned constants via
/// `LoaderConfig::with_test_pins`.
struct Fixture {
    config_bytes: Vec<u8>,
    tokenizer_bytes: Vec<u8>,
    model_bytes: Vec<u8>,
}

impl Fixture {
    fn new() -> Self {
        Self {
            config_bytes: br#"{"model_type":"bert","hidden_size":384}"#.to_vec(),
            tokenizer_bytes: br#"{"version":"1.0","model":{"type":"WordPiece"}}"#.to_vec(),
            // 8 KiB of deterministic pseudo-random bytes for the "model".
            model_bytes: (0u32..2048).flat_map(|n| n.to_le_bytes()).collect(),
        }
    }

    fn sha_hex(bytes: &[u8]) -> String {
        let mut h = Sha256::new();
        h.update(bytes);
        format!("{:x}", h.finalize())
    }

    fn config_sha(&self) -> String {
        Self::sha_hex(&self.config_bytes)
    }
    fn tokenizer_sha(&self) -> String {
        Self::sha_hex(&self.tokenizer_bytes)
    }
    fn model_sha(&self) -> String {
        Self::sha_hex(&self.model_bytes)
    }
}

fn resolve_path(file: &str) -> String {
    format!("/BAAI/bge-small-en-v1.5/resolve/{HF_REVISION}/{file}")
}

fn test_config(server_base: &str, cache_root: &Path, fix: &Fixture) -> LoaderConfig {
    LoaderConfig::for_tests()
        .with_base_url(server_base.to_string())
        .with_cache_root(cache_root.to_path_buf())
        .with_test_pins(fix.config_sha(), fix.tokenizer_sha(), fix.model_sha())
}

#[test]
fn loads_pinned_model_with_correct_sha() {
    let fix = Fixture::new();
    let server = MockServer::start();
    let tmp = TempDir::new().unwrap();

    let m_cfg = server.mock(|when, then| {
        when.method(GET).path(resolve_path("config.json"));
        then.status(200).body(&fix.config_bytes);
    });
    let m_tok = server.mock(|when, then| {
        when.method(GET).path(resolve_path("tokenizer.json"));
        then.status(200).body(&fix.tokenizer_bytes);
    });
    let m_mdl = server.mock(|when, then| {
        when.method(GET).path(resolve_path("model.safetensors"));
        then.status(200).body(&fix.model_bytes);
    });

    let cache = tmp.path().to_path_buf();
    let loaded: LoadedWeights =
        load_with_config(test_config(&server.base_url(), &cache, &fix)).expect("loader ok");

    assert!(loaded.config_json_path.is_file());
    assert!(loaded.tokenizer_json_path.is_file());
    assert!(loaded.model_safetensors_path.is_file());

    let on_disk = fs::read(&loaded.model_safetensors_path).unwrap();
    assert_eq!(Fixture::sha_hex(&on_disk), fix.model_sha());
    assert!(loaded.bytes_downloaded > 0);

    // Per design §7, a fresh fetch surfaces a DefaultEmbedderDownload event.
    assert!(loaded
        .events
        .iter()
        .any(|e| matches!(e, EmbedderEvent::DefaultEmbedderDownload { .. })));

    m_cfg.assert();
    m_tok.assert();
    m_mdl.assert();
}

#[test]
fn rejects_checksum_mismatch() {
    let fix = Fixture::new();
    let server = MockServer::start();
    let tmp = TempDir::new().unwrap();

    server.mock(|when, then| {
        when.method(GET).path(resolve_path("config.json"));
        then.status(200).body(&fix.config_bytes);
    });
    server.mock(|when, then| {
        when.method(GET).path(resolve_path("tokenizer.json"));
        then.status(200).body(&fix.tokenizer_bytes);
    });
    // Serve wrong bytes for the model. The pinned sha is for the correct bytes.
    let wrong = b"not the real model bytes".to_vec();
    server.mock(|when, then| {
        when.method(GET).path(resolve_path("model.safetensors"));
        then.status(200).body(&wrong);
    });

    let cache = tmp.path().to_path_buf();
    let err = load_with_config(test_config(&server.base_url(), &cache, &fix))
        .expect_err("must fail closed on sha mismatch");
    assert!(
        matches!(err, EmbedderLoadError::ChecksumMismatch { .. }),
        "expected ChecksumMismatch, got {err:?}"
    );

    // Per design §6: file removed on mismatch. Both the final and .partial
    // forms must be absent (loader is responsible for cleanup).
    let cache_dir = cache.join("fathomdb").join("embedders");
    let mut found_model = false;
    if cache_dir.is_dir() {
        for entry in walkdir(&cache_dir) {
            let name = entry.file_name().and_then(|n| n.to_str()).unwrap_or("");
            if name.contains("model.safetensors") {
                found_model = true;
            }
        }
    }
    assert!(!found_model, "model.safetensors (or .partial) must be removed on checksum mismatch");
}

#[test]
fn resumes_partial_download() {
    let fix = Fixture::new();
    let server = MockServer::start();
    let tmp = TempDir::new().unwrap();
    let cache = tmp.path().to_path_buf();

    // Config + tokenizer always succeed cleanly.
    server.mock(|when, then| {
        when.method(GET).path(resolve_path("config.json"));
        then.status(200).body(&fix.config_bytes);
    });
    server.mock(|when, then| {
        when.method(GET).path(resolve_path("tokenizer.json"));
        then.status(200).body(&fix.tokenizer_bytes);
    });

    // Pre-stage a .partial for the model holding the first half of the bytes.
    let half = fix.model_bytes.len() / 2;
    let cfg = test_config(&server.base_url(), &cache, &fix);
    let partial_dir = cfg.expected_cache_dir();
    fs::create_dir_all(&partial_dir).unwrap();
    let partial_path = partial_dir.join("model.safetensors.partial");
    let mut f = fs::File::create(&partial_path).unwrap();
    f.write_all(&fix.model_bytes[..half]).unwrap();
    f.sync_all().unwrap();
    drop(f);

    // Mock returns 206 Partial Content on Range request; serves the second half.
    let m_range = server.mock(|when, then| {
        when.method(GET).path(resolve_path("model.safetensors")).header_exists("range");
        then.status(206).body(&fix.model_bytes[half..]);
    });

    let loaded = load_with_config(cfg).expect("resume load ok");
    let bytes = fs::read(&loaded.model_safetensors_path).unwrap();
    assert_eq!(Fixture::sha_hex(&bytes), fix.model_sha());
    m_range.assert();
}

#[test]
fn concurrent_loaders_serialize_via_filelock() {
    let fix = Fixture::new();
    let server = MockServer::start();
    let tmp = TempDir::new().unwrap();
    let cache = tmp.path().to_path_buf();

    let cfg_calls = Arc::new(AtomicUsize::new(0));
    let tok_calls = Arc::new(AtomicUsize::new(0));
    let mdl_calls = Arc::new(AtomicUsize::new(0));

    // Slow handlers so that even if threads race to acquire the lock, the
    // first holder is unambiguously the one doing the network work and the
    // rest must observe the cache after release.
    let _m_cfg = {
        let calls = cfg_calls.clone();
        let body = fix.config_bytes.clone();
        server.mock(move |when, then| {
            calls.fetch_add(1, Ordering::SeqCst);
            when.method(GET).path(resolve_path("config.json"));
            then.status(200).delay(Duration::from_millis(50)).body(body);
        })
    };
    let _m_tok = {
        let calls = tok_calls.clone();
        let body = fix.tokenizer_bytes.clone();
        server.mock(move |when, then| {
            calls.fetch_add(1, Ordering::SeqCst);
            when.method(GET).path(resolve_path("tokenizer.json"));
            then.status(200).delay(Duration::from_millis(50)).body(body);
        })
    };
    let _m_mdl = {
        let calls = mdl_calls.clone();
        let body = fix.model_bytes.clone();
        server.mock(move |when, then| {
            calls.fetch_add(1, Ordering::SeqCst);
            when.method(GET).path(resolve_path("model.safetensors"));
            then.status(200).delay(Duration::from_millis(50)).body(body);
        })
    };

    let base = server.base_url();
    let mut handles = Vec::new();
    for _ in 0..4 {
        let cfg = test_config(&base, &cache, &fix);
        handles.push(thread::spawn(move || load_with_config(cfg)));
    }

    for h in handles {
        h.join().unwrap().expect("each thread loads ok");
    }

    // Variant chosen (documented in module header): exactly one set of fetches
    // observed by the mock. The first thread acquires the fs2 exclusive lock,
    // downloads + verifies + renames; the other three observe the cached
    // files after the lock releases and short-circuit before HTTP.
    assert_eq!(cfg_calls.load(Ordering::SeqCst), 1);
    assert_eq!(tok_calls.load(Ordering::SeqCst), 1);
    assert_eq!(mdl_calls.load(Ordering::SeqCst), 1);
}

#[test]
fn auth_token_sent_when_env_set() {
    let fix = Fixture::new();
    let server = MockServer::start();
    let tmp = TempDir::new().unwrap();
    let cache = tmp.path().to_path_buf();

    let m_cfg = server.mock(|when, then| {
        when.method(GET).path(resolve_path("config.json")).header("authorization", "Bearer sekret");
        then.status(200).body(&fix.config_bytes);
    });
    let m_tok = server.mock(|when, then| {
        when.method(GET)
            .path(resolve_path("tokenizer.json"))
            .header("authorization", "Bearer sekret");
        then.status(200).body(&fix.tokenizer_bytes);
    });
    let m_mdl = server.mock(|when, then| {
        when.method(GET)
            .path(resolve_path("model.safetensors"))
            .header("authorization", "Bearer sekret");
        then.status(200).body(&fix.model_bytes);
    });

    let cfg = test_config(&server.base_url(), &cache, &fix).with_hf_token(Some("sekret".into()));
    load_with_config(cfg).expect("loads with bearer");

    m_cfg.assert();
    m_tok.assert();
    m_mdl.assert();

    // Second pass: token unset → mock must reject any request bearing an
    // Authorization header. Use a fresh cache so the loader actually
    // re-fetches.
    let tmp2 = TempDir::new().unwrap();
    let server2 = MockServer::start();
    let m_cfg2 = server2.mock(|when, then| {
        when.method(GET).path(resolve_path("config.json"));
        // header_missing isn't always available; assert via a negative path:
        // if any request carries Authorization, this mock won't match and the
        // loader will see a 404. Use header_exists negation pattern.
        then.status(200).body(&fix.config_bytes);
    });
    let m_tok2 = server2.mock(|when, then| {
        when.method(GET).path(resolve_path("tokenizer.json"));
        then.status(200).body(&fix.tokenizer_bytes);
    });
    let m_mdl2 = server2.mock(|when, then| {
        when.method(GET).path(resolve_path("model.safetensors"));
        then.status(200).body(&fix.model_bytes);
    });

    let cfg2 = test_config(&server2.base_url(), tmp2.path(), &fix).with_hf_token(None);
    load_with_config(cfg2).expect("loads without token");
    m_cfg2.assert();
    m_tok2.assert();
    m_mdl2.assert();
}

#[test]
fn respects_timeout_env_overrides() {
    // EU-3 FIX-2 #2: design §2 promises `FATHOMDB_EMBEDDER_CONNECT_TIMEOUT_S`
    // and `FATHOMDB_EMBEDDER_READ_TIMEOUT_S` env overrides parse as u64
    // seconds; invalid → default with a warning (no panic, no unwrap).
    //
    // We assert the parsing logic directly via `for_tests_reading_timeout_env`
    // which goes through the same `parse_secs_env_or_default` path the
    // production constructor uses. Holding the env-mutex prevents races with
    // other tests that touch the same vars.
    let _g = ENV_GUARD.lock().unwrap_or_else(|e| e.into_inner());

    // Save existing values so we restore the process env.
    let prev_connect = std::env::var("FATHOMDB_EMBEDDER_CONNECT_TIMEOUT_S").ok();
    let prev_read = std::env::var("FATHOMDB_EMBEDDER_READ_TIMEOUT_S").ok();

    // Valid overrides parse and apply.
    std::env::set_var("FATHOMDB_EMBEDDER_CONNECT_TIMEOUT_S", "7");
    std::env::set_var("FATHOMDB_EMBEDDER_READ_TIMEOUT_S", "111");
    let cfg = LoaderConfig::for_tests_reading_timeout_env();
    assert_eq!(cfg.connect_timeout(), Duration::from_secs(7));
    assert_eq!(cfg.read_timeout(), Duration::from_secs(111));

    // Invalid → default, no panic.
    std::env::set_var("FATHOMDB_EMBEDDER_CONNECT_TIMEOUT_S", "not-a-number");
    std::env::set_var("FATHOMDB_EMBEDDER_READ_TIMEOUT_S", "");
    let cfg = LoaderConfig::for_tests_reading_timeout_env();
    assert_eq!(cfg.connect_timeout(), Duration::from_secs(10), "invalid → default 10s");
    assert_eq!(cfg.read_timeout(), Duration::from_secs(60), "invalid → default 60s");

    // Unset → default.
    std::env::remove_var("FATHOMDB_EMBEDDER_CONNECT_TIMEOUT_S");
    std::env::remove_var("FATHOMDB_EMBEDDER_READ_TIMEOUT_S");
    let cfg = LoaderConfig::for_tests_reading_timeout_env();
    assert_eq!(cfg.connect_timeout(), Duration::from_secs(10));
    assert_eq!(cfg.read_timeout(), Duration::from_secs(60));

    // Restore previous values.
    match prev_connect {
        Some(v) => std::env::set_var("FATHOMDB_EMBEDDER_CONNECT_TIMEOUT_S", v),
        None => std::env::remove_var("FATHOMDB_EMBEDDER_CONNECT_TIMEOUT_S"),
    }
    match prev_read {
        Some(v) => std::env::set_var("FATHOMDB_EMBEDDER_READ_TIMEOUT_S", v),
        None => std::env::remove_var("FATHOMDB_EMBEDDER_READ_TIMEOUT_S"),
    }
}

#[test]
fn hf_hub_compat_probe_reads_from_hub_layout() {
    // EU-3 FIX-2 #6: when the file is already present under the HF-hub
    // read-only layout AND its sha matches the pinned constant, the
    // loader copies/hard-links it into the fathomdb cache without making
    // any network request. The HF-hub layout is never written to.
    let fix = Fixture::new();
    let server = MockServer::start();
    let tmp = TempDir::new().unwrap();
    let cache = tmp.path().to_path_buf();

    // Pre-stage just `config.json` in the HF-hub layout. The other two
    // files go via mock so we can assert exactly which requests fly.
    let hf_home = tmp.path().join("hf_home");
    let hub_dir = hf_home
        .join("hub")
        .join("models--BAAI--bge-small-en-v1.5")
        .join("snapshots")
        .join(HF_REVISION);
    fs::create_dir_all(&hub_dir).unwrap();
    let hub_config = hub_dir.join("config.json");
    fs::write(&hub_config, &fix.config_bytes).unwrap();

    // Mock: only tokenizer + model served from network. config.json must
    // NOT be requested — if the loader hits it, the test fails the
    // explicit `assert_hits(0)` assertion below.
    let m_cfg_must_not_hit = server.mock(|when, then| {
        when.method(GET).path(resolve_path("config.json"));
        then.status(200).body(&fix.config_bytes);
    });
    let m_tok = server.mock(|when, then| {
        when.method(GET).path(resolve_path("tokenizer.json"));
        then.status(200).body(&fix.tokenizer_bytes);
    });
    let m_mdl = server.mock(|when, then| {
        when.method(GET).path(resolve_path("model.safetensors"));
        then.status(200).body(&fix.model_bytes);
    });

    let cfg = test_config(&server.base_url(), &cache, &fix).with_hf_hub_root(Some(hf_home.clone()));
    let loaded = load_with_config(cfg).expect("loader ok with hub-probe hit");

    // Mock-side: config.json was served from the hub, not the network.
    m_cfg_must_not_hit.assert_hits(0);
    m_tok.assert();
    m_mdl.assert();

    // Loader emitted a cache-hit event for config.json.
    let cache_hit_files: Vec<&str> = loaded
        .events
        .iter()
        .filter_map(|e| match e {
            EmbedderEvent::DefaultEmbedderCacheHit { file, .. } => Some(file.as_str()),
            _ => None,
        })
        .collect();
    assert!(
        cache_hit_files.contains(&"config.json"),
        "expected DefaultEmbedderCacheHit for config.json, got {cache_hit_files:?}"
    );

    // The HF-hub source is intact (read-only probe).
    let hub_bytes = fs::read(&hub_config).unwrap();
    assert_eq!(hub_bytes, fix.config_bytes, "hub source must not be modified");

    // The fathomdb cache materialized the file.
    let on_disk = fs::read(&loaded.config_json_path).unwrap();
    assert_eq!(on_disk, fix.config_bytes);
}

/// Serializes tests that mutate the process env so set/restore cycles
/// don't race with each other.
static ENV_GUARD: std::sync::Mutex<()> = std::sync::Mutex::new(());

#[test]
fn public_api_exists() {
    // Compile-time check: the zero-arg public entry point referenced by EU-4
    // and EU-5 exists and has the documented signature. It is not invoked
    // here (would hit the real network); see the GREEN-side integration tests
    // for behavior coverage.
    let _: fn() -> Result<LoadedWeights, EmbedderLoadError> = load_pinned_default_embedder;
}

// Minimal recursive walker (avoids pulling walkdir as a dev-dep).
fn walkdir(root: &std::path::Path) -> Vec<PathBuf> {
    let mut out = Vec::new();
    let mut stack = vec![root.to_path_buf()];
    while let Some(p) = stack.pop() {
        if let Ok(rd) = fs::read_dir(&p) {
            for entry in rd.flatten() {
                let path = entry.path();
                if path.is_dir() {
                    stack.push(path);
                } else {
                    out.push(path);
                }
            }
        }
    }
    out
}