pond/embed.rs
1//! The embedding stage: candle XLM-RoBERTa FP16 ([`CandleEmbedder`]) plus
2//! the batch-oriented [`EmbedWorker`] that fills `messages.vector` /
3//! `messages.embedding_model` (spec.md#search). One message produces one
4//! vector - there is no chunking.
5//!
6//! [`LazyEmbedder`] caches a loaded backend for `pond mcp` / `pond serve`
7//! and drops it after [`DEFAULT_IDLE_EVICTION`] of no use. The drop is
8//! clean under macOS `phys_footprint` (post-drop drops to ~107 MiB
9//! regardless of backend), so time-weighted RSS over an interactive MCP
10//! session stays well under the per-instance budget despite the macOS
11//! Metal buffer pool's `iokit_mapped` retention during active queries.
12//!
13//! The worker accumulates messages and calls the model once per fixed-size
14//! batch, never once per message, and writes each batch's vectors to
15//! `messages` in one column-update commit.
16
17use std::sync::Arc;
18use std::sync::OnceLock;
19use std::sync::atomic::{AtomicBool, Ordering};
20use std::time::{Duration, Instant};
21
22use anyhow::{Context, Result, anyhow};
23use candle_core::{DType, Device, Tensor};
24use candle_nn::VarBuilder;
25use candle_transformers::models::xlm_roberta::{Config, XLMRobertaModel};
26use tokenizers::Tokenizer;
27use tokio::sync::Mutex;
28use tokio_stream::StreamExt;
29
30use crate::sessions::{EmbeddedMessage, PendingMessage, Store, embedding_dim};
31
32/// e5's training context. The tokenizer truncates input past it before
33/// inference - one message, one vector, bounded embed cost.
34pub(crate) const MAX_TOKENS: usize = 512;
35
36/// The candle e5 backend: XLM-RoBERTa FP16 weights on the GPU (Metal on
37/// macOS, CUDA on a `cuda`-feature non-macOS build, CPU otherwise).
38/// `forward` is `&self`, so no interior mutability is needed.
39pub struct CandleEmbedder {
40 model: XLMRobertaModel,
41 tokenizer: Tokenizer,
42 device: Device,
43}
44
45impl CandleEmbedder {
46 /// Load the configured XLM-RoBERTa model from HuggingFace (cached after
47 /// the first download) onto the best available device.
48 pub fn load() -> Result<Self> {
49 let device = select_device();
50 let id = model_id();
51 let api = hf_hub::api::sync::Api::new().context("init HuggingFace hub client")?;
52 let repo = api.model(id.to_owned());
53 let fetch = |file: &str| {
54 repo.get(file)
55 .with_context(|| format!("fetch {file} for {id}"))
56 };
57
58 let config: Config =
59 serde_json::from_str(&std::fs::read_to_string(fetch("config.json")?)?)?;
60 if config.hidden_size != embedding_dim() {
61 return Err(anyhow!(
62 "[embeddings].dim = {} but model {id:?} reports hidden_size = {}; \
63 set [embeddings].dim to match the model's output width.",
64 embedding_dim(),
65 config.hidden_size,
66 ));
67 }
68 // mmap the safetensors file: candle's `safetensors::load` path uses
69 // `std::fs::read` which retains an owned `Vec<u8>` of the full FP32
70 // weights in the system allocator after drop on macOS. mmap avoids
71 // the owned-heap path. Note: candle's Metal pool retains FP32->F16
72 // cast transients regardless (iokit_mapped contribution to
73 // phys_footprint, candle-core/src/metal_backend/device.rs:44-57).
74 let model_path = fetch("model.safetensors")?;
75 #[allow(unsafe_code)]
76 let vb =
77 unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], DType::F16, &device)? };
78 let model = XLMRobertaModel::new(&config, vb)
79 .map_err(|error| anyhow!("load {id} weights: {error}"))?;
80
81 let mut tokenizer = Tokenizer::from_file(fetch("tokenizer.json")?)
82 .map_err(|error| anyhow!("load e5 tokenizer: {error}"))?;
83 tokenizer.with_padding(Some(tokenizers::PaddingParams {
84 strategy: tokenizers::PaddingStrategy::BatchLongest,
85 pad_id: config.pad_token_id,
86 ..Default::default()
87 }));
88 tokenizer
89 .with_truncation(Some(tokenizers::TruncationParams {
90 max_length: MAX_TOKENS,
91 ..Default::default()
92 }))
93 .map_err(|error| anyhow!("configure e5 tokenizer: {error}"))?;
94
95 tracing::info!(model = %id, device = device_label(&device), "loaded embedding model");
96 Ok(Self {
97 model,
98 tokenizer,
99 device,
100 })
101 }
102}
103
104impl Embedder for CandleEmbedder {
105 fn device(&self) -> &str {
106 device_label(&self.device)
107 }
108
109 fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
110 if texts.is_empty() {
111 return Ok(Vec::new());
112 }
113 let encodings = self
114 .tokenizer
115 .encode_batch(texts.to_vec(), true)
116 .map_err(|error| anyhow!("tokenize embedding batch: {error}"))?;
117 let mut ids = Vec::with_capacity(encodings.len());
118 let mut masks = Vec::with_capacity(encodings.len());
119 for encoding in &encodings {
120 ids.push(Tensor::new(encoding.get_ids(), &self.device)?);
121 masks.push(Tensor::new(encoding.get_attention_mask(), &self.device)?);
122 }
123 let input_ids = Tensor::stack(&ids, 0)?;
124 let attention_mask = Tensor::stack(&masks, 0)?;
125 let token_type_ids = input_ids.zeros_like()?;
126 let hidden = self
127 .model
128 .forward(
129 &input_ids,
130 &attention_mask,
131 &token_type_ids,
132 None,
133 None,
134 None,
135 )?
136 .to_dtype(DType::F32)?;
137 let mask = attention_mask.to_dtype(DType::F32)?.unsqueeze(2)?;
138 let summed = hidden.broadcast_mul(&mask)?.sum(1)?;
139 let counts = mask.sum(1)?;
140 let mean = summed.broadcast_div(&counts)?;
141 let norm = mean.sqr()?.sum_keepdim(1)?.sqrt()?;
142 mean.broadcast_div(&norm)?
143 .to_vec2::<f32>()
144 .map_err(|error| anyhow!("read embedding vectors: {error}"))
145 }
146}
147
148fn select_device() -> Device {
149 #[cfg(target_os = "macos")]
150 let device = Device::metal_if_available(0);
151 #[cfg(not(target_os = "macos"))]
152 let device = Device::cuda_if_available(0);
153 device.unwrap_or_else(|error| {
154 tracing::warn!(%error, "GPU device unavailable, falling back to CPU");
155 Device::Cpu
156 })
157}
158
159fn device_label(device: &Device) -> &'static str {
160 match device {
161 Device::Cpu => "cpu",
162 Device::Cuda(_) => "cuda",
163 Device::Metal(_) => "metal",
164 }
165}
166
167/// Arc-shared factory used by [`LazyEmbedder`] to build the backend on
168/// first call (or on reload after idle eviction). Arc so the loader can be
169/// cloned into `spawn_blocking` without consuming `&self`.
170type EmbedLoader = Arc<dyn Fn() -> Result<Arc<dyn Embedder>> + Send + Sync>;
171
172/// How long the cached backend can sit unused before [`LazyEmbedder::get`]
173/// drops it. Five minutes matches typical interactive-MCP conversational
174/// pauses: short enough that a model that's been unused for a turn or two
175/// is gone before the next quiet window, long enough that ordinary
176/// query bursts never pay the reload cost.
177pub const DEFAULT_IDLE_EVICTION: Duration = Duration::from_secs(300);
178
179struct CachedBackend {
180 backend: Arc<dyn Embedder>,
181 last_use: Instant,
182}
183
184/// Lazy holder for an [`Embedder`] with idle eviction. The model isn't
185/// loaded until the first hybrid/vector call asks for it - idle `pond mcp`
186/// / `pond serve` processes pay nothing while no vector queries land. After
187/// `idle_threshold` of inactivity the cached backend is dropped on the
188/// next `get` call; under macOS `phys_footprint` the drop reclaims
189/// ~365-585 MiB cleanly (the post-drop floor is ~107 MiB regardless of
190/// backend). Reload cost is one synchronous model-load (300-500 ms),
191/// absorbed inside the human-paced gap between MCP queries.
192pub struct LazyEmbedder {
193 loader: EmbedLoader,
194 state: Mutex<Option<CachedBackend>>,
195 idle_threshold: Duration,
196}
197
198impl LazyEmbedder {
199 /// candle XLM-RoBERTa FP16 (Metal on macOS / CUDA with `--features cuda`
200 /// / CPU otherwise). The pond default for every entry point.
201 pub fn candle() -> Self {
202 Self::with_loader(Arc::new(|| {
203 Ok(Arc::new(CandleEmbedder::load()?) as Arc<dyn Embedder>)
204 }))
205 }
206
207 /// Build a `LazyEmbedder` from an explicit loader. Used by the bench
208 /// harness to override the idle threshold; production callers use
209 /// [`Self::candle`].
210 pub fn with_loader(loader: EmbedLoader) -> Self {
211 Self {
212 loader,
213 state: Mutex::new(None),
214 idle_threshold: DEFAULT_IDLE_EVICTION,
215 }
216 }
217
218 /// Override the idle-eviction threshold. Pass `Duration::MAX` to disable
219 /// eviction entirely - useful in benches that want a stable steady-state.
220 #[must_use]
221 pub fn with_idle_threshold(mut self, threshold: Duration) -> Self {
222 self.idle_threshold = threshold;
223 self
224 }
225
226 /// Pre-seed with an already-constructed backend. Used by integration
227 /// tests that want to inject a fake `Embedder` without paying the real
228 /// model-load cost. Eviction is disabled so the test fake survives the
229 /// whole test even if a test stalls.
230 pub fn from_loaded(backend: Arc<dyn Embedder>) -> Self {
231 let preloaded = Arc::clone(&backend);
232 let loader: EmbedLoader = Arc::new(move || Ok(Arc::clone(&preloaded)));
233 Self {
234 loader,
235 state: Mutex::new(Some(CachedBackend {
236 backend,
237 last_use: Instant::now(),
238 })),
239 idle_threshold: Duration::MAX,
240 }
241 }
242
243 /// Load (on first call or after eviction) or return the cached handle.
244 /// The candle load is synchronous and blocking, so it runs on
245 /// `spawn_blocking`; the async caller sees a clean `await` point.
246 pub async fn get(&self) -> Result<Arc<dyn Embedder>> {
247 let mut state = self.state.lock().await;
248 let now = Instant::now();
249 if let Some(cached) = &*state
250 && now.duration_since(cached.last_use) > self.idle_threshold
251 {
252 tracing::info!(
253 idle_secs = self.idle_threshold.as_secs(),
254 "evicting idle embedder",
255 );
256 *state = None;
257 }
258 if let Some(cached) = state.as_mut() {
259 cached.last_use = now;
260 return Ok(Arc::clone(&cached.backend));
261 }
262 let loader = Arc::clone(&self.loader);
263 let backend = tokio::task::spawn_blocking(move || loader())
264 .await
265 .map_err(|join_error| anyhow!("embedder load panicked: {join_error}"))??;
266 *state = Some(CachedBackend {
267 backend: Arc::clone(&backend),
268 last_use: now,
269 });
270 Ok(backend)
271 }
272}
273
274/// Default embedding model pond ships a loader for (spec.md#search). Used when
275/// `[embeddings].model` is absent. `pond embed` stamps the runtime model id
276/// (see [`model_id`]) into `messages.embedding_model` with every vector.
277/// e5-small (384-dim) is the default; the paraphrase benchmark set showed no
278/// statistically-significant quality loss vs e5-base while halving vector
279/// storage and ~halving model RSS.
280pub const DEFAULT_MODEL_ID: &str = "intfloat/multilingual-e5-small";
281
282/// Process-wide model id, seeded once at startup from `[embeddings].model` via
283/// [`init_model_id`]. `OnceLock` (not `const`) so a temporary config file can
284/// pick e5-small / e5-large for an experiment without touching every call site.
285/// Uninitialized -> [`DEFAULT_MODEL_ID`], keeping unit tests config-free.
286static MODEL_ID_RUNTIME: OnceLock<String> = OnceLock::new();
287
288/// The active model id. Returns the value installed by [`init_model_id`] or
289/// [`DEFAULT_MODEL_ID`] when nothing has installed one (tests, ad-hoc tooling).
290pub fn model_id() -> &'static str {
291 MODEL_ID_RUNTIME
292 .get()
293 .map(String::as_str)
294 .unwrap_or(DEFAULT_MODEL_ID)
295}
296
297/// Seed [`model_id`] from config. First call wins; later calls with a different
298/// id are silently ignored - the process loads its config once.
299pub fn init_model_id(id: String) {
300 MODEL_ID_RUNTIME.get_or_init(|| id);
301}
302
303/// Messages per model-inference + write batch. e5 truncates at 512 tokens, so
304/// a 32-row batch's padded attention transient stays bounded.
305pub const DEFAULT_BATCH_SIZE: usize = 32;
306
307/// Messages buffered and length-sorted before being cut into model batches.
308/// The tokenizer pads every batch to its longest member, so a batch mixing a short
309/// and a long message embeds the short one at the long one's length. Sorting a
310/// window first clusters similar-length messages, so each batch pads near its
311/// own longest, not the corpus worst case. Bounded so peak memory stays one
312/// window, not the whole backlog. See [`EmbedWorker::with_sort_window`].
313pub const DEFAULT_SORT_WINDOW: usize = 2048;
314
315/// Format a search query for the embedder. e5 is an asymmetric retriever:
316/// its model card prescribes `query: ` on the search side, `passage: ` on
317/// documents. Used by `pond_search` to prepare the query text before the
318/// candle/Metal embed call.
319pub fn format_query(query: &str) -> String {
320 format!("query: {query}")
321}
322
323/// Format a document (one message's `search_text`) for the embedder - the
324/// `passage: ` half of the pair documented on [`format_query`]. Used by
325/// `EmbedWorker` when batching messages for `pond embed`.
326pub fn format_passage(text: &str) -> String {
327 format!("passage: {text}")
328}
329
330/// The embedding seam (spec.md#search): text in, vectors out. The real
331/// backend is [`CandleEmbedder`]; tests substitute an instrumented fake
332/// to assert batching behavior. The vector width is checked at the write
333/// boundary and the model id is whatever [`model_id`] returns at the
334/// time of the write.
335pub trait Embedder: Send + Sync {
336 /// A short label naming the hardware/runtime: `"metal"`, `"cuda"`,
337 /// or `"cpu"`. Used by `pond embed` to surface what backend ran the
338 /// inference; benches print it alongside latency.
339 fn device(&self) -> &str;
340
341 /// Embed a batch of texts. The returned vectors are L2-normalized and
342 /// [`embedding_dim`] long, one per input.
343 fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
344}
345
346/// Outcome of an [`EmbedWorker::run`] pass.
347#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
348pub struct EmbedSummary {
349 /// Messages embedded; one vector each.
350 pub messages: usize,
351 /// Model-inference + write batches issued.
352 pub batches: usize,
353 /// Set when the run exited via the cancel flag instead of stream end -
354 /// the caller uses this to print an interrupted notice and decide whether
355 /// to still rebuild downstream indices.
356 pub cancelled: bool,
357}
358
359/// Per-batch stats handed to a progress callback. Lets `pond embed` drive an
360/// `indicatif` bar without leaking the crate into this module's API.
361#[derive(Debug, Clone, Copy)]
362pub struct BatchProgress {
363 /// Messages embedded in this batch.
364 pub batch_messages: usize,
365 /// Running message total across the run.
366 pub total_messages: usize,
367 /// Running batch count across the run.
368 pub total_batches: usize,
369}
370
371type ProgressFn = Box<dyn Fn(BatchProgress) + Send + Sync>;
372
373/// Fills `messages.vector` / `messages.embedding_model` for the backlog of
374/// un-embedded messages. Reads `messages.search_text` directly, batches it
375/// through the backend one vector each, and writes each batch back to
376/// `messages` by primary key.
377pub struct EmbedWorker<'a, B: Embedder> {
378 store: &'a Store,
379 backend: &'a B,
380 include_stale: bool,
381 /// Optional cap on total messages embedded in one `run` - `None` in
382 /// production (embed everything), set by the benchmark harness to a fixed
383 /// count so a run is a stable, comparable workload.
384 limit: Option<usize>,
385 /// Messages buffered and length-sorted per `drain_window` pass
386 /// ([`DEFAULT_SORT_WINDOW`]); the benchmark sweeps it through
387 /// [`EmbedWorker::with_sort_window`].
388 sort_window: usize,
389 /// Optional per-batch progress callback. Called once per `flush()` with
390 /// the running totals; `pond embed` wires this to an `indicatif` bar.
391 progress: Option<ProgressFn>,
392 /// Set externally (Ctrl-C handler in `pond embed`): the pull loop drains
393 /// the in-memory window before exiting so partial work is committed.
394 cancel: Option<Arc<AtomicBool>>,
395}
396
397impl<'a, B: Embedder> EmbedWorker<'a, B> {
398 /// Build a worker over `store`'s un-embedded backlog. A backend whose
399 /// vectors are the wrong width is rejected at the write boundary
400 /// (`embedding_update_batch`), so there is nothing to validate here.
401 pub fn new(store: &'a Store, backend: &'a B) -> Self {
402 Self {
403 store,
404 backend,
405 include_stale: false,
406 limit: None,
407 sort_window: DEFAULT_SORT_WINDOW,
408 progress: None,
409 cancel: None,
410 }
411 }
412
413 /// Honour `flag` as a cooperative cancellation signal. The pull loop checks
414 /// it before each new stream message; once set, the worker drains the
415 /// current window (committing the embedded slice) and returns with
416 /// `EmbedSummary { cancelled: true, .. }`. `pond embed` wires this to a
417 /// Ctrl-C handler so an interrupted run doesn't lose its in-memory window.
418 pub fn with_cancel(mut self, flag: Arc<AtomicBool>) -> Self {
419 self.cancel = Some(flag);
420 self
421 }
422
423 fn cancelled(&self) -> bool {
424 self.cancel
425 .as_ref()
426 .is_some_and(|f| f.load(Ordering::Relaxed))
427 }
428
429 /// Override the length-sort window (default [`DEFAULT_SORT_WINDOW`]). The
430 /// benchmark harness sweeps this to size the padding-waste vs. throughput
431 /// trade-off; a window of [`DEFAULT_BATCH_SIZE`] disables sorting.
432 pub fn with_sort_window(mut self, window: usize) -> Self {
433 self.sort_window = window.max(DEFAULT_BATCH_SIZE);
434 self
435 }
436
437 /// Register a per-batch progress callback. Called once after each
438 /// `flush()` with the messages in the just-finished batch and the running
439 /// totals. `pond embed` uses this to drive an `indicatif` progress bar.
440 pub fn with_progress(
441 mut self,
442 callback: impl Fn(BatchProgress) + Send + Sync + 'static,
443 ) -> Self {
444 self.progress = Some(Box::new(callback));
445 self
446 }
447
448 /// Cap the run at `limit` messages (default: no cap). The benchmark harness
449 /// uses this to embed a fixed, comparable slice of a corpus.
450 pub fn with_limit(mut self, limit: usize) -> Self {
451 self.limit = Some(limit.max(1));
452 self
453 }
454
455 pub fn include_stale(mut self) -> Self {
456 self.include_stale = true;
457 self
458 }
459
460 /// Embed every message whose `vector` is still null. Idempotent: a re-run
461 /// over an already-embedded corpus finds an empty backlog and is a no-op.
462 ///
463 /// Messages are pulled from a streaming scan, so peak memory is one stream
464 /// page plus the staged batch - not the whole corpus.
465 pub async fn run(&self) -> Result<EmbedSummary> {
466 let mut summary = EmbedSummary::default();
467 let mut window: Vec<PendingMessage> = Vec::with_capacity(self.sort_window);
468 let mut pulled = 0usize;
469
470 let mut stream = if self.include_stale {
471 Box::pin(self.store.pending_or_stale_messages())
472 as std::pin::Pin<Box<dyn tokio_stream::Stream<Item = Result<PendingMessage>> + '_>>
473 } else {
474 Box::pin(self.store.pending_embedding_messages())
475 as std::pin::Pin<Box<dyn tokio_stream::Stream<Item = Result<PendingMessage>> + '_>>
476 };
477 while let Some(pending) = stream.next().await {
478 // Stop pulling once the message cap is reached or cancellation
479 // fires; the staged window is still drained below, so the
480 // already-embedded slice commits cleanly.
481 if self.limit.is_some_and(|limit| pulled >= limit) || self.cancelled() {
482 break;
483 }
484 window.push(pending?);
485 pulled += 1;
486 if window.len() >= self.sort_window {
487 self.drain_window(&mut window, &mut summary).await?;
488 }
489 }
490 self.drain_window(&mut window, &mut summary).await?;
491 summary.cancelled = self.cancelled();
492
493 tracing::info!(
494 model = model_id(),
495 messages = summary.messages,
496 batches = summary.batches,
497 cancelled = summary.cancelled,
498 "embed worker finished",
499 );
500 Ok(summary)
501 }
502
503 /// One `merge_update` per window, not per 32-row batch: each
504 /// `merge_update` streams the target column once, so amortizing it over
505 /// a window-sized batch beats issuing it per model batch. The
506 /// length-sort clusters similar lengths because the tokenizer pads each
507 /// batch to its longest member. Empties `window`.
508 async fn drain_window(
509 &self,
510 window: &mut Vec<PendingMessage>,
511 summary: &mut EmbedSummary,
512 ) -> Result<()> {
513 if window.is_empty() {
514 return Ok(());
515 }
516 window.sort_unstable_by_key(|message| message.search_text.len());
517 let mut batch: Vec<PendingMessage> = Vec::with_capacity(DEFAULT_BATCH_SIZE);
518 let mut accumulator: Vec<EmbeddedMessage> = Vec::with_capacity(window.len());
519 for message in window.drain(..) {
520 batch.push(message);
521 if batch.len() >= DEFAULT_BATCH_SIZE {
522 accumulator.extend(self.embed_batch(&mut batch, summary).await?);
523 }
524 }
525 accumulator.extend(self.embed_batch(&mut batch, summary).await?);
526 if !accumulator.is_empty() {
527 self.store.write_embeddings(&accumulator).await?;
528 }
529 Ok(())
530 }
531
532 /// Run one model batch; return the rows. Store write is batched in
533 /// [`drain_window`](Self::drain_window), one `merge_update` per window.
534 async fn embed_batch(
535 &self,
536 batch: &mut Vec<PendingMessage>,
537 summary: &mut EmbedSummary,
538 ) -> Result<Vec<EmbeddedMessage>> {
539 if batch.is_empty() {
540 return Ok(Vec::new());
541 }
542 let pending = std::mem::take(batch);
543 // Apply e5's `passage: ` document prefix at the model boundary; the
544 // stored `search_text` keeps its uncapped, unprefixed form for FTS.
545 let texts = pending
546 .iter()
547 .map(|message| format_passage(&message.search_text))
548 .collect::<Vec<_>>();
549 let vectors = self.backend.embed(&texts)?;
550 if vectors.len() != pending.len() {
551 return Err(anyhow!(
552 "backend returned {} vectors for {} messages",
553 vectors.len(),
554 pending.len(),
555 ));
556 }
557 let rows = pending
558 .into_iter()
559 .zip(vectors)
560 .map(|(message, vector)| EmbeddedMessage {
561 session_id: message.session_id,
562 id: message.id,
563 vector,
564 })
565 .collect::<Vec<_>>();
566 let batch_messages = rows.len();
567 summary.messages += batch_messages;
568 summary.batches += 1;
569 if let Some(progress) = &self.progress {
570 progress(BatchProgress {
571 batch_messages,
572 total_messages: summary.messages,
573 total_batches: summary.batches,
574 });
575 }
576 Ok(rows)
577 }
578}
579
580#[cfg(test)]
581#[allow(clippy::unwrap_used)]
582mod tests {
583 use super::*;
584 use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
585
586 #[test]
587 fn e5_prefixes_apply_the_asymmetric_retrieval_pair() {
588 assert_eq!(
589 format_query("how does retry backoff work"),
590 "query: how does retry backoff work",
591 );
592 assert_eq!(
593 format_passage("retry uses exponential backoff"),
594 "passage: retry uses exponential backoff",
595 );
596 }
597
598 /// Counts how many times `LazyEmbedder` invokes its loader. Lets the
599 /// idle-eviction test detect reloads without spinning up a real model.
600 struct CountingEmbedder;
601 impl Embedder for CountingEmbedder {
602 fn device(&self) -> &str {
603 "test"
604 }
605 fn embed(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
606 Ok(vec![])
607 }
608 }
609
610 /// `LazyEmbedder` keys eviction on `std::time::Instant`, which isn't
611 /// affected by `tokio::time::pause`. The test uses a tiny real
612 /// threshold so the suite runs in <100 ms.
613 #[tokio::test(flavor = "multi_thread")]
614 async fn lazy_embedder_evicts_after_idle_threshold() {
615 let loads = Arc::new(AtomicUsize::new(0));
616 let counter = Arc::clone(&loads);
617 let loader: EmbedLoader = Arc::new(move || {
618 counter.fetch_add(1, AtomicOrdering::SeqCst);
619 Ok(Arc::new(CountingEmbedder) as Arc<dyn Embedder>)
620 });
621 let embedder =
622 LazyEmbedder::with_loader(loader).with_idle_threshold(Duration::from_millis(20));
623
624 embedder.get().await.unwrap();
625 assert_eq!(
626 loads.load(AtomicOrdering::SeqCst),
627 1,
628 "first get loads once"
629 );
630
631 embedder.get().await.unwrap();
632 assert_eq!(
633 loads.load(AtomicOrdering::SeqCst),
634 1,
635 "back-to-back get reuses the cached backend",
636 );
637
638 tokio::time::sleep(Duration::from_millis(60)).await;
639 embedder.get().await.unwrap();
640 assert_eq!(
641 loads.load(AtomicOrdering::SeqCst),
642 2,
643 "get after the idle threshold triggers a reload",
644 );
645 }
646
647 #[tokio::test(flavor = "multi_thread")]
648 async fn lazy_embedder_from_loaded_never_evicts() {
649 let preloaded = LazyEmbedder::from_loaded(Arc::new(CountingEmbedder));
650 preloaded.get().await.unwrap();
651 // Wait past any reasonable threshold; the from_loaded path uses
652 // Duration::MAX so the fake stays alive for the whole test.
653 tokio::time::sleep(Duration::from_millis(60)).await;
654 preloaded.get().await.unwrap();
655 }
656}