Skip to main content

wombatkv_node/
block_prefetch.rs

1#![forbid(unsafe_code)]
2//! Background block-prefetch worker (RFC 0008 §6).
3//!
4//! Periodically snapshots the `MetadataIndex`, scores each entry per
5//! recency / chain-head / model-affinity heuristic, and selects the
6//! top-K candidates that the request hot path is most likely to hit
7//! next. Goal: warm the flat tier before requests arrive, so cold-S3
8//! load latency disappears from the user-visible TTFT.
9//!
10//! ## Scoring (RFC 0008 §6)
11//!
12//! ```text
13//!   score = w_recency * exp(-decay * (now - last_access_ns))
14//!         + w_chain   * is_chain_head_bonus
15//!         + w_model   * model_affinity_bonus
16//! ```
17//!
18//! With:
19//! - `w_recency = 1.0` (primary signal; recently-touched blocks rank highest)
20//! - `w_chain   = 0.3` (chain-head blocks anchor multi-turn prompts)
21//! - `w_model   = 0.2` (active model's blocks rank ahead of stragglers)
22//! - `decay     = ln(2) / 600e9` (half-life ≈ 10 minutes in nanoseconds)
23//! - `is_chain_head_bonus = 1.0 iff BlockMeta.block_seq == 0`
24//! - `model_affinity_bonus = 1.0 iff BlockMeta.model_digest == active`
25//!
26//! ## v1 vs v2
27//!
28//! v1 was log-only: scored the candidates and emitted a `[MyelonInstr]`
29//! event listing the top-K, but never issued the GET. v2 (this module)
30//! actually fetches: per cycle it issues `WombatKVKvStore::get_kv` for
31//! each top-K miss, materializing the payload into the local flat
32//! cache so the next request hits the warm path.
33//!
34//! The fallback path is preserved behind the `WMBT_KV_PREFETCH_DRY_RUN=1`
35//! env: when set, the worker scores and logs but never issues GETs
36//! (matches v1 behavior for diagnostic / canary deployments).
37//!
38//! ### Sequential vs parallel
39//!
40//! v2 issues GETs sequentially per cycle. This mirrors the C ABI's
41//! per-block path (`Handle::put_kv_blocks` parallelizes via
42//! `std::thread::scope`, but `get_kv` itself is the cabi's per-block
43//! call). Parallel fetch within a cycle is a v3 TODO, once we have
44//! evidence of cycle-time becoming a bottleneck, swap in a
45//! `std::thread::scope` fan-out bounded by `top_k`.
46
47use std::sync::atomic::{AtomicBool, Ordering};
48use std::sync::Arc;
49use std::thread::JoinHandle;
50use std::time::{Duration, Instant};
51
52use wombatkv_radix::{BlockHash, BlockMeta, MetadataIndex, ModelDigest};
53
54/// Per-fetch wallclock cap. If a single `get_kv` exceeds this the
55/// worker aborts the rest of the cycle and logs. Keeps a stuck S3 call
56/// from monopolizing the worker thread.
57const FETCH_WALLCLOCK_CAP: Duration = Duration::from_secs(5);
58
59/// Tunables for the prefetch worker.
60#[derive(Clone, Debug)]
61pub struct PrefetchConfig {
62    /// Sleep between scoring cycles. Workers fire `tick()` every
63    /// `interval` and otherwise idle.
64    pub interval: Duration,
65    /// Maximum entries to materialize per cycle. Acts as a cap on the
66    /// prefetch-induced load on the flat tier + S3 bandwidth.
67    pub top_k: usize,
68    /// Active-model fingerprint for affinity scoring. Zeroed digest
69    /// disables the bonus (all blocks rank equally on model affinity).
70    pub model_digest: ModelDigest,
71    /// Namespace under which prefetch GETs are issued. Mirrors the
72    /// caller's `WombatKVKvStore::get_kv(namespace, key)` namespace -
73    /// in cabi this is `WMBT_KV_NAMESPACE` (default `tp-default`).
74    pub namespace: String,
75}
76
77impl Default for PrefetchConfig {
78    fn default() -> Self {
79        Self {
80            interval: Duration::from_millis(500),
81            top_k: 8,
82            model_digest: [0u8; 24],
83            namespace: String::new(),
84        }
85    }
86}
87
88/// Owns a background thread that scores + would-prefetch hot blocks.
89///
90/// Dropping the worker signals stop and joins the thread. The join
91/// runs in `Drop`, so the worker is guaranteed not to outlive its
92/// owner.
93pub struct PrefetchWorker {
94    handle: Option<JoinHandle<()>>,
95    stop: Arc<AtomicBool>,
96}
97
98impl PrefetchWorker {
99    /// Request shutdown without joining. Calling `drop` afterwards
100    /// will still join, this exists for callers that want to fan
101    /// out shutdown signals before serially joining.
102    pub fn signal_stop(&self) {
103        self.stop.store(true, Ordering::SeqCst);
104    }
105
106    /// Returns true while the worker thread has not yet observed the
107    /// stop signal *and* exited. Exposed for tests.
108    #[must_use]
109    pub fn is_running(&self) -> bool {
110        self.handle.as_ref().is_some_and(|h| !h.is_finished())
111    }
112}
113
114impl Drop for PrefetchWorker {
115    fn drop(&mut self) {
116        self.stop.store(true, Ordering::SeqCst);
117        if let Some(h) = self.handle.take() {
118            // Best-effort join. We do not panic in Drop; the worker
119            // body is panic-free under our control.
120            let _ = h.join();
121        }
122    }
123}
124
125/// Closure-based callback for "I would have prefetched this block".
126/// Lets tests assert the worker's choices without coupling to the
127/// log format.
128pub type PrefetchEmit = Arc<dyn Fn(&PrefetchPlan) + Send + Sync>;
129
130/// One cycle's prefetch plan. Held briefly inside the worker thread,
131/// then handed to `emit` (v1 / dry-run) or to the fetcher (v2).
132#[derive(Clone, Debug)]
133pub struct PrefetchPlan {
134    /// Total entries scored this cycle.
135    pub scored: usize,
136    /// Top-K selected (capped by `PrefetchConfig::top_k`).
137    pub selected: Vec<(BlockHash, BlockMeta, f64)>,
138    /// Wall-clock cost of the cycle.
139    pub elapsed: Duration,
140}
141
142/// Materialization surface for v2 prefetch. The worker holds an
143/// `Arc<dyn PrefetchFetcher>` so the algorithm crate can issue GETs
144/// without the `block_prefetch` module depending on the embed module's
145/// generic `WombatKVKvStore<S>` shape.
146///
147/// Implementations must be cheap to call (`contains_flat` should be a
148/// single filesystem stat) and `fetch_block` must populate the local
149/// flat tier on success, the whole point of v2 is to warm the flat
150/// cache before request time.
151pub trait PrefetchFetcher: Send + Sync {
152    /// Returns true if the block is already materialized in the local
153    /// flat tier. The worker skips already-flat blocks to keep cycle
154    /// cost bounded.
155    fn contains_flat(&self, namespace: &str, key: &str) -> bool;
156
157    /// Fetch the block. Returns `Ok(Some(bytes_len))` on hit (and the
158    /// implementation populates flat/foyer), `Ok(None)` on miss, and
159    /// `Err(message)` on backend error. The worker logs+continues on
160    /// error; one bad block does not stop the cycle.
161    fn fetch_block(&self, namespace: &str, key: &str) -> Result<Option<u64>, String>;
162}
163
164/// Per-cycle outcome counts. Surfaced via the `[MyelonInstr]` event
165/// emitted by [`default_v2_emit`].
166#[derive(Clone, Debug, Default)]
167pub struct PrefetchFetchOutcome {
168    pub scored: usize,
169    pub selected: usize,
170    pub skipped_already_flat: usize,
171    pub fetched: usize,
172    pub failed: usize,
173    pub bytes_materialized: u64,
174    pub elapsed_ms: u128,
175}
176
177/// Score `BlockMeta` under the RFC 0008 §6 heuristic.
178///
179/// `now_ns` is passed in (not pulled from the wall clock here) so the
180/// caller can score a batch with one monotonic reading.
181#[must_use]
182pub fn score_block(meta: &BlockMeta, now_ns: u64, active_model: &ModelDigest) -> f64 {
183    // Weights and decay constant per RFC 0008 §6.
184    const W_RECENCY: f64 = 1.0;
185    const W_CHAIN: f64 = 0.3;
186    const W_MODEL: f64 = 0.2;
187    // ln(2) / 600e9 → half-life of 600 seconds in nanoseconds.
188    let decay: f64 = std::f64::consts::LN_2 / 600.0e9_f64;
189
190    // Recency: bigger when last_access_ns is close to now_ns.
191    // Negative age (touched in the future, clock skew) maps to 1.0
192    // by clamping age to 0.
193    let age_ns = now_ns.saturating_sub(meta.last_access_ns) as f64;
194    let recency = (-decay * age_ns).exp();
195
196    let chain_bonus = if meta.block_seq == 0 { 1.0 } else { 0.0 };
197    let model_bonus = if &meta.model_digest == active_model { 1.0 } else { 0.0 };
198
199    W_RECENCY * recency + W_CHAIN * chain_bonus + W_MODEL * model_bonus
200}
201
202fn now_ns() -> u64 {
203    std::time::SystemTime::now()
204        .duration_since(std::time::UNIX_EPOCH)
205        .map_or(0, |d| d.as_nanos() as u64)
206}
207
208/// Compose the relative block key for a `BlockHash`. Mirrors
209/// `wombatkv-cabi::block_key_for_hash`, both call sites read the
210/// same `wombatkv_radix::BLOCK_KEY_PREFIX` so they can never skew.
211#[must_use]
212pub fn block_key_for_hash(hash: &BlockHash) -> String {
213    use wombatkv_radix::BLOCK_KEY_PREFIX;
214    let mut s = String::with_capacity(BLOCK_KEY_PREFIX.len() + 64);
215    s.push_str(BLOCK_KEY_PREFIX);
216    for b in hash {
217        s.push_str(&hex_pair(*b));
218    }
219    s
220}
221
222fn hex_pair(b: u8) -> String {
223    let hi = HEX[(b >> 4) as usize];
224    let lo = HEX[(b & 0x0f) as usize];
225    let mut s = String::with_capacity(2);
226    s.push(hi);
227    s.push(lo);
228    s
229}
230
231const HEX: [char; 16] =
232    ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'];
233
234/// Run one scoring + selection cycle against the supplied index.
235/// Returns the plan; callers decide what to do with it (emit, GET, etc.).
236///
237/// Splitting "score + select" from "act on the plan" keeps the scoring
238/// logic synchronous and unit-testable without spinning up a thread.
239#[must_use]
240pub fn run_cycle(index: &dyn MetadataIndex, config: &PrefetchConfig) -> PrefetchPlan {
241    let started = Instant::now();
242    let now = now_ns();
243    let snapshot = index.entries();
244    let scored = snapshot.len();
245
246    // Score every entry. The scoring loop is O(N); the sort below
247    // dominates at large N. With `top_k` typically ≤ 32, a partial
248    // sort would beat `sort_by`, but the simpler full sort keeps the
249    // code obvious, revisit when M exceeds 10^4.
250    let mut scored_entries: Vec<(BlockHash, BlockMeta, f64)> = snapshot
251        .into_iter()
252        .map(|(h, m)| {
253            let s = score_block(&m, now, &config.model_digest);
254            (h, m, s)
255        })
256        .collect();
257
258    scored_entries.sort_by(|a, b| {
259        // Higher score first; NaN treated as -inf (shouldn't occur,
260        // but guards against future scoring tweaks).
261        b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal)
262    });
263
264    let mut selected = scored_entries;
265    selected.truncate(config.top_k);
266
267    PrefetchPlan { scored, selected, elapsed: started.elapsed() }
268}
269
270/// Spawn the v1 / dry-run worker. Scores + logs, never fetches.
271///
272/// Retained for tests, diagnostic deployments, and the
273/// `WMBT_KV_PREFETCH_DRY_RUN=1` escape hatch (configured by the
274/// embed-side `start_prefetcher` when the env is set).
275///
276/// `index` is held inside the worker via the supplied Arc clone, so the
277/// thread can read it without holding a reference to the outer struct.
278/// `emit` is invoked once per cycle with the plan; pass `default_emit`
279/// for the standard `[MyelonInstr]` event, or a test-side closure to
280/// capture the plan inline.
281pub fn spawn_worker(
282    index: Arc<dyn MetadataIndex>,
283    config: PrefetchConfig,
284    emit: PrefetchEmit,
285) -> PrefetchWorker {
286    let stop = Arc::new(AtomicBool::new(false));
287    let stop_for_thread = stop.clone();
288    let handle = std::thread::Builder::new()
289        .name("wombatkv-prefetch".to_string())
290        .spawn(move || {
291            // Bounded sleep loop: we don't want to delay shutdown by
292            // a full `interval` on Drop, so sleep in slices and check
293            // the stop flag each slice.
294            let slice = Duration::from_millis(25);
295            loop {
296                if stop_for_thread.load(Ordering::SeqCst) {
297                    break;
298                }
299                let plan = run_cycle(index.as_ref(), &config);
300                emit(&plan);
301
302                // Sleep `interval`, but in `slice` chunks so the
303                // worker reacts to stop within ≤ slice (25 ms).
304                let mut remaining = config.interval;
305                while remaining > Duration::ZERO {
306                    if stop_for_thread.load(Ordering::SeqCst) {
307                        break;
308                    }
309                    let s = remaining.min(slice);
310                    std::thread::sleep(s);
311                    remaining = remaining.saturating_sub(s);
312                }
313            }
314        })
315        .expect("spawn prefetch worker");
316
317    PrefetchWorker { handle: Some(handle), stop }
318}
319
320/// Spawn the v2 worker that actually fetches the top-K each cycle.
321///
322/// Per cycle:
323///   1. Snapshot the metadata index + score per RFC 0008 §6.
324///   2. Take top-K candidates.
325///   3. Filter out candidates already present in the local flat cache.
326///   4. For each remaining: call `fetcher.fetch_block(namespace, key)`
327///      sequentially. On error, log + continue.
328///   5. Emit a per-cycle `[MyelonInstr]` event with detailed stage counts.
329///
330/// The fetch loop is sequential by design. v3 may parallelize via
331/// `std::thread::scope` when cycle-time becomes a bottleneck, see
332/// module docs.
333pub fn spawn_worker_v2(
334    index: Arc<dyn MetadataIndex>,
335    config: PrefetchConfig,
336    fetcher: Arc<dyn PrefetchFetcher>,
337    emit_outcome: Arc<dyn Fn(&PrefetchFetchOutcome) + Send + Sync>,
338) -> PrefetchWorker {
339    let stop = Arc::new(AtomicBool::new(false));
340    let stop_for_thread = stop.clone();
341    let handle = std::thread::Builder::new()
342        .name("wombatkv-prefetch-v2".to_string())
343        .spawn(move || {
344            let slice = Duration::from_millis(25);
345            loop {
346                if stop_for_thread.load(Ordering::SeqCst) {
347                    break;
348                }
349                let outcome =
350                    run_cycle_v2(index.as_ref(), fetcher.as_ref(), &config, &stop_for_thread);
351                emit_outcome(&outcome);
352
353                let mut remaining = config.interval;
354                while remaining > Duration::ZERO {
355                    if stop_for_thread.load(Ordering::SeqCst) {
356                        break;
357                    }
358                    let s = remaining.min(slice);
359                    std::thread::sleep(s);
360                    remaining = remaining.saturating_sub(s);
361                }
362            }
363        })
364        .expect("spawn prefetch worker v2");
365
366    PrefetchWorker { handle: Some(handle), stop }
367}
368
369/// Run one v2 cycle: score, select, fetch top-K, return the per-stage
370/// counts. Exposed for tests so the cycle can run inline (without
371/// thread orchestration).
372///
373/// `stop` is honored mid-fetch: if it flips to true between fetches,
374/// the loop exits and the partial outcome is returned.
375#[must_use]
376pub fn run_cycle_v2(
377    index: &dyn MetadataIndex,
378    fetcher: &dyn PrefetchFetcher,
379    config: &PrefetchConfig,
380    stop: &AtomicBool,
381) -> PrefetchFetchOutcome {
382    let started = Instant::now();
383    let plan = run_cycle(index, config);
384    let scored = plan.scored;
385    let selected = plan.selected.len();
386
387    let mut skipped_already_flat = 0_usize;
388    let mut fetched = 0_usize;
389    let mut failed = 0_usize;
390    let mut bytes_materialized = 0_u64;
391
392    for (hash, _meta, _score) in plan.selected {
393        if stop.load(Ordering::SeqCst) {
394            break;
395        }
396        let key = block_key_for_hash(&hash);
397        if fetcher.contains_flat(&config.namespace, &key) {
398            skipped_already_flat += 1;
399            continue;
400        }
401        let fetch_started = Instant::now();
402        match fetcher.fetch_block(&config.namespace, &key) {
403            Ok(Some(bytes_len)) => {
404                let cost = fetch_started.elapsed();
405                if cost > FETCH_WALLCLOCK_CAP {
406                    eprintln!(
407                        "wombatkv[prefetch v2]: get_kv({key}) took {cost:?} \
408                         (cap {FETCH_WALLCLOCK_CAP:?}); aborting remainder of cycle"
409                    );
410                    fetched += 1;
411                    bytes_materialized = bytes_materialized.saturating_add(bytes_len);
412                    break;
413                }
414                fetched += 1;
415                bytes_materialized = bytes_materialized.saturating_add(bytes_len);
416            }
417            Ok(None) => {
418                // Miss is not an error, the metadata index can ride
419                // ahead of S3 (e.g., during a delete-replay).
420                failed += 1;
421            }
422            Err(err) => {
423                failed += 1;
424                eprintln!("wombatkv[prefetch v2]: get_kv({key}) failed: {err}");
425            }
426        }
427    }
428
429    PrefetchFetchOutcome {
430        scored,
431        selected,
432        skipped_already_flat,
433        fetched,
434        failed,
435        bytes_materialized,
436        elapsed_ms: started.elapsed().as_millis(),
437    }
438}
439
440/// Default v1 / dry-run emit: a `[MyelonInstr]` JSON line on stderr per
441/// cycle. Mirrors the existing event shape in `embed.rs` so log parsers
442/// see one consistent envelope across read/write/prefetch paths.
443#[must_use]
444pub fn default_emit() -> PrefetchEmit {
445    Arc::new(|plan: &PrefetchPlan| {
446        let elapsed_ms = plan.elapsed.as_millis();
447        // v1 is log-only: the actual GET would land here. We emit the
448        // count of "would-materialize" as `materialized` so the event
449        // shape doesn't churn when v2 actually fetches.
450        let scored = plan.scored;
451        let materialized = plan.selected.len();
452        eprintln!(
453            "[MyelonInstr] {{\"scope\":\"wmbt_kv_timing\",\"fn\":\"prefetch_cycle\",\
454             \"stages\":{{\"scored\":{scored},\"materialized\":{materialized},\
455             \"elapsed_ms\":{elapsed_ms}}}}}"
456        );
457    })
458}
459
460/// Default v2 emit: a `[MyelonInstr]` JSON line per cycle with full
461/// stage counts (scored, selected, `skipped_already_flat`, fetched,
462/// failed, `bytes_materialized`, `elapsed_ms`).
463#[must_use]
464pub fn default_v2_emit() -> Arc<dyn Fn(&PrefetchFetchOutcome) + Send + Sync> {
465    Arc::new(|o: &PrefetchFetchOutcome| {
466        eprintln!(
467            "[MyelonInstr] {{\"scope\":\"wmbt_kv_timing\",\"fn\":\"prefetch_cycle_v2\",\
468             \"stages\":{{\"scored\":{},\"selected\":{},\"skipped_already_flat\":{},\
469             \"fetched\":{},\"failed\":{},\"bytes_materialized\":{},\"elapsed_ms\":{}}}}}",
470            o.scored,
471            o.selected,
472            o.skipped_already_flat,
473            o.fetched,
474            o.failed,
475            o.bytes_materialized,
476            o.elapsed_ms,
477        );
478    })
479}
480
481/// Returns true if the v1 / dry-run fallback is requested via env.
482/// `WMBT_KV_PREFETCH_DRY_RUN=1` (and the usual truthy synonyms) flips
483/// the embed-side `start_prefetcher` back to log-only behavior.
484#[must_use]
485pub fn dry_run_enabled() -> bool {
486    matches!(
487        std::env::var("WMBT_KV_PREFETCH_DRY_RUN").ok().as_deref(),
488        Some("1" | "true" | "TRUE" | "yes" | "on")
489    )
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495    use std::sync::Mutex;
496    use std::time::Duration;
497    use wombatkv_radix::InMemoryMetadataIndex;
498
499    fn mk_meta(seq: u32, last_access_ns: u64, model: ModelDigest) -> BlockMeta {
500        let mut m = BlockMeta {
501            parent_hash: BlockMeta::ZERO_HASH,
502            block_seq: seq,
503            payload_bytes: 1024,
504            last_access_ns,
505            model_digest: model,
506            layout_tag: [0u8; 16],
507            ext_flags: 0,
508        };
509        // Don't use new_root/new_successor here, those re-stamp
510        // last_access_ns to now(), which fights the test fixture.
511        m.last_access_ns = last_access_ns;
512        m
513    }
514
515    fn make_hash(seed: u8) -> BlockHash {
516        let mut h = [0u8; 32];
517        h[0] = seed;
518        h
519    }
520
521    #[test]
522    fn score_recency_decay_monotone() {
523        // Same model, same seq → score driven purely by recency.
524        let model = [7u8; 24];
525        let now = 10_000_000_000_000_u64; // 10 s
526        let fresh = mk_meta(1, now - 1_000_000_000, model); // 1 s ago
527        let stale = mk_meta(1, now - 600_000_000_000, model); // 10 min ago
528        let ancient = mk_meta(1, now - 3_600_000_000_000, model); // 1 h ago
529
530        let s_fresh = score_block(&fresh, now, &model);
531        let s_stale = score_block(&stale, now, &model);
532        let s_ancient = score_block(&ancient, now, &model);
533
534        assert!(s_fresh > s_stale);
535        assert!(s_stale > s_ancient);
536        // Half-life is 10 minutes → stale should be ~half of fresh's
537        // recency component plus the model bonus (0.2, same for all).
538        // 0.2 model bonus dominates the small recency tail at 10 min,
539        // so check the recency-only delta instead.
540        let recency_fresh = s_fresh - 0.2_f64;
541        let recency_stale = s_stale - 0.2_f64;
542        let ratio = recency_stale / recency_fresh;
543        // Allow generous slack, should be near 0.5 (half-life).
544        assert!((0.40..0.60).contains(&ratio), "stale/fresh recency ratio {ratio} not near 0.5");
545    }
546
547    #[test]
548    fn score_chain_head_outranks_successor_at_equal_recency() {
549        let model = [7u8; 24];
550        let now = 10_000_000_000_000_u64;
551        let head = mk_meta(0, now - 1_000_000_000, model);
552        let succ = mk_meta(5, now - 1_000_000_000, model);
553        assert!(score_block(&head, now, &model) > score_block(&succ, now, &model));
554    }
555
556    #[test]
557    fn score_model_affinity_outranks_others() {
558        let active = [7u8; 24];
559        let other = [42u8; 24];
560        let now = 10_000_000_000_000_u64;
561        let same = mk_meta(1, now - 1_000_000_000, active);
562        let diff = mk_meta(1, now - 1_000_000_000, other);
563        assert!(score_block(&same, now, &active) > score_block(&diff, now, &active));
564    }
565
566    #[test]
567    fn run_cycle_picks_top_k_by_score() {
568        let idx = InMemoryMetadataIndex::new();
569        let active = [7u8; 24];
570        let now = now_ns();
571
572        // 100 synthetic entries. Hash 0..49 → "old" (stale recency).
573        // Hash 50..99 → "fresh", and hash %3==0 within fresh are chain heads.
574        //
575        // We use `bulk_load` (not `insert`) because `MetadataIndex::insert`
576        // calls `meta.touch()` which overwrites our fixture's last_access_ns
577        // with the wall clock, defeating the staleness control.
578        let entries: Vec<(BlockHash, BlockMeta)> = (0..100_u8)
579            .map(|i| {
580                let h = make_hash(i);
581                let is_fresh = i >= 50;
582                let last = if is_fresh {
583                    now.saturating_sub(1_000_000_000) // 1 s ago
584                } else {
585                    now.saturating_sub(3_600_000_000_000) // 1 h ago
586                };
587                let seq = if is_fresh && i % 3 == 0 { 0 } else { u32::from(i) + 1 };
588                (h, mk_meta(seq, last, active))
589            })
590            .collect();
591        idx.bulk_load(entries);
592        assert_eq!(idx.len(), 100);
593
594        let cfg = PrefetchConfig {
595            interval: Duration::from_mins(1),
596            top_k: 10,
597            model_digest: active,
598            namespace: String::new(),
599        };
600        let plan = run_cycle(&idx, &cfg);
601        assert_eq!(plan.scored, 100);
602        assert_eq!(plan.selected.len(), 10);
603
604        // Every selected block must come from the "fresh" half
605        // (hash >= 50). Old blocks have age = 1 h ≫ half-life, so
606        // their recency component is ≈ 2^-6 ≈ 0.016, well below
607        // anything in the fresh half.
608        for (h, _, _) in &plan.selected {
609            assert!(h[0] >= 50, "selected stale block {}", h[0]);
610        }
611
612        // Scores must be sorted descending.
613        for w in plan.selected.windows(2) {
614            assert!(w[0].2 >= w[1].2, "scores not descending");
615        }
616    }
617
618    #[test]
619    fn worker_runs_cycle_and_stops_within_one_second() {
620        let idx: Arc<dyn MetadataIndex> = Arc::new(InMemoryMetadataIndex::new());
621        // Cast back to insert; we hold the InMemoryMetadataIndex via
622        // the trait object, and `MetadataIndex::insert` works directly.
623        let active = [7u8; 24];
624        idx.insert(make_hash(1), mk_meta(0, now_ns(), active));
625        idx.insert(make_hash(2), mk_meta(1, now_ns(), active));
626
627        let cycles = Arc::new(Mutex::new(0_usize));
628        let cycles_cb = cycles.clone();
629        let emit: PrefetchEmit = Arc::new(move |_plan: &PrefetchPlan| {
630            *cycles_cb.lock().unwrap() += 1;
631        });
632
633        let cfg = PrefetchConfig {
634            interval: Duration::from_millis(50),
635            top_k: 4,
636            model_digest: active,
637            namespace: String::new(),
638        };
639
640        let started = Instant::now();
641        let worker = spawn_worker(idx, cfg, emit);
642
643        // Wait long enough for ≥ 2 cycles.
644        std::thread::sleep(Duration::from_millis(200));
645        let observed = *cycles.lock().unwrap();
646        assert!(observed >= 2, "expected ≥ 2 cycles, got {observed}");
647
648        // Drop the worker → triggers stop + join. Must complete
649        // well under 1 s.
650        drop(worker);
651        let drop_time = started.elapsed();
652        assert!(drop_time < Duration::from_secs(2), "worker shutdown took {drop_time:?}");
653    }
654
655    #[test]
656    fn worker_handles_empty_index_gracefully() {
657        let idx: Arc<dyn MetadataIndex> = Arc::new(InMemoryMetadataIndex::new());
658        let cycles = Arc::new(Mutex::new(0_usize));
659        let cycles_cb = cycles.clone();
660        let emit: PrefetchEmit = Arc::new(move |plan: &PrefetchPlan| {
661            assert_eq!(plan.scored, 0);
662            assert!(plan.selected.is_empty());
663            *cycles_cb.lock().unwrap() += 1;
664        });
665
666        let cfg = PrefetchConfig {
667            interval: Duration::from_millis(30),
668            top_k: 8,
669            model_digest: [0u8; 24],
670            namespace: String::new(),
671        };
672        let worker = spawn_worker(idx, cfg, emit);
673        std::thread::sleep(Duration::from_millis(120));
674        let n = *cycles.lock().unwrap();
675        drop(worker);
676        assert!(n >= 2, "expected ≥ 2 cycles on empty index, got {n}");
677    }
678
679    #[test]
680    fn signal_stop_makes_drop_fast_even_with_long_interval() {
681        let idx: Arc<dyn MetadataIndex> = Arc::new(InMemoryMetadataIndex::new());
682        let emit: PrefetchEmit = Arc::new(|_| {});
683        let cfg = PrefetchConfig {
684            // Long interval, without sliced sleep, drop would
685            // wait this long.
686            interval: Duration::from_secs(10),
687            top_k: 1,
688            model_digest: [0u8; 24],
689            namespace: String::new(),
690        };
691        let worker = spawn_worker(idx, cfg, emit);
692        // Tiny sleep so the worker has run at least one cycle and is
693        // now sleeping inside the inner slice loop.
694        std::thread::sleep(Duration::from_millis(50));
695        let t = Instant::now();
696        worker.signal_stop();
697        drop(worker);
698        assert!(t.elapsed() < Duration::from_millis(500), "drop took {:?}", t.elapsed());
699    }
700
701    #[test]
702    fn block_key_for_hash_matches_cabi_format() {
703        // Mirror the cabi: `wombatkv/v1/block/b3=<64-char-lower-hex>`.
704        use wombatkv_radix::BLOCK_KEY_PREFIX;
705        let mut h = [0u8; 32];
706        h[0] = 0xab;
707        h[1] = 0xcd;
708        h[31] = 0xef;
709        let key = block_key_for_hash(&h);
710        assert_eq!(key.len(), BLOCK_KEY_PREFIX.len() + 64);
711        assert!(key.starts_with(BLOCK_KEY_PREFIX));
712        let hex = &key[BLOCK_KEY_PREFIX.len()..];
713        assert!(hex.starts_with("abcd"), "got hex prefix {hex:?}");
714        assert!(hex.ends_with("ef"));
715        assert!(hex.chars().all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase()));
716    }
717
718    // ============================================================
719    // v2 tests (RFC 0008 §6, second iteration).
720    //
721    // The v2 worker is wired to a `PrefetchFetcher`. We use a recording
722    // mock so tests can assert exactly which keys would have hit the
723    // store, without spinning up a `WombatKVKvStore<InMemoryObjectStore>`
724    // (that round-trip is covered by the cabi/embed integration tests).
725    // ============================================================
726
727    struct MockFetcher {
728        flat_keys: Mutex<Vec<String>>,
729        // Optional canned errors per key, consumed once.
730        errors: Mutex<std::collections::HashMap<String, String>>,
731        // Optional bytes per key (hits). Missing = miss.
732        bytes: Mutex<std::collections::HashMap<String, u64>>,
733        // Recorded fetch calls (namespace, key) for assertion.
734        calls: Mutex<Vec<(String, String)>>,
735    }
736
737    impl MockFetcher {
738        fn new() -> Self {
739            Self {
740                flat_keys: Mutex::new(Vec::new()),
741                errors: Mutex::new(std::collections::HashMap::new()),
742                bytes: Mutex::new(std::collections::HashMap::new()),
743                calls: Mutex::new(Vec::new()),
744            }
745        }
746
747        fn stock_hit(&self, key: &str, len: u64) {
748            self.bytes.lock().unwrap().insert(key.to_string(), len);
749        }
750
751        fn stock_err(&self, key: &str, msg: &str) {
752            self.errors.lock().unwrap().insert(key.to_string(), msg.to_string());
753        }
754
755        fn pre_warm_flat(&self, key: &str) {
756            self.flat_keys.lock().unwrap().push(key.to_string());
757        }
758
759        fn calls(&self) -> Vec<(String, String)> {
760            self.calls.lock().unwrap().clone()
761        }
762    }
763
764    impl PrefetchFetcher for MockFetcher {
765        fn contains_flat(&self, _namespace: &str, key: &str) -> bool {
766            self.flat_keys.lock().unwrap().iter().any(|k| k == key)
767        }
768
769        fn fetch_block(&self, namespace: &str, key: &str) -> Result<Option<u64>, String> {
770            self.calls.lock().unwrap().push((namespace.to_string(), key.to_string()));
771            if let Some(err) = self.errors.lock().unwrap().remove(key) {
772                return Err(err);
773            }
774            if let Some(len) = self.bytes.lock().unwrap().get(key).copied() {
775                // Simulate the store populating flat on a successful
776                // fetch so a follow-up cycle would skip this key.
777                self.flat_keys.lock().unwrap().push(key.to_string());
778                return Ok(Some(len));
779            }
780            Ok(None)
781        }
782    }
783
784    fn seed_recent_blocks(
785        idx: &InMemoryMetadataIndex,
786        count: u8,
787        model: ModelDigest,
788    ) -> Vec<BlockHash> {
789        let now = now_ns();
790        let mut hashes = Vec::with_capacity(count as usize);
791        for i in 0..count {
792            let h = make_hash(i);
793            // Stagger last_access slightly so scores are deterministic
794            // (older index = older time, so earlier entries score
795            // slightly higher and the worker's selection is stable).
796            let last = now.saturating_sub(1_000_000_000_u64.saturating_mul(u64::from(i)));
797            idx.bulk_load(std::iter::once((h, mk_meta(0, last, model))));
798            hashes.push(h);
799        }
800        hashes
801    }
802
803    #[test]
804    fn v2_fetches_top_k_against_kvstore() {
805        // Setup: 100 entries in the metadata index. The mock fetcher
806        // has bytes stocked for every key. top_k=10 → exactly 10 calls
807        // to fetch_block, all hitting.
808        let idx = Arc::new(InMemoryMetadataIndex::new());
809        let model = [7u8; 24];
810        let hashes = seed_recent_blocks(&idx, 100, model);
811
812        let fetcher = Arc::new(MockFetcher::new());
813        for h in &hashes {
814            let key = block_key_for_hash(h);
815            fetcher.stock_hit(&key, 1024);
816        }
817
818        let cfg = PrefetchConfig {
819            interval: Duration::from_mins(1),
820            top_k: 10,
821            model_digest: model,
822            namespace: "ns-a".to_string(),
823        };
824
825        let stop = AtomicBool::new(false);
826        let outcome = run_cycle_v2(
827            idx.as_ref() as &dyn MetadataIndex,
828            fetcher.as_ref() as &dyn PrefetchFetcher,
829            &cfg,
830            &stop,
831        );
832
833        assert_eq!(outcome.scored, 100);
834        assert_eq!(outcome.selected, 10);
835        assert_eq!(outcome.skipped_already_flat, 0);
836        assert_eq!(outcome.fetched, 10);
837        assert_eq!(outcome.failed, 0);
838        assert_eq!(outcome.bytes_materialized, 10 * 1024);
839
840        let calls = fetcher.calls();
841        assert_eq!(calls.len(), 10);
842        for (ns, key) in &calls {
843            assert_eq!(ns, "ns-a");
844            assert!(key.starts_with(wombatkv_radix::BLOCK_KEY_PREFIX));
845        }
846    }
847
848    #[test]
849    fn v2_skips_already_flat() {
850        // 10 entries in the index; top_k=10. Pre-warm 5 of them in
851        // the mock's flat tier. The worker must skip those and only
852        // fetch the other 5.
853        let idx = Arc::new(InMemoryMetadataIndex::new());
854        let model = [7u8; 24];
855        let hashes = seed_recent_blocks(&idx, 10, model);
856
857        let fetcher = Arc::new(MockFetcher::new());
858        for h in &hashes {
859            let key = block_key_for_hash(h);
860            fetcher.stock_hit(&key, 256);
861        }
862        // Pre-warm flat for the first 5 keys.
863        for h in &hashes[..5] {
864            let key = block_key_for_hash(h);
865            fetcher.pre_warm_flat(&key);
866        }
867
868        let cfg = PrefetchConfig {
869            interval: Duration::from_mins(1),
870            top_k: 10,
871            model_digest: model,
872            namespace: "ns-a".to_string(),
873        };
874        let stop = AtomicBool::new(false);
875        let outcome = run_cycle_v2(
876            idx.as_ref() as &dyn MetadataIndex,
877            fetcher.as_ref() as &dyn PrefetchFetcher,
878            &cfg,
879            &stop,
880        );
881
882        assert_eq!(outcome.scored, 10);
883        assert_eq!(outcome.selected, 10);
884        assert_eq!(outcome.skipped_already_flat, 5);
885        assert_eq!(outcome.fetched, 5);
886        assert_eq!(outcome.failed, 0);
887        assert_eq!(outcome.bytes_materialized, 5 * 256);
888
889        let calls = fetcher.calls();
890        assert_eq!(calls.len(), 5, "should only fetch the 5 not-yet-flat keys");
891    }
892
893    #[test]
894    fn v2_handles_get_kv_errors_gracefully() {
895        let idx = Arc::new(InMemoryMetadataIndex::new());
896        let model = [7u8; 24];
897        let hashes = seed_recent_blocks(&idx, 6, model);
898
899        let fetcher = Arc::new(MockFetcher::new());
900        for (i, h) in hashes.iter().enumerate() {
901            let key = block_key_for_hash(h);
902            if i % 2 == 0 {
903                fetcher.stock_err(&key, "synthetic backend error");
904            } else {
905                fetcher.stock_hit(&key, 100);
906            }
907        }
908
909        let cfg = PrefetchConfig {
910            interval: Duration::from_mins(1),
911            top_k: 6,
912            model_digest: model,
913            namespace: "ns-a".to_string(),
914        };
915        let stop = AtomicBool::new(false);
916        let outcome = run_cycle_v2(
917            idx.as_ref() as &dyn MetadataIndex,
918            fetcher.as_ref() as &dyn PrefetchFetcher,
919            &cfg,
920            &stop,
921        );
922
923        // Worker logged + continued through every error; cycle did
924        // not crash. 3 successful fetches, 3 logged failures.
925        assert_eq!(outcome.scored, 6);
926        assert_eq!(outcome.selected, 6);
927        assert_eq!(outcome.fetched, 3);
928        assert_eq!(outcome.failed, 3);
929        assert_eq!(outcome.bytes_materialized, 3 * 100);
930        assert_eq!(fetcher.calls().len(), 6);
931    }
932
933    #[test]
934    fn v2_dry_run_does_not_fetch() {
935        // With WMBT_KV_PREFETCH_DRY_RUN=1, the embed-side
936        // `start_prefetcher` is expected to route to `spawn_worker`
937        // (v1) rather than `spawn_worker_v2`. Validate the gate
938        // helper + that a v1 worker over the same index + emit makes
939        // zero fetch calls on the fetcher.
940        //
941        // We don't `std::env::set_var` here because that's process-
942        // global and would pollute other tests; we exercise the gate
943        // by directly using the v1 path.
944        let idx: Arc<dyn MetadataIndex> = Arc::new(InMemoryMetadataIndex::new());
945        let model = [7u8; 24];
946        // Mutate the concrete impl via downcast-equivalent: use
947        // InMemoryMetadataIndex through the Arc, since we constructed
948        // it ourselves.
949        let inner = InMemoryMetadataIndex::new();
950        seed_recent_blocks(&inner, 4, model);
951        // Snapshot into the Arc-held index via bulk_load.
952        // (The Arc<dyn ..> is what spawn_worker holds.)
953        let snapshot = inner.entries();
954        if let Some(_concrete) = idx.as_ref().entries().first() {
955            // Already populated; no-op.
956        }
957        // Push snapshot through the dyn index via insert. (No
958        // bulk_load on the trait object, concrete InMemoryMetadataIndex
959        // would be needed; the v1 path doesn't care, since we're
960        // asserting "no fetcher calls".)
961        for (h, m) in snapshot {
962            idx.insert(h, m);
963        }
964
965        let fetcher = Arc::new(MockFetcher::new());
966        let plan_count = Arc::new(Mutex::new(0_usize));
967        let pc = plan_count.clone();
968        let emit: PrefetchEmit = Arc::new(move |_plan: &PrefetchPlan| {
969            *pc.lock().unwrap() += 1;
970        });
971
972        let cfg = PrefetchConfig {
973            interval: Duration::from_millis(40),
974            top_k: 4,
975            model_digest: model,
976            namespace: "ns-dry".to_string(),
977        };
978        let worker = spawn_worker(idx, cfg, emit);
979        std::thread::sleep(Duration::from_millis(150));
980        drop(worker);
981
982        // v1 emit fired at least once.
983        assert!(*plan_count.lock().unwrap() >= 1);
984        // ...and the fetcher was never touched.
985        assert!(
986            fetcher.calls().is_empty(),
987            "dry-run path must not call PrefetchFetcher: got {} calls",
988            fetcher.calls().len()
989        );
990    }
991
992    #[test]
993    fn dry_run_env_helper_reads_truthy_values() {
994        // Wrap each set_var in a single-threaded section so other
995        // tests don't see the env flicker. We're already inside a
996        // #[test], so cargo serializes vs other ENV-touching tests
997        // only by chance; keep the env restored on the way out.
998        let saved = std::env::var("WMBT_KV_PREFETCH_DRY_RUN").ok();
999        std::env::remove_var("WMBT_KV_PREFETCH_DRY_RUN");
1000        assert!(!dry_run_enabled());
1001        std::env::set_var("WMBT_KV_PREFETCH_DRY_RUN", "1");
1002        assert!(dry_run_enabled());
1003        std::env::set_var("WMBT_KV_PREFETCH_DRY_RUN", "yes");
1004        assert!(dry_run_enabled());
1005        std::env::set_var("WMBT_KV_PREFETCH_DRY_RUN", "0");
1006        assert!(!dry_run_enabled());
1007        std::env::remove_var("WMBT_KV_PREFETCH_DRY_RUN");
1008        if let Some(v) = saved {
1009            std::env::set_var("WMBT_KV_PREFETCH_DRY_RUN", v);
1010        }
1011    }
1012
1013    #[test]
1014    fn v2_worker_runs_and_stops() {
1015        // End-to-end: spawn the v2 worker against a real-ish (mock)
1016        // fetcher, let it run a couple of cycles, then drop.
1017        let idx_concrete = Arc::new(InMemoryMetadataIndex::new());
1018        let model = [3u8; 24];
1019        let hashes = seed_recent_blocks(&idx_concrete, 5, model);
1020        let idx: Arc<dyn MetadataIndex> = idx_concrete.clone();
1021
1022        let fetcher = Arc::new(MockFetcher::new());
1023        for h in &hashes {
1024            fetcher.stock_hit(&block_key_for_hash(h), 64);
1025        }
1026
1027        let cfg = PrefetchConfig {
1028            interval: Duration::from_millis(30),
1029            top_k: 5,
1030            model_digest: model,
1031            namespace: "ns-x".to_string(),
1032        };
1033
1034        let outcomes = Arc::new(Mutex::new(Vec::<PrefetchFetchOutcome>::new()));
1035        let outcomes_for_cb = outcomes.clone();
1036        let emit_outcome: Arc<dyn Fn(&PrefetchFetchOutcome) + Send + Sync> =
1037            Arc::new(move |o: &PrefetchFetchOutcome| {
1038                outcomes_for_cb.lock().unwrap().push(o.clone());
1039            });
1040
1041        let fetcher_dyn: Arc<dyn PrefetchFetcher> = fetcher.clone();
1042        let started = Instant::now();
1043        let worker = spawn_worker_v2(idx, cfg, fetcher_dyn, emit_outcome);
1044        std::thread::sleep(Duration::from_millis(140));
1045        drop(worker);
1046        assert!(started.elapsed() < Duration::from_secs(2));
1047
1048        let observed = outcomes.lock().unwrap();
1049        assert!(observed.len() >= 2);
1050        // First cycle should have done all 5 fetches; subsequent
1051        // cycles should see them all as already-flat.
1052        assert_eq!(observed[0].fetched, 5);
1053        if observed.len() >= 2 {
1054            assert_eq!(observed[1].fetched, 0);
1055            assert_eq!(observed[1].skipped_already_flat, 5);
1056        }
1057    }
1058}