Skip to main content

pond/
embed.rs

1//! The embedding stage: candle XLM-RoBERTa FP16 ([`CandleEmbedder`]) plus
2//! the batch-oriented [`EmbedWorker`] that fills `messages.vector` /
3//! `messages.embedding_model` (spec.md#search). One message produces one
4//! vector - there is no chunking.
5//!
6//! [`LazyEmbedder`] caches a loaded backend for `pond mcp` / `pond serve`
7//! and drops it after [`DEFAULT_IDLE_EVICTION`] of no use. The drop is
8//! clean under macOS `phys_footprint` (post-drop drops to ~107 MiB
9//! regardless of backend), so time-weighted RSS over an interactive MCP
10//! session stays well under the per-instance budget despite the macOS
11//! Metal buffer pool's `iokit_mapped` retention during active queries.
12//!
13//! The worker accumulates messages and calls the model once per fixed-size
14//! batch, never once per message, and writes each batch's vectors to
15//! `messages` in one column-update commit.
16
17use std::sync::Arc;
18use std::sync::OnceLock;
19use std::sync::atomic::{AtomicBool, Ordering};
20use std::time::{Duration, Instant};
21
22use anyhow::{Context, Result, anyhow};
23use candle_core::{DType, Device, Tensor};
24use candle_nn::VarBuilder;
25use candle_transformers::models::xlm_roberta::{Config, XLMRobertaModel};
26use tokenizers::Tokenizer;
27use tokio::sync::Mutex;
28use tokio_stream::StreamExt;
29
30use crate::sessions::{EmbeddedMessage, PendingMessage, Store, embedding_dim};
31
32/// e5's training context. The tokenizer truncates input past it before
33/// inference - one message, one vector, bounded embed cost.
34pub(crate) const MAX_TOKENS: usize = 512;
35
36/// The candle e5 backend: XLM-RoBERTa FP16 weights on the GPU (Metal on
37/// macOS, CUDA on a `cuda`-feature non-macOS build, CPU otherwise).
38/// `forward` is `&self`, so no interior mutability is needed.
39pub struct CandleEmbedder {
40    model: XLMRobertaModel,
41    tokenizer: Tokenizer,
42    device: Device,
43}
44
45impl CandleEmbedder {
46    /// Load the configured XLM-RoBERTa model from HuggingFace (cached after
47    /// the first download) onto the best available device.
48    pub fn load() -> Result<Self> {
49        let device = select_device();
50        let id = model_id();
51        let api = hf_hub::api::sync::Api::new().context("init HuggingFace hub client")?;
52        let repo = api.model(id.to_owned());
53        let fetch = |file: &str| {
54            repo.get(file)
55                .with_context(|| format!("fetch {file} for {id}"))
56        };
57
58        let config: Config =
59            serde_json::from_str(&std::fs::read_to_string(fetch("config.json")?)?)?;
60        if config.hidden_size != embedding_dim() {
61            return Err(anyhow!(
62                "[embeddings].dim = {} but model {id:?} reports hidden_size = {}; \
63                 set [embeddings].dim to match the model's output width.",
64                embedding_dim(),
65                config.hidden_size,
66            ));
67        }
68        // mmap the safetensors file: candle's `safetensors::load` path uses
69        // `std::fs::read` which retains an owned `Vec<u8>` of the full FP32
70        // weights in the system allocator after drop on macOS. mmap avoids
71        // the owned-heap path. Note: candle's Metal pool retains FP32->F16
72        // cast transients regardless (iokit_mapped contribution to
73        // phys_footprint, candle-core/src/metal_backend/device.rs:44-57).
74        let model_path = fetch("model.safetensors")?;
75        #[allow(unsafe_code)]
76        let vb =
77            unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], DType::F16, &device)? };
78        let model = XLMRobertaModel::new(&config, vb)
79            .map_err(|error| anyhow!("load {id} weights: {error}"))?;
80
81        let mut tokenizer = Tokenizer::from_file(fetch("tokenizer.json")?)
82            .map_err(|error| anyhow!("load e5 tokenizer: {error}"))?;
83        tokenizer.with_padding(Some(tokenizers::PaddingParams {
84            strategy: tokenizers::PaddingStrategy::BatchLongest,
85            pad_id: config.pad_token_id,
86            ..Default::default()
87        }));
88        tokenizer
89            .with_truncation(Some(tokenizers::TruncationParams {
90                max_length: MAX_TOKENS,
91                ..Default::default()
92            }))
93            .map_err(|error| anyhow!("configure e5 tokenizer: {error}"))?;
94
95        tracing::info!(model = %id, device = device_label(&device), "loaded embedding model");
96        Ok(Self {
97            model,
98            tokenizer,
99            device,
100        })
101    }
102}
103
104impl Embedder for CandleEmbedder {
105    fn device(&self) -> &str {
106        device_label(&self.device)
107    }
108
109    fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
110        if texts.is_empty() {
111            return Ok(Vec::new());
112        }
113        let encodings = self
114            .tokenizer
115            .encode_batch(texts.to_vec(), true)
116            .map_err(|error| anyhow!("tokenize embedding batch: {error}"))?;
117        let mut ids = Vec::with_capacity(encodings.len());
118        let mut masks = Vec::with_capacity(encodings.len());
119        for encoding in &encodings {
120            ids.push(Tensor::new(encoding.get_ids(), &self.device)?);
121            masks.push(Tensor::new(encoding.get_attention_mask(), &self.device)?);
122        }
123        let input_ids = Tensor::stack(&ids, 0)?;
124        let attention_mask = Tensor::stack(&masks, 0)?;
125        let token_type_ids = input_ids.zeros_like()?;
126        let hidden = self
127            .model
128            .forward(
129                &input_ids,
130                &attention_mask,
131                &token_type_ids,
132                None,
133                None,
134                None,
135            )?
136            .to_dtype(DType::F32)?;
137        let mask = attention_mask.to_dtype(DType::F32)?.unsqueeze(2)?;
138        let summed = hidden.broadcast_mul(&mask)?.sum(1)?;
139        let counts = mask.sum(1)?;
140        let mean = summed.broadcast_div(&counts)?;
141        let norm = mean.sqr()?.sum_keepdim(1)?.sqrt()?;
142        mean.broadcast_div(&norm)?
143            .to_vec2::<f32>()
144            .map_err(|error| anyhow!("read embedding vectors: {error}"))
145    }
146}
147
148fn select_device() -> Device {
149    #[cfg(target_os = "macos")]
150    let device = Device::metal_if_available(0);
151    #[cfg(not(target_os = "macos"))]
152    let device = Device::cuda_if_available(0);
153    device.unwrap_or_else(|error| {
154        tracing::warn!(%error, "GPU device unavailable, falling back to CPU");
155        Device::Cpu
156    })
157}
158
159fn device_label(device: &Device) -> &'static str {
160    match device {
161        Device::Cpu => "cpu",
162        Device::Cuda(_) => "cuda",
163        Device::Metal(_) => "metal",
164    }
165}
166
167/// Arc-shared factory used by [`LazyEmbedder`] to build the backend on
168/// first call (or on reload after idle eviction). Arc so the loader can be
169/// cloned into `spawn_blocking` without consuming `&self`.
170type EmbedLoader = Arc<dyn Fn() -> Result<Arc<dyn Embedder>> + Send + Sync>;
171
172/// How long the cached backend can sit unused before [`LazyEmbedder::get`]
173/// drops it. Five minutes matches typical interactive-MCP conversational
174/// pauses: short enough that a model that's been unused for a turn or two
175/// is gone before the next quiet window, long enough that ordinary
176/// query bursts never pay the reload cost.
177pub const DEFAULT_IDLE_EVICTION: Duration = Duration::from_secs(300);
178
179struct CachedBackend {
180    backend: Arc<dyn Embedder>,
181    last_use: Instant,
182}
183
184/// Lazy holder for an [`Embedder`] with idle eviction. The model isn't
185/// loaded until the first hybrid/vector call asks for it - idle `pond mcp`
186/// / `pond serve` processes pay nothing while no vector queries land. After
187/// `idle_threshold` of inactivity the cached backend is dropped on the
188/// next `get` call; under macOS `phys_footprint` the drop reclaims
189/// ~365-585 MiB cleanly (the post-drop floor is ~107 MiB regardless of
190/// backend). Reload cost is one synchronous model-load (300-500 ms),
191/// absorbed inside the human-paced gap between MCP queries.
192pub struct LazyEmbedder {
193    loader: EmbedLoader,
194    state: Mutex<Option<CachedBackend>>,
195    idle_threshold: Duration,
196}
197
198impl LazyEmbedder {
199    /// candle XLM-RoBERTa FP16 (Metal on macOS / CUDA with `--features cuda`
200    /// / CPU otherwise). The pond default for every entry point.
201    pub fn candle() -> Self {
202        Self::with_loader(Arc::new(|| {
203            Ok(Arc::new(CandleEmbedder::load()?) as Arc<dyn Embedder>)
204        }))
205    }
206
207    /// Build a `LazyEmbedder` from an explicit loader. Used by the bench
208    /// harness to override the idle threshold; production callers use
209    /// [`Self::candle`].
210    pub fn with_loader(loader: EmbedLoader) -> Self {
211        Self {
212            loader,
213            state: Mutex::new(None),
214            idle_threshold: DEFAULT_IDLE_EVICTION,
215        }
216    }
217
218    /// Override the idle-eviction threshold. Pass `Duration::MAX` to disable
219    /// eviction entirely - useful in benches that want a stable steady-state.
220    #[must_use]
221    pub fn with_idle_threshold(mut self, threshold: Duration) -> Self {
222        self.idle_threshold = threshold;
223        self
224    }
225
226    /// Pre-seed with an already-constructed backend. Used by integration
227    /// tests that want to inject a fake `Embedder` without paying the real
228    /// model-load cost. Eviction is disabled so the test fake survives the
229    /// whole test even if a test stalls.
230    pub fn from_loaded(backend: Arc<dyn Embedder>) -> Self {
231        let preloaded = Arc::clone(&backend);
232        let loader: EmbedLoader = Arc::new(move || Ok(Arc::clone(&preloaded)));
233        Self {
234            loader,
235            state: Mutex::new(Some(CachedBackend {
236                backend,
237                last_use: Instant::now(),
238            })),
239            idle_threshold: Duration::MAX,
240        }
241    }
242
243    /// Load (on first call or after eviction) or return the cached handle.
244    /// The candle load is synchronous and blocking, so it runs on
245    /// `spawn_blocking`; the async caller sees a clean `await` point.
246    pub async fn get(&self) -> Result<Arc<dyn Embedder>> {
247        let mut state = self.state.lock().await;
248        let now = Instant::now();
249        if let Some(cached) = &*state
250            && now.duration_since(cached.last_use) > self.idle_threshold
251        {
252            tracing::info!(
253                idle_secs = self.idle_threshold.as_secs(),
254                "evicting idle embedder",
255            );
256            *state = None;
257        }
258        if let Some(cached) = state.as_mut() {
259            cached.last_use = now;
260            return Ok(Arc::clone(&cached.backend));
261        }
262        let loader = Arc::clone(&self.loader);
263        let backend = tokio::task::spawn_blocking(move || loader())
264            .await
265            .map_err(|join_error| anyhow!("embedder load panicked: {join_error}"))??;
266        *state = Some(CachedBackend {
267            backend: Arc::clone(&backend),
268            last_use: now,
269        });
270        Ok(backend)
271    }
272}
273
274/// Default embedding model pond ships a loader for (spec.md#search). Used when
275/// `[embeddings].model` is absent. `pond embed` stamps the runtime model id
276/// (see [`model_id`]) into `messages.embedding_model` with every vector.
277/// e5-small (384-dim) is the default; the paraphrase benchmark set showed no
278/// statistically-significant quality loss vs e5-base while halving vector
279/// storage and ~halving model RSS.
280pub const DEFAULT_MODEL_ID: &str = "intfloat/multilingual-e5-small";
281
282/// Process-wide model id, seeded once at startup from `[embeddings].model` via
283/// [`init_model_id`]. `OnceLock` (not `const`) so a temporary config file can
284/// pick e5-small / e5-large for an experiment without touching every call site.
285/// Uninitialized -> [`DEFAULT_MODEL_ID`], keeping unit tests config-free.
286static MODEL_ID_RUNTIME: OnceLock<String> = OnceLock::new();
287
288/// The active model id. Returns the value installed by [`init_model_id`] or
289/// [`DEFAULT_MODEL_ID`] when nothing has installed one (tests, ad-hoc tooling).
290pub fn model_id() -> &'static str {
291    MODEL_ID_RUNTIME
292        .get()
293        .map(String::as_str)
294        .unwrap_or(DEFAULT_MODEL_ID)
295}
296
297/// Seed [`model_id`] from config. First call wins; later calls with a different
298/// id are silently ignored - the process loads its config once.
299pub fn init_model_id(id: String) {
300    MODEL_ID_RUNTIME.get_or_init(|| id);
301}
302
303/// Messages per model-inference + write batch. e5 truncates at 512 tokens, so
304/// a 32-row batch's padded attention transient stays bounded.
305pub const DEFAULT_BATCH_SIZE: usize = 32;
306
307/// Messages buffered and length-sorted before being cut into model batches.
308/// The tokenizer pads every batch to its longest member, so a batch mixing a short
309/// and a long message embeds the short one at the long one's length. Sorting a
310/// window first clusters similar-length messages, so each batch pads near its
311/// own longest, not the corpus worst case. Bounded so peak memory stays one
312/// window, not the whole backlog. See [`EmbedWorker::with_sort_window`].
313pub const DEFAULT_SORT_WINDOW: usize = 2048;
314
315/// Format a search query for the embedder. e5 is an asymmetric retriever:
316/// its model card prescribes `query: ` on the search side, `passage: ` on
317/// documents. Used by `pond_search` to prepare the query text before the
318/// candle/Metal embed call.
319pub fn format_query(query: &str) -> String {
320    format!("query: {query}")
321}
322
323/// Format a document (one message's `search_text`) for the embedder - the
324/// `passage: ` half of the pair documented on [`format_query`]. Used by
325/// `EmbedWorker` when batching messages for `pond embed`.
326pub fn format_passage(text: &str) -> String {
327    format!("passage: {text}")
328}
329
330/// The embedding seam (spec.md#search): text in, vectors out. The real
331/// backend is [`CandleEmbedder`]; tests substitute an instrumented fake
332/// to assert batching behavior. The vector width is checked at the write
333/// boundary and the model id is whatever [`model_id`] returns at the
334/// time of the write.
335pub trait Embedder: Send + Sync {
336    /// A short label naming the hardware/runtime: `"metal"`, `"cuda"`,
337    /// or `"cpu"`. Used by `pond embed` to surface what backend ran the
338    /// inference; benches print it alongside latency.
339    fn device(&self) -> &str;
340
341    /// Embed a batch of texts. The returned vectors are L2-normalized and
342    /// [`embedding_dim`] long, one per input.
343    fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
344}
345
346/// Outcome of an [`EmbedWorker::run`] pass.
347#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
348pub struct EmbedSummary {
349    /// Messages embedded; one vector each.
350    pub messages: usize,
351    /// Model-inference + write batches issued.
352    pub batches: usize,
353    /// Set when the run exited via the cancel flag instead of stream end -
354    /// the caller uses this to print an interrupted notice and decide whether
355    /// to still rebuild downstream indices.
356    pub cancelled: bool,
357}
358
359/// Per-batch stats handed to a progress callback. Lets `pond embed` drive an
360/// `indicatif` bar without leaking the crate into this module's API.
361#[derive(Debug, Clone, Copy)]
362pub struct BatchProgress {
363    /// Messages embedded in this batch.
364    pub batch_messages: usize,
365    /// Running message total across the run.
366    pub total_messages: usize,
367    /// Running batch count across the run.
368    pub total_batches: usize,
369}
370
371type ProgressFn = Box<dyn Fn(BatchProgress) + Send + Sync>;
372
373/// Fills `messages.vector` / `messages.embedding_model` for the backlog of
374/// un-embedded messages. Reads `messages.search_text` directly, batches it
375/// through the backend one vector each, and writes each batch back to
376/// `messages` by primary key.
377pub struct EmbedWorker<'a, B: Embedder> {
378    store: &'a Store,
379    backend: &'a B,
380    include_stale: bool,
381    /// Optional cap on total messages embedded in one `run` - `None` in
382    /// production (embed everything), set by the benchmark harness to a fixed
383    /// count so a run is a stable, comparable workload.
384    limit: Option<usize>,
385    /// Messages buffered and length-sorted per `drain_window` pass
386    /// ([`DEFAULT_SORT_WINDOW`]); the benchmark sweeps it through
387    /// [`EmbedWorker::with_sort_window`].
388    sort_window: usize,
389    /// Optional per-batch progress callback. Called once per `flush()` with
390    /// the running totals; `pond embed` wires this to an `indicatif` bar.
391    progress: Option<ProgressFn>,
392    /// Set externally (Ctrl-C handler in `pond embed`): the pull loop drains
393    /// the in-memory window before exiting so partial work is committed.
394    cancel: Option<Arc<AtomicBool>>,
395}
396
397impl<'a, B: Embedder> EmbedWorker<'a, B> {
398    /// Build a worker over `store`'s un-embedded backlog. A backend whose
399    /// vectors are the wrong width is rejected at the write boundary
400    /// (`embedding_update_batch`), so there is nothing to validate here.
401    pub fn new(store: &'a Store, backend: &'a B) -> Self {
402        Self {
403            store,
404            backend,
405            include_stale: false,
406            limit: None,
407            sort_window: DEFAULT_SORT_WINDOW,
408            progress: None,
409            cancel: None,
410        }
411    }
412
413    /// Honour `flag` as a cooperative cancellation signal. The pull loop checks
414    /// it before each new stream message; once set, the worker drains the
415    /// current window (committing the embedded slice) and returns with
416    /// `EmbedSummary { cancelled: true, .. }`. `pond embed` wires this to a
417    /// Ctrl-C handler so an interrupted run doesn't lose its in-memory window.
418    pub fn with_cancel(mut self, flag: Arc<AtomicBool>) -> Self {
419        self.cancel = Some(flag);
420        self
421    }
422
423    fn cancelled(&self) -> bool {
424        self.cancel
425            .as_ref()
426            .is_some_and(|f| f.load(Ordering::Relaxed))
427    }
428
429    /// Override the length-sort window (default [`DEFAULT_SORT_WINDOW`]). The
430    /// benchmark harness sweeps this to size the padding-waste vs. throughput
431    /// trade-off; a window of [`DEFAULT_BATCH_SIZE`] disables sorting.
432    pub fn with_sort_window(mut self, window: usize) -> Self {
433        self.sort_window = window.max(DEFAULT_BATCH_SIZE);
434        self
435    }
436
437    /// Register a per-batch progress callback. Called once after each
438    /// `flush()` with the messages in the just-finished batch and the running
439    /// totals. `pond embed` uses this to drive an `indicatif` progress bar.
440    pub fn with_progress(
441        mut self,
442        callback: impl Fn(BatchProgress) + Send + Sync + 'static,
443    ) -> Self {
444        self.progress = Some(Box::new(callback));
445        self
446    }
447
448    /// Cap the run at `limit` messages (default: no cap). The benchmark harness
449    /// uses this to embed a fixed, comparable slice of a corpus.
450    pub fn with_limit(mut self, limit: usize) -> Self {
451        self.limit = Some(limit.max(1));
452        self
453    }
454
455    pub fn include_stale(mut self) -> Self {
456        self.include_stale = true;
457        self
458    }
459
460    /// Embed every message whose `vector` is still null. Idempotent: a re-run
461    /// over an already-embedded corpus finds an empty backlog and is a no-op.
462    ///
463    /// Messages are pulled from a streaming scan, so peak memory is one stream
464    /// page plus the staged batch - not the whole corpus.
465    pub async fn run(&self) -> Result<EmbedSummary> {
466        let mut summary = EmbedSummary::default();
467        let mut window: Vec<PendingMessage> = Vec::with_capacity(self.sort_window);
468        let mut pulled = 0usize;
469
470        let mut stream = if self.include_stale {
471            Box::pin(self.store.pending_or_stale_messages())
472                as std::pin::Pin<Box<dyn tokio_stream::Stream<Item = Result<PendingMessage>> + '_>>
473        } else {
474            Box::pin(self.store.pending_embedding_messages())
475                as std::pin::Pin<Box<dyn tokio_stream::Stream<Item = Result<PendingMessage>> + '_>>
476        };
477        while let Some(pending) = stream.next().await {
478            // Stop pulling once the message cap is reached or cancellation
479            // fires; the staged window is still drained below, so the
480            // already-embedded slice commits cleanly.
481            if self.limit.is_some_and(|limit| pulled >= limit) || self.cancelled() {
482                break;
483            }
484            window.push(pending?);
485            pulled += 1;
486            if window.len() >= self.sort_window {
487                self.drain_window(&mut window, &mut summary).await?;
488            }
489        }
490        self.drain_window(&mut window, &mut summary).await?;
491        summary.cancelled = self.cancelled();
492
493        tracing::info!(
494            model = model_id(),
495            messages = summary.messages,
496            batches = summary.batches,
497            cancelled = summary.cancelled,
498            "embed worker finished",
499        );
500        Ok(summary)
501    }
502
503    /// One `merge_update` per window, not per 32-row batch: each
504    /// `merge_update` streams the target column once, so amortizing it over
505    /// a window-sized batch beats issuing it per model batch. The
506    /// length-sort clusters similar lengths because the tokenizer pads each
507    /// batch to its longest member. Empties `window`.
508    async fn drain_window(
509        &self,
510        window: &mut Vec<PendingMessage>,
511        summary: &mut EmbedSummary,
512    ) -> Result<()> {
513        if window.is_empty() {
514            return Ok(());
515        }
516        window.sort_unstable_by_key(|message| message.search_text.len());
517        let mut batch: Vec<PendingMessage> = Vec::with_capacity(DEFAULT_BATCH_SIZE);
518        let mut accumulator: Vec<EmbeddedMessage> = Vec::with_capacity(window.len());
519        for message in window.drain(..) {
520            batch.push(message);
521            if batch.len() >= DEFAULT_BATCH_SIZE {
522                accumulator.extend(self.embed_batch(&mut batch, summary).await?);
523            }
524        }
525        accumulator.extend(self.embed_batch(&mut batch, summary).await?);
526        if !accumulator.is_empty() {
527            self.store.write_embeddings(&accumulator).await?;
528        }
529        Ok(())
530    }
531
532    /// Run one model batch; return the rows. Store write is batched in
533    /// [`drain_window`](Self::drain_window), one `merge_update` per window.
534    async fn embed_batch(
535        &self,
536        batch: &mut Vec<PendingMessage>,
537        summary: &mut EmbedSummary,
538    ) -> Result<Vec<EmbeddedMessage>> {
539        if batch.is_empty() {
540            return Ok(Vec::new());
541        }
542        let pending = std::mem::take(batch);
543        // Apply e5's `passage: ` document prefix at the model boundary; the
544        // stored `search_text` keeps its uncapped, unprefixed form for FTS.
545        let texts = pending
546            .iter()
547            .map(|message| format_passage(&message.search_text))
548            .collect::<Vec<_>>();
549        let vectors = self.backend.embed(&texts)?;
550        if vectors.len() != pending.len() {
551            return Err(anyhow!(
552                "backend returned {} vectors for {} messages",
553                vectors.len(),
554                pending.len(),
555            ));
556        }
557        let rows = pending
558            .into_iter()
559            .zip(vectors)
560            .map(|(message, vector)| EmbeddedMessage {
561                session_id: message.session_id,
562                id: message.id,
563                vector,
564            })
565            .collect::<Vec<_>>();
566        let batch_messages = rows.len();
567        summary.messages += batch_messages;
568        summary.batches += 1;
569        if let Some(progress) = &self.progress {
570            progress(BatchProgress {
571                batch_messages,
572                total_messages: summary.messages,
573                total_batches: summary.batches,
574            });
575        }
576        Ok(rows)
577    }
578}
579
580#[cfg(test)]
581#[allow(clippy::unwrap_used)]
582mod tests {
583    use super::*;
584    use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
585
586    #[test]
587    fn e5_prefixes_apply_the_asymmetric_retrieval_pair() {
588        assert_eq!(
589            format_query("how does retry backoff work"),
590            "query: how does retry backoff work",
591        );
592        assert_eq!(
593            format_passage("retry uses exponential backoff"),
594            "passage: retry uses exponential backoff",
595        );
596    }
597
598    /// Counts how many times `LazyEmbedder` invokes its loader. Lets the
599    /// idle-eviction test detect reloads without spinning up a real model.
600    struct CountingEmbedder;
601    impl Embedder for CountingEmbedder {
602        fn device(&self) -> &str {
603            "test"
604        }
605        fn embed(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
606            Ok(vec![])
607        }
608    }
609
610    /// `LazyEmbedder` keys eviction on `std::time::Instant`, which isn't
611    /// affected by `tokio::time::pause`. The test uses a tiny real
612    /// threshold so the suite runs in <100 ms.
613    #[tokio::test(flavor = "multi_thread")]
614    async fn lazy_embedder_evicts_after_idle_threshold() {
615        let loads = Arc::new(AtomicUsize::new(0));
616        let counter = Arc::clone(&loads);
617        let loader: EmbedLoader = Arc::new(move || {
618            counter.fetch_add(1, AtomicOrdering::SeqCst);
619            Ok(Arc::new(CountingEmbedder) as Arc<dyn Embedder>)
620        });
621        let embedder =
622            LazyEmbedder::with_loader(loader).with_idle_threshold(Duration::from_millis(20));
623
624        embedder.get().await.unwrap();
625        assert_eq!(
626            loads.load(AtomicOrdering::SeqCst),
627            1,
628            "first get loads once"
629        );
630
631        embedder.get().await.unwrap();
632        assert_eq!(
633            loads.load(AtomicOrdering::SeqCst),
634            1,
635            "back-to-back get reuses the cached backend",
636        );
637
638        tokio::time::sleep(Duration::from_millis(60)).await;
639        embedder.get().await.unwrap();
640        assert_eq!(
641            loads.load(AtomicOrdering::SeqCst),
642            2,
643            "get after the idle threshold triggers a reload",
644        );
645    }
646
647    #[tokio::test(flavor = "multi_thread")]
648    async fn lazy_embedder_from_loaded_never_evicts() {
649        let preloaded = LazyEmbedder::from_loaded(Arc::new(CountingEmbedder));
650        preloaded.get().await.unwrap();
651        // Wait past any reasonable threshold; the from_loaded path uses
652        // Duration::MAX so the fake stays alive for the whole test.
653        tokio::time::sleep(Duration::from_millis(60)).await;
654        preloaded.get().await.unwrap();
655    }
656}