Skip to main content

pond/
embed.rs

1//! The embedding stage: candle XLM-RoBERTa FP16 ([`CandleEmbedder`]) plus
2//! the batch-oriented [`EmbedWorker`] that fills `messages.vector` /
3//! `messages.embedding_model` (spec.md#search). One message produces one
4//! vector - there is no chunking.
5//!
6//! [`LazyEmbedder`] caches a loaded backend for `pond mcp` / `pond serve`
7//! and drops it after [`DEFAULT_IDLE_EVICTION`] of no use. The drop is
8//! clean under macOS `phys_footprint` (post-drop drops to ~107 MiB
9//! regardless of backend), so time-weighted RSS over an interactive MCP
10//! session stays well under the per-instance budget despite the macOS
11//! Metal buffer pool's `iokit_mapped` retention during active queries.
12//!
13//! The worker accumulates messages and calls the model once per fixed-size
14//! batch, never once per message, and writes each batch's vectors to
15//! `messages` in one column-update commit.
16
17use std::sync::Arc;
18use std::sync::OnceLock;
19use std::sync::atomic::{AtomicBool, Ordering};
20use std::time::{Duration, Instant};
21
22use anyhow::{Context, Result, anyhow};
23use candle_core::{DType, Device, Tensor};
24use candle_nn::VarBuilder;
25use candle_transformers::models::xlm_roberta::{Config, XLMRobertaModel};
26use tokenizers::Tokenizer;
27use tokio::sync::Mutex;
28use tokio_stream::StreamExt;
29
30use crate::sessions::{EmbeddedMessage, PendingMessage, Store, embedding_dim};
31
32/// e5's training context. The tokenizer truncates input past it before
33/// inference - one message, one vector, bounded embed cost.
34pub(crate) const MAX_TOKENS: usize = 512;
35
36/// The candle e5 backend: XLM-RoBERTa FP16 weights on the GPU (Metal on
37/// macOS, CUDA on a `cuda`-feature non-macOS build, CPU otherwise).
38/// `forward` is `&self`, so no interior mutability is needed.
39pub struct CandleEmbedder {
40    model: XLMRobertaModel,
41    tokenizer: Tokenizer,
42    device: Device,
43}
44
45impl CandleEmbedder {
46    /// Load the configured XLM-RoBERTa model from HuggingFace (cached after
47    /// the first download) onto the best available device.
48    pub fn load() -> Result<Self> {
49        let device = select_device();
50        let id = model_id();
51        let api = hf_hub::api::sync::Api::new().context("init HuggingFace hub client")?;
52        let repo = api.model(id.to_owned());
53        let fetch = |file: &str| {
54            repo.get(file)
55                .with_context(|| format!("fetch {file} for {id}"))
56        };
57
58        let config: Config =
59            serde_json::from_str(&std::fs::read_to_string(fetch("config.json")?)?)?;
60        if config.hidden_size != embedding_dim() {
61            return Err(anyhow!(
62                "[embeddings].dim = {} but model {id:?} reports hidden_size = {}; \
63                 set [embeddings].dim to match the model's output width.",
64                embedding_dim(),
65                config.hidden_size,
66            ));
67        }
68        // mmap the safetensors file: candle's `safetensors::load` path uses
69        // `std::fs::read` which retains an owned `Vec<u8>` of the full FP32
70        // weights in the system allocator after drop on macOS. mmap avoids
71        // the owned-heap path. Note: candle's Metal pool retains FP32->F16
72        // cast transients regardless (iokit_mapped contribution to
73        // phys_footprint, candle-core/src/metal_backend/device.rs:44-57).
74        let model_path = fetch("model.safetensors")?;
75        #[allow(unsafe_code)]
76        let vb =
77            unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], DType::F16, &device)? };
78        let model = XLMRobertaModel::new(&config, vb)
79            .map_err(|error| anyhow!("load {id} weights: {error}"))?;
80
81        let mut tokenizer = Tokenizer::from_file(fetch("tokenizer.json")?)
82            .map_err(|error| anyhow!("load e5 tokenizer: {error}"))?;
83        tokenizer.with_padding(Some(tokenizers::PaddingParams {
84            strategy: tokenizers::PaddingStrategy::BatchLongest,
85            pad_id: config.pad_token_id,
86            ..Default::default()
87        }));
88        tokenizer
89            .with_truncation(Some(tokenizers::TruncationParams {
90                max_length: MAX_TOKENS,
91                ..Default::default()
92            }))
93            .map_err(|error| anyhow!("configure e5 tokenizer: {error}"))?;
94
95        tracing::info!(model = %id, device = device_label(&device), "loaded embedding model");
96        Ok(Self {
97            model,
98            tokenizer,
99            device,
100        })
101    }
102}
103
104impl Embedder for CandleEmbedder {
105    fn device(&self) -> &str {
106        device_label(&self.device)
107    }
108
109    fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
110        if texts.is_empty() {
111            return Ok(Vec::new());
112        }
113        let encodings = self
114            .tokenizer
115            .encode_batch(texts.to_vec(), true)
116            .map_err(|error| anyhow!("tokenize embedding batch: {error}"))?;
117        let mut ids = Vec::with_capacity(encodings.len());
118        let mut masks = Vec::with_capacity(encodings.len());
119        for encoding in &encodings {
120            ids.push(Tensor::new(encoding.get_ids(), &self.device)?);
121            masks.push(Tensor::new(encoding.get_attention_mask(), &self.device)?);
122        }
123        let input_ids = Tensor::stack(&ids, 0)?;
124        let attention_mask = Tensor::stack(&masks, 0)?;
125        let token_type_ids = input_ids.zeros_like()?;
126        let hidden = self
127            .model
128            .forward(
129                &input_ids,
130                &attention_mask,
131                &token_type_ids,
132                None,
133                None,
134                None,
135            )?
136            .to_dtype(DType::F32)?;
137        let mask = attention_mask.to_dtype(DType::F32)?.unsqueeze(2)?;
138        let summed = hidden.broadcast_mul(&mask)?.sum(1)?;
139        let counts = mask.sum(1)?;
140        let mean = summed.broadcast_div(&counts)?;
141        let norm = mean.sqr()?.sum_keepdim(1)?.sqrt()?;
142        mean.broadcast_div(&norm)?
143            .to_vec2::<f32>()
144            .map_err(|error| anyhow!("read embedding vectors: {error}"))
145    }
146}
147
148fn select_device() -> Device {
149    #[cfg(target_os = "macos")]
150    let device = Device::metal_if_available(0);
151    #[cfg(not(target_os = "macos"))]
152    let device = Device::cuda_if_available(0);
153    device.unwrap_or_else(|error| {
154        tracing::warn!(%error, "GPU device unavailable, falling back to CPU");
155        Device::Cpu
156    })
157}
158
159fn device_label(device: &Device) -> &'static str {
160    match device {
161        Device::Cpu => "cpu",
162        Device::Cuda(_) => "cuda",
163        Device::Metal(_) => "metal",
164    }
165}
166
167/// Arc-shared factory used by [`LazyEmbedder`] to build the backend on
168/// first call (or on reload after idle eviction). Arc so the loader can be
169/// cloned into `spawn_blocking` without consuming `&self`.
170type EmbedLoader = Arc<dyn Fn() -> Result<Arc<dyn Embedder>> + Send + Sync>;
171
172/// How long the cached backend can sit unused before [`LazyEmbedder::get`]
173/// drops it. One minute returns the ~790 MB model to the idle floor quickly
174/// between interactive-MCP bursts; the reload is one cached model-load
175/// (~358 ms) on the first query after a quiet window.
176pub const DEFAULT_IDLE_EVICTION: Duration = Duration::from_secs(60);
177
178struct CachedBackend {
179    backend: Arc<dyn Embedder>,
180    last_use: Instant,
181}
182
183/// Lazy holder for an [`Embedder`] with idle eviction. The model isn't
184/// loaded until the first hybrid/vector call asks for it - idle `pond mcp`
185/// / `pond serve` processes pay nothing while no vector queries land. After
186/// `idle_threshold` of inactivity the cached backend is dropped on the
187/// next `get` call; under macOS `phys_footprint` the drop reclaims
188/// ~365-585 MiB cleanly (the post-drop floor is ~107 MiB regardless of
189/// backend). Reload cost is one synchronous model-load (300-500 ms),
190/// absorbed inside the human-paced gap between MCP queries.
191pub struct LazyEmbedder {
192    loader: EmbedLoader,
193    state: Mutex<Option<CachedBackend>>,
194    idle_threshold: Duration,
195}
196
197impl LazyEmbedder {
198    /// candle XLM-RoBERTa FP16 (Metal on macOS / CUDA with `--features cuda`
199    /// / CPU otherwise). The pond default for every entry point.
200    pub fn candle() -> Self {
201        Self::with_loader(Arc::new(|| {
202            Ok(Arc::new(CandleEmbedder::load()?) as Arc<dyn Embedder>)
203        }))
204    }
205
206    /// Build a `LazyEmbedder` from an explicit loader. Used by the bench
207    /// harness to override the idle threshold; production callers use
208    /// [`Self::candle`].
209    pub fn with_loader(loader: EmbedLoader) -> Self {
210        Self {
211            loader,
212            state: Mutex::new(None),
213            idle_threshold: DEFAULT_IDLE_EVICTION,
214        }
215    }
216
217    /// Override the idle-eviction threshold. Pass `Duration::MAX` to disable
218    /// eviction entirely - useful in benches that want a stable steady-state.
219    #[must_use]
220    pub fn with_idle_threshold(mut self, threshold: Duration) -> Self {
221        self.idle_threshold = threshold;
222        self
223    }
224
225    /// Pre-seed with an already-constructed backend. Used by integration
226    /// tests that want to inject a fake `Embedder` without paying the real
227    /// model-load cost. Eviction is disabled so the test fake survives the
228    /// whole test even if a test stalls.
229    pub fn from_loaded(backend: Arc<dyn Embedder>) -> Self {
230        let preloaded = Arc::clone(&backend);
231        let loader: EmbedLoader = Arc::new(move || Ok(Arc::clone(&preloaded)));
232        Self {
233            loader,
234            state: Mutex::new(Some(CachedBackend {
235                backend,
236                last_use: Instant::now(),
237            })),
238            idle_threshold: Duration::MAX,
239        }
240    }
241
242    /// Load (on first call or after eviction) or return the cached handle.
243    /// The candle load is synchronous and blocking, so it runs on
244    /// `spawn_blocking`; the async caller sees a clean `await` point.
245    pub async fn get(&self) -> Result<Arc<dyn Embedder>> {
246        let mut state = self.state.lock().await;
247        let now = Instant::now();
248        if let Some(cached) = &*state
249            && now.duration_since(cached.last_use) > self.idle_threshold
250        {
251            tracing::info!(
252                idle_secs = self.idle_threshold.as_secs(),
253                "evicting idle embedder",
254            );
255            *state = None;
256        }
257        if let Some(cached) = state.as_mut() {
258            cached.last_use = now;
259            return Ok(Arc::clone(&cached.backend));
260        }
261        let loader = Arc::clone(&self.loader);
262        let backend = tokio::task::spawn_blocking(move || loader())
263            .await
264            .map_err(|join_error| anyhow!("embedder load panicked: {join_error}"))??;
265        *state = Some(CachedBackend {
266            backend: Arc::clone(&backend),
267            last_use: now,
268        });
269        Ok(backend)
270    }
271}
272
273/// Default embedding model pond ships a loader for (spec.md#search). Used when
274/// `[embeddings].model` is absent. `pond optimize` stamps the runtime model id
275/// (see [`model_id`]) into `messages.embedding_model` with every vector.
276/// e5-small (384-dim) is the default; the paraphrase benchmark set showed no
277/// statistically-significant quality loss vs e5-base while halving vector
278/// storage and ~halving model RSS.
279pub const DEFAULT_MODEL_ID: &str = "intfloat/multilingual-e5-small";
280
281/// Process-wide model id, seeded once at startup from `[embeddings].model` via
282/// [`init_model_id`]. `OnceLock` (not `const`) so a temporary config file can
283/// pick e5-small / e5-large for an experiment without touching every call site.
284/// Uninitialized -> [`DEFAULT_MODEL_ID`], keeping unit tests config-free.
285static MODEL_ID_RUNTIME: OnceLock<String> = OnceLock::new();
286
287/// The active model id. Returns the value installed by [`init_model_id`] or
288/// [`DEFAULT_MODEL_ID`] when nothing has installed one (tests, ad-hoc tooling).
289pub fn model_id() -> &'static str {
290    MODEL_ID_RUNTIME
291        .get()
292        .map(String::as_str)
293        .unwrap_or(DEFAULT_MODEL_ID)
294}
295
296/// Seed [`model_id`] from config. First call wins; later calls with a different
297/// id are silently ignored - the process loads its config once.
298pub fn init_model_id(id: String) {
299    MODEL_ID_RUNTIME.get_or_init(|| id);
300}
301
302/// Messages per model-inference + write batch. e5 truncates at 512 tokens, so
303/// a 32-row batch's padded attention transient stays bounded.
304pub const DEFAULT_BATCH_SIZE: usize = 32;
305
306/// Messages buffered and length-sorted before being cut into model batches.
307/// The tokenizer pads every batch to its longest member, so a batch mixing a short
308/// and a long message embeds the short one at the long one's length. Sorting a
309/// window first clusters similar-length messages, so each batch pads near its
310/// own longest, not the corpus worst case. Bounded so peak memory stays one
311/// window, not the whole backlog. See [`EmbedWorker::with_sort_window`].
312pub const DEFAULT_SORT_WINDOW: usize = 2048;
313
314/// Format a search query for the embedder. e5 is an asymmetric retriever:
315/// its model card prescribes `query: ` on the search side, `passage: ` on
316/// documents. Used by `pond_search` to prepare the query text before the
317/// candle/Metal embed call.
318pub fn format_query(query: &str) -> String {
319    format!("query: {query}")
320}
321
322/// Format a document (one message's `search_text`) for the embedder - the
323/// `passage: ` half of the pair documented on [`format_query`]. Used by
324/// `EmbedWorker` when batching messages for `pond optimize`.
325pub fn format_passage(text: &str) -> String {
326    format!("passage: {text}")
327}
328
329/// The embedding seam (spec.md#search): text in, vectors out. The real
330/// backend is [`CandleEmbedder`]; tests substitute an instrumented fake
331/// to assert batching behavior. The vector width is checked at the write
332/// boundary and the model id is whatever [`model_id`] returns at the
333/// time of the write.
334pub trait Embedder: Send + Sync {
335    /// A short label naming the hardware/runtime: `"metal"`, `"cuda"`,
336    /// or `"cpu"`. Used by `pond optimize` to surface what backend ran the
337    /// inference; benches print it alongside latency.
338    fn device(&self) -> &str;
339
340    /// Embed a batch of texts. The returned vectors are L2-normalized and
341    /// [`embedding_dim`] long, one per input.
342    fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
343}
344
345/// Outcome of an [`EmbedWorker::run`] pass.
346#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
347pub struct EmbedSummary {
348    /// Messages embedded; one vector each.
349    pub messages: usize,
350    /// Model-inference + write batches issued.
351    pub batches: usize,
352    /// Set when the run exited via the cancel flag instead of stream end -
353    /// the caller uses this to print an interrupted notice and decide whether
354    /// to still rebuild downstream indices.
355    pub cancelled: bool,
356}
357
358/// Per-batch stats handed to a progress callback. Lets `pond optimize` drive an
359/// `indicatif` bar without leaking the crate into this module's API.
360#[derive(Debug, Clone, Copy)]
361pub struct BatchProgress {
362    /// Messages embedded in this batch.
363    pub batch_messages: usize,
364    /// Running message total across the run.
365    pub total_messages: usize,
366    /// Running batch count across the run.
367    pub total_batches: usize,
368}
369
370type ProgressFn = Box<dyn Fn(BatchProgress) + Send + Sync>;
371
372/// Fills `messages.vector` / `messages.embedding_model` for the backlog of
373/// un-embedded messages. Reads `messages.search_text` directly, batches it
374/// through the backend one vector each, and writes each batch back to
375/// `messages` by primary key.
376pub struct EmbedWorker<'a, B: Embedder> {
377    store: &'a Store,
378    backend: &'a B,
379    include_stale: bool,
380    /// Optional cap on total messages embedded in one `run` - `None` in
381    /// production (embed everything), set by the benchmark harness to a fixed
382    /// count so a run is a stable, comparable workload.
383    limit: Option<usize>,
384    /// Messages buffered and length-sorted per `drain_window` pass
385    /// ([`DEFAULT_SORT_WINDOW`]); the benchmark sweeps it through
386    /// [`EmbedWorker::with_sort_window`].
387    sort_window: usize,
388    /// Optional per-batch progress callback. Called once per `flush()` with
389    /// the running totals; `pond optimize` wires this to an `indicatif` bar.
390    progress: Option<ProgressFn>,
391    /// Set externally (Ctrl-C handler in `pond optimize`): the pull loop drains
392    /// the in-memory window before exiting so partial work is committed.
393    cancel: Option<Arc<AtomicBool>>,
394}
395
396impl<'a, B: Embedder> EmbedWorker<'a, B> {
397    /// Build a worker over `store`'s un-embedded backlog. A backend whose
398    /// vectors are the wrong width is rejected at the write boundary
399    /// (`embedding_update_batch`), so there is nothing to validate here.
400    pub fn new(store: &'a Store, backend: &'a B) -> Self {
401        Self {
402            store,
403            backend,
404            include_stale: false,
405            limit: None,
406            sort_window: DEFAULT_SORT_WINDOW,
407            progress: None,
408            cancel: None,
409        }
410    }
411
412    /// Honour `flag` as a cooperative cancellation signal. The pull loop checks
413    /// it before each new stream message; once set, the worker drains the
414    /// current window (committing the embedded slice) and returns with
415    /// `EmbedSummary { cancelled: true, .. }`. `pond optimize` wires this to a
416    /// Ctrl-C handler so an interrupted run doesn't lose its in-memory window.
417    pub fn with_cancel(mut self, flag: Arc<AtomicBool>) -> Self {
418        self.cancel = Some(flag);
419        self
420    }
421
422    fn cancelled(&self) -> bool {
423        self.cancel
424            .as_ref()
425            .is_some_and(|f| f.load(Ordering::Relaxed))
426    }
427
428    /// Override the length-sort window (default [`DEFAULT_SORT_WINDOW`]). The
429    /// benchmark harness sweeps this to size the padding-waste vs. throughput
430    /// trade-off; a window of [`DEFAULT_BATCH_SIZE`] disables sorting.
431    pub fn with_sort_window(mut self, window: usize) -> Self {
432        self.sort_window = window.max(DEFAULT_BATCH_SIZE);
433        self
434    }
435
436    /// Register a per-batch progress callback. Called once after each
437    /// `flush()` with the messages in the just-finished batch and the running
438    /// totals. `pond optimize` uses this to drive an `indicatif` progress bar.
439    pub fn with_progress(
440        mut self,
441        callback: impl Fn(BatchProgress) + Send + Sync + 'static,
442    ) -> Self {
443        self.progress = Some(Box::new(callback));
444        self
445    }
446
447    /// Cap the run at `limit` messages (default: no cap). The benchmark harness
448    /// uses this to embed a fixed, comparable slice of a corpus.
449    pub fn with_limit(mut self, limit: usize) -> Self {
450        self.limit = Some(limit.max(1));
451        self
452    }
453
454    pub fn include_stale(mut self) -> Self {
455        self.include_stale = true;
456        self
457    }
458
459    /// Embed every message whose `vector` is still null. Idempotent: a re-run
460    /// over an already-embedded corpus finds an empty backlog and is a no-op.
461    ///
462    /// Messages are pulled from a streaming scan, so peak memory is one stream
463    /// page plus the staged batch - not the whole corpus.
464    pub async fn run(&self) -> Result<EmbedSummary> {
465        let mut summary = EmbedSummary::default();
466        let mut window: Vec<PendingMessage> = Vec::with_capacity(self.sort_window);
467        let mut pulled = 0usize;
468
469        let mut stream = if self.include_stale {
470            Box::pin(self.store.pending_or_stale_messages())
471                as std::pin::Pin<Box<dyn tokio_stream::Stream<Item = Result<PendingMessage>> + '_>>
472        } else {
473            Box::pin(self.store.pending_embedding_messages())
474                as std::pin::Pin<Box<dyn tokio_stream::Stream<Item = Result<PendingMessage>> + '_>>
475        };
476        while let Some(pending) = stream.next().await {
477            // Stop pulling once the message cap is reached or cancellation
478            // fires; the staged window is still drained below, so the
479            // already-embedded slice commits cleanly.
480            if self.limit.is_some_and(|limit| pulled >= limit) || self.cancelled() {
481                break;
482            }
483            window.push(pending?);
484            pulled += 1;
485            if window.len() >= self.sort_window {
486                self.drain_window(&mut window, &mut summary).await?;
487            }
488        }
489        self.drain_window(&mut window, &mut summary).await?;
490        summary.cancelled = self.cancelled();
491
492        tracing::info!(
493            model = model_id(),
494            messages = summary.messages,
495            batches = summary.batches,
496            cancelled = summary.cancelled,
497            "embed worker finished",
498        );
499        Ok(summary)
500    }
501
502    /// One `merge_update` per window, not per 32-row batch: each
503    /// `merge_update` streams the target column once, so amortizing it over
504    /// a window-sized batch beats issuing it per model batch. The
505    /// length-sort clusters similar lengths because the tokenizer pads each
506    /// batch to its longest member. Empties `window`.
507    async fn drain_window(
508        &self,
509        window: &mut Vec<PendingMessage>,
510        summary: &mut EmbedSummary,
511    ) -> Result<()> {
512        if window.is_empty() {
513            return Ok(());
514        }
515        window.sort_unstable_by_key(|message| message.search_text.len());
516        let mut batch: Vec<PendingMessage> = Vec::with_capacity(DEFAULT_BATCH_SIZE);
517        let mut accumulator: Vec<EmbeddedMessage> = Vec::with_capacity(window.len());
518        for message in window.drain(..) {
519            batch.push(message);
520            if batch.len() >= DEFAULT_BATCH_SIZE {
521                accumulator.extend(self.embed_batch(&mut batch, summary).await?);
522            }
523        }
524        accumulator.extend(self.embed_batch(&mut batch, summary).await?);
525        if !accumulator.is_empty() {
526            self.store.write_embeddings(&accumulator).await?;
527        }
528        Ok(())
529    }
530
531    /// Run one model batch; return the rows. Store write is batched in
532    /// [`drain_window`](Self::drain_window), one `merge_update` per window.
533    async fn embed_batch(
534        &self,
535        batch: &mut Vec<PendingMessage>,
536        summary: &mut EmbedSummary,
537    ) -> Result<Vec<EmbeddedMessage>> {
538        if batch.is_empty() {
539            return Ok(Vec::new());
540        }
541        let pending = std::mem::take(batch);
542        // Apply e5's `passage: ` document prefix at the model boundary; the
543        // stored `search_text` keeps its uncapped, unprefixed form for FTS.
544        let texts = pending
545            .iter()
546            .map(|message| format_passage(&message.search_text))
547            .collect::<Vec<_>>();
548        let vectors = self.backend.embed(&texts)?;
549        if vectors.len() != pending.len() {
550            return Err(anyhow!(
551                "backend returned {} vectors for {} messages",
552                vectors.len(),
553                pending.len(),
554            ));
555        }
556        let rows = pending
557            .into_iter()
558            .zip(vectors)
559            .map(|(message, vector)| EmbeddedMessage {
560                session_id: message.session_id,
561                id: message.id,
562                vector,
563            })
564            .collect::<Vec<_>>();
565        let batch_messages = rows.len();
566        summary.messages += batch_messages;
567        summary.batches += 1;
568        if let Some(progress) = &self.progress {
569            progress(BatchProgress {
570                batch_messages,
571                total_messages: summary.messages,
572                total_batches: summary.batches,
573            });
574        }
575        Ok(rows)
576    }
577}
578
579#[cfg(test)]
580#[allow(clippy::unwrap_used)]
581mod tests {
582    use super::*;
583    use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
584
585    #[test]
586    fn e5_prefixes_apply_the_asymmetric_retrieval_pair() {
587        assert_eq!(
588            format_query("how does retry backoff work"),
589            "query: how does retry backoff work",
590        );
591        assert_eq!(
592            format_passage("retry uses exponential backoff"),
593            "passage: retry uses exponential backoff",
594        );
595    }
596
597    /// Counts how many times `LazyEmbedder` invokes its loader. Lets the
598    /// idle-eviction test detect reloads without spinning up a real model.
599    struct CountingEmbedder;
600    impl Embedder for CountingEmbedder {
601        fn device(&self) -> &str {
602            "test"
603        }
604        fn embed(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
605            Ok(vec![])
606        }
607    }
608
609    /// `LazyEmbedder` keys eviction on `std::time::Instant`, which isn't
610    /// affected by `tokio::time::pause`. The test uses a tiny real
611    /// threshold so the suite runs in <100 ms.
612    #[tokio::test(flavor = "multi_thread")]
613    async fn lazy_embedder_evicts_after_idle_threshold() {
614        let loads = Arc::new(AtomicUsize::new(0));
615        let counter = Arc::clone(&loads);
616        let loader: EmbedLoader = Arc::new(move || {
617            counter.fetch_add(1, AtomicOrdering::SeqCst);
618            Ok(Arc::new(CountingEmbedder) as Arc<dyn Embedder>)
619        });
620        let embedder =
621            LazyEmbedder::with_loader(loader).with_idle_threshold(Duration::from_millis(20));
622
623        embedder.get().await.unwrap();
624        assert_eq!(
625            loads.load(AtomicOrdering::SeqCst),
626            1,
627            "first get loads once"
628        );
629
630        embedder.get().await.unwrap();
631        assert_eq!(
632            loads.load(AtomicOrdering::SeqCst),
633            1,
634            "back-to-back get reuses the cached backend",
635        );
636
637        tokio::time::sleep(Duration::from_millis(60)).await;
638        embedder.get().await.unwrap();
639        assert_eq!(
640            loads.load(AtomicOrdering::SeqCst),
641            2,
642            "get after the idle threshold triggers a reload",
643        );
644    }
645
646    #[tokio::test(flavor = "multi_thread")]
647    async fn lazy_embedder_from_loaded_never_evicts() {
648        let preloaded = LazyEmbedder::from_loaded(Arc::new(CountingEmbedder));
649        preloaded.get().await.unwrap();
650        // Wait past any reasonable threshold; the from_loaded path uses
651        // Duration::MAX so the fake stays alive for the whole test.
652        tokio::time::sleep(Duration::from_millis(60)).await;
653        preloaded.get().await.unwrap();
654    }
655}