pond/embed.rs
1//! The embedding stage: candle XLM-RoBERTa FP16 ([`CandleEmbedder`]) plus
2//! the batch-oriented [`EmbedWorker`] that fills `messages.vector` /
3//! `messages.embedding_model` (spec.md#search). One message produces one
4//! vector - there is no chunking.
5//!
6//! [`LazyEmbedder`] caches a loaded backend for `pond mcp` / `pond serve`
7//! and drops it after [`DEFAULT_IDLE_EVICTION`] of no use. The drop is
8//! clean under macOS `phys_footprint` (post-drop drops to ~107 MiB
9//! regardless of backend), so time-weighted RSS over an interactive MCP
10//! session stays well under the per-instance budget despite the macOS
11//! Metal buffer pool's `iokit_mapped` retention during active queries.
12//!
13//! The worker accumulates messages and calls the model once per fixed-size
14//! batch, never once per message, and writes each batch's vectors to
15//! `messages` in one column-update commit.
16
17use std::sync::Arc;
18use std::sync::OnceLock;
19use std::sync::atomic::{AtomicBool, Ordering};
20use std::time::{Duration, Instant};
21
22use anyhow::{Context, Result, anyhow};
23use candle_core::{DType, Device, Tensor};
24use candle_nn::VarBuilder;
25use candle_transformers::models::xlm_roberta::{Config, XLMRobertaModel};
26use tokenizers::Tokenizer;
27use tokio::sync::Mutex;
28use tokio_stream::StreamExt;
29
30use crate::sessions::{EmbeddedMessage, PendingMessage, Store, embedding_dim};
31
32/// e5's training context. The tokenizer truncates input past it before
33/// inference - one message, one vector, bounded embed cost.
34pub(crate) const MAX_TOKENS: usize = 512;
35
36/// The candle e5 backend: XLM-RoBERTa FP16 weights on the GPU (Metal on
37/// macOS, CUDA on a `cuda`-feature non-macOS build, CPU otherwise).
38/// `forward` is `&self`, so no interior mutability is needed.
39pub struct CandleEmbedder {
40 model: XLMRobertaModel,
41 tokenizer: Tokenizer,
42 device: Device,
43}
44
45impl CandleEmbedder {
46 /// Load the configured XLM-RoBERTa model from HuggingFace (cached after
47 /// the first download) onto the best available device.
48 pub fn load() -> Result<Self> {
49 let device = select_device();
50 let id = model_id();
51 let api = hf_hub::api::sync::Api::new().context("init HuggingFace hub client")?;
52 let repo = api.model(id.to_owned());
53 let fetch = |file: &str| {
54 repo.get(file)
55 .with_context(|| format!("fetch {file} for {id}"))
56 };
57
58 let config: Config =
59 serde_json::from_str(&std::fs::read_to_string(fetch("config.json")?)?)?;
60 if config.hidden_size != embedding_dim() {
61 return Err(anyhow!(
62 "[embeddings].dim = {} but model {id:?} reports hidden_size = {}; \
63 set [embeddings].dim to match the model's output width.",
64 embedding_dim(),
65 config.hidden_size,
66 ));
67 }
68 // mmap the safetensors file: candle's `safetensors::load` path uses
69 // `std::fs::read` which retains an owned `Vec<u8>` of the full FP32
70 // weights in the system allocator after drop on macOS. mmap avoids
71 // the owned-heap path. Note: candle's Metal pool retains FP32->F16
72 // cast transients regardless (iokit_mapped contribution to
73 // phys_footprint, candle-core/src/metal_backend/device.rs:44-57).
74 let model_path = fetch("model.safetensors")?;
75 #[allow(unsafe_code)]
76 let vb =
77 unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], DType::F16, &device)? };
78 let model = XLMRobertaModel::new(&config, vb)
79 .map_err(|error| anyhow!("load {id} weights: {error}"))?;
80
81 let mut tokenizer = Tokenizer::from_file(fetch("tokenizer.json")?)
82 .map_err(|error| anyhow!("load e5 tokenizer: {error}"))?;
83 tokenizer.with_padding(Some(tokenizers::PaddingParams {
84 strategy: tokenizers::PaddingStrategy::BatchLongest,
85 pad_id: config.pad_token_id,
86 ..Default::default()
87 }));
88 tokenizer
89 .with_truncation(Some(tokenizers::TruncationParams {
90 max_length: MAX_TOKENS,
91 ..Default::default()
92 }))
93 .map_err(|error| anyhow!("configure e5 tokenizer: {error}"))?;
94
95 tracing::info!(model = %id, device = device_label(&device), "loaded embedding model");
96 Ok(Self {
97 model,
98 tokenizer,
99 device,
100 })
101 }
102}
103
104impl Embedder for CandleEmbedder {
105 fn device(&self) -> &str {
106 device_label(&self.device)
107 }
108
109 fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
110 if texts.is_empty() {
111 return Ok(Vec::new());
112 }
113 let encodings = self
114 .tokenizer
115 .encode_batch(texts.to_vec(), true)
116 .map_err(|error| anyhow!("tokenize embedding batch: {error}"))?;
117 let mut ids = Vec::with_capacity(encodings.len());
118 let mut masks = Vec::with_capacity(encodings.len());
119 for encoding in &encodings {
120 ids.push(Tensor::new(encoding.get_ids(), &self.device)?);
121 masks.push(Tensor::new(encoding.get_attention_mask(), &self.device)?);
122 }
123 let input_ids = Tensor::stack(&ids, 0)?;
124 let attention_mask = Tensor::stack(&masks, 0)?;
125 let token_type_ids = input_ids.zeros_like()?;
126 let hidden = self
127 .model
128 .forward(
129 &input_ids,
130 &attention_mask,
131 &token_type_ids,
132 None,
133 None,
134 None,
135 )?
136 .to_dtype(DType::F32)?;
137 let mask = attention_mask.to_dtype(DType::F32)?.unsqueeze(2)?;
138 let summed = hidden.broadcast_mul(&mask)?.sum(1)?;
139 let counts = mask.sum(1)?;
140 let mean = summed.broadcast_div(&counts)?;
141 let norm = mean.sqr()?.sum_keepdim(1)?.sqrt()?;
142 mean.broadcast_div(&norm)?
143 .to_vec2::<f32>()
144 .map_err(|error| anyhow!("read embedding vectors: {error}"))
145 }
146}
147
148fn select_device() -> Device {
149 #[cfg(target_os = "macos")]
150 let device = Device::metal_if_available(0);
151 #[cfg(not(target_os = "macos"))]
152 let device = Device::cuda_if_available(0);
153 device.unwrap_or_else(|error| {
154 tracing::warn!(%error, "GPU device unavailable, falling back to CPU");
155 Device::Cpu
156 })
157}
158
159fn device_label(device: &Device) -> &'static str {
160 match device {
161 Device::Cpu => "cpu",
162 Device::Cuda(_) => "cuda",
163 Device::Metal(_) => "metal",
164 }
165}
166
167/// Arc-shared factory used by [`LazyEmbedder`] to build the backend on
168/// first call (or on reload after idle eviction). Arc so the loader can be
169/// cloned into `spawn_blocking` without consuming `&self`.
170type EmbedLoader = Arc<dyn Fn() -> Result<Arc<dyn Embedder>> + Send + Sync>;
171
172/// How long the cached backend can sit unused before [`LazyEmbedder::get`]
173/// drops it. One minute returns the ~790 MB model to the idle floor quickly
174/// between interactive-MCP bursts; the reload is one cached model-load
175/// (~358 ms) on the first query after a quiet window.
176pub const DEFAULT_IDLE_EVICTION: Duration = Duration::from_secs(60);
177
178struct CachedBackend {
179 backend: Arc<dyn Embedder>,
180 last_use: Instant,
181}
182
183/// Lazy holder for an [`Embedder`] with idle eviction. The model isn't
184/// loaded until the first hybrid/vector call asks for it - idle `pond mcp`
185/// / `pond serve` processes pay nothing while no vector queries land. After
186/// `idle_threshold` of inactivity the cached backend is dropped on the
187/// next `get` call; under macOS `phys_footprint` the drop reclaims
188/// ~365-585 MiB cleanly (the post-drop floor is ~107 MiB regardless of
189/// backend). Reload cost is one synchronous model-load (300-500 ms),
190/// absorbed inside the human-paced gap between MCP queries.
191pub struct LazyEmbedder {
192 loader: EmbedLoader,
193 state: Mutex<Option<CachedBackend>>,
194 idle_threshold: Duration,
195}
196
197impl LazyEmbedder {
198 /// candle XLM-RoBERTa FP16 (Metal on macOS / CUDA with `--features cuda`
199 /// / CPU otherwise). The pond default for every entry point.
200 pub fn candle() -> Self {
201 Self::with_loader(Arc::new(|| {
202 Ok(Arc::new(CandleEmbedder::load()?) as Arc<dyn Embedder>)
203 }))
204 }
205
206 /// Build a `LazyEmbedder` from an explicit loader. Used by the bench
207 /// harness to override the idle threshold; production callers use
208 /// [`Self::candle`].
209 pub fn with_loader(loader: EmbedLoader) -> Self {
210 Self {
211 loader,
212 state: Mutex::new(None),
213 idle_threshold: DEFAULT_IDLE_EVICTION,
214 }
215 }
216
217 /// Override the idle-eviction threshold. Pass `Duration::MAX` to disable
218 /// eviction entirely - useful in benches that want a stable steady-state.
219 #[must_use]
220 pub fn with_idle_threshold(mut self, threshold: Duration) -> Self {
221 self.idle_threshold = threshold;
222 self
223 }
224
225 /// Pre-seed with an already-constructed backend. Used by integration
226 /// tests that want to inject a fake `Embedder` without paying the real
227 /// model-load cost. Eviction is disabled so the test fake survives the
228 /// whole test even if a test stalls.
229 pub fn from_loaded(backend: Arc<dyn Embedder>) -> Self {
230 let preloaded = Arc::clone(&backend);
231 let loader: EmbedLoader = Arc::new(move || Ok(Arc::clone(&preloaded)));
232 Self {
233 loader,
234 state: Mutex::new(Some(CachedBackend {
235 backend,
236 last_use: Instant::now(),
237 })),
238 idle_threshold: Duration::MAX,
239 }
240 }
241
242 /// Load (on first call or after eviction) or return the cached handle.
243 /// The candle load is synchronous and blocking, so it runs on
244 /// `spawn_blocking`; the async caller sees a clean `await` point.
245 pub async fn get(&self) -> Result<Arc<dyn Embedder>> {
246 let mut state = self.state.lock().await;
247 let now = Instant::now();
248 if let Some(cached) = &*state
249 && now.duration_since(cached.last_use) > self.idle_threshold
250 {
251 tracing::info!(
252 idle_secs = self.idle_threshold.as_secs(),
253 "evicting idle embedder",
254 );
255 *state = None;
256 }
257 if let Some(cached) = state.as_mut() {
258 cached.last_use = now;
259 return Ok(Arc::clone(&cached.backend));
260 }
261 let loader = Arc::clone(&self.loader);
262 let backend = tokio::task::spawn_blocking(move || loader())
263 .await
264 .map_err(|join_error| anyhow!("embedder load panicked: {join_error}"))??;
265 *state = Some(CachedBackend {
266 backend: Arc::clone(&backend),
267 last_use: now,
268 });
269 Ok(backend)
270 }
271}
272
273/// Default embedding model pond ships a loader for (spec.md#search). Used when
274/// `[embeddings].model` is absent. `pond optimize` stamps the runtime model id
275/// (see [`model_id`]) into `messages.embedding_model` with every vector.
276/// e5-small (384-dim) is the default; the paraphrase benchmark set showed no
277/// statistically-significant quality loss vs e5-base while halving vector
278/// storage and ~halving model RSS.
279pub const DEFAULT_MODEL_ID: &str = "intfloat/multilingual-e5-small";
280
281/// Process-wide model id, seeded once at startup from `[embeddings].model` via
282/// [`init_model_id`]. `OnceLock` (not `const`) so a temporary config file can
283/// pick e5-small / e5-large for an experiment without touching every call site.
284/// Uninitialized -> [`DEFAULT_MODEL_ID`], keeping unit tests config-free.
285static MODEL_ID_RUNTIME: OnceLock<String> = OnceLock::new();
286
287/// The active model id. Returns the value installed by [`init_model_id`] or
288/// [`DEFAULT_MODEL_ID`] when nothing has installed one (tests, ad-hoc tooling).
289pub fn model_id() -> &'static str {
290 MODEL_ID_RUNTIME
291 .get()
292 .map(String::as_str)
293 .unwrap_or(DEFAULT_MODEL_ID)
294}
295
296/// Seed [`model_id`] from config. First call wins; later calls with a different
297/// id are silently ignored - the process loads its config once.
298pub fn init_model_id(id: String) {
299 MODEL_ID_RUNTIME.get_or_init(|| id);
300}
301
302/// Messages per model-inference + write batch. e5 truncates at 512 tokens, so
303/// a 32-row batch's padded attention transient stays bounded.
304pub const DEFAULT_BATCH_SIZE: usize = 32;
305
306/// Messages buffered and length-sorted before being cut into model batches.
307/// The tokenizer pads every batch to its longest member, so a batch mixing a short
308/// and a long message embeds the short one at the long one's length. Sorting a
309/// window first clusters similar-length messages, so each batch pads near its
310/// own longest, not the corpus worst case. Bounded so peak memory stays one
311/// window, not the whole backlog. See [`EmbedWorker::with_sort_window`].
312pub const DEFAULT_SORT_WINDOW: usize = 2048;
313
314/// Format a search query for the embedder. e5 is an asymmetric retriever:
315/// its model card prescribes `query: ` on the search side, `passage: ` on
316/// documents. Used by `pond_search` to prepare the query text before the
317/// candle/Metal embed call.
318pub fn format_query(query: &str) -> String {
319 format!("query: {query}")
320}
321
322/// Format a document (one message's `search_text`) for the embedder - the
323/// `passage: ` half of the pair documented on [`format_query`]. Used by
324/// `EmbedWorker` when batching messages for `pond optimize`.
325pub fn format_passage(text: &str) -> String {
326 format!("passage: {text}")
327}
328
329/// The embedding seam (spec.md#search): text in, vectors out. The real
330/// backend is [`CandleEmbedder`]; tests substitute an instrumented fake
331/// to assert batching behavior. The vector width is checked at the write
332/// boundary and the model id is whatever [`model_id`] returns at the
333/// time of the write.
334pub trait Embedder: Send + Sync {
335 /// A short label naming the hardware/runtime: `"metal"`, `"cuda"`,
336 /// or `"cpu"`. Used by `pond optimize` to surface what backend ran the
337 /// inference; benches print it alongside latency.
338 fn device(&self) -> &str;
339
340 /// Embed a batch of texts. The returned vectors are L2-normalized and
341 /// [`embedding_dim`] long, one per input.
342 fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
343}
344
345/// Outcome of an [`EmbedWorker::run`] pass.
346#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
347pub struct EmbedSummary {
348 /// Messages embedded; one vector each.
349 pub messages: usize,
350 /// Model-inference + write batches issued.
351 pub batches: usize,
352 /// Set when the run exited via the cancel flag instead of stream end -
353 /// the caller uses this to print an interrupted notice and decide whether
354 /// to still rebuild downstream indices.
355 pub cancelled: bool,
356}
357
358/// Per-batch stats handed to a progress callback. Lets `pond optimize` drive an
359/// `indicatif` bar without leaking the crate into this module's API.
360#[derive(Debug, Clone, Copy)]
361pub struct BatchProgress {
362 /// Messages embedded in this batch.
363 pub batch_messages: usize,
364 /// Running message total across the run.
365 pub total_messages: usize,
366 /// Running batch count across the run.
367 pub total_batches: usize,
368}
369
370type ProgressFn = Box<dyn Fn(BatchProgress) + Send + Sync>;
371
372/// Fills `messages.vector` / `messages.embedding_model` for the backlog of
373/// un-embedded messages. Reads `messages.search_text` directly, batches it
374/// through the backend one vector each, and writes each batch back to
375/// `messages` by primary key.
376pub struct EmbedWorker<'a, B: Embedder> {
377 store: &'a Store,
378 backend: &'a B,
379 include_stale: bool,
380 /// Optional cap on total messages embedded in one `run` - `None` in
381 /// production (embed everything), set by the benchmark harness to a fixed
382 /// count so a run is a stable, comparable workload.
383 limit: Option<usize>,
384 /// Messages buffered and length-sorted per `drain_window` pass
385 /// ([`DEFAULT_SORT_WINDOW`]); the benchmark sweeps it through
386 /// [`EmbedWorker::with_sort_window`].
387 sort_window: usize,
388 /// Optional per-batch progress callback. Called once per `flush()` with
389 /// the running totals; `pond optimize` wires this to an `indicatif` bar.
390 progress: Option<ProgressFn>,
391 /// Set externally (Ctrl-C handler in `pond optimize`): the pull loop drains
392 /// the in-memory window before exiting so partial work is committed.
393 cancel: Option<Arc<AtomicBool>>,
394}
395
396impl<'a, B: Embedder> EmbedWorker<'a, B> {
397 /// Build a worker over `store`'s un-embedded backlog. A backend whose
398 /// vectors are the wrong width is rejected at the write boundary
399 /// (`embedding_update_batch`), so there is nothing to validate here.
400 pub fn new(store: &'a Store, backend: &'a B) -> Self {
401 Self {
402 store,
403 backend,
404 include_stale: false,
405 limit: None,
406 sort_window: DEFAULT_SORT_WINDOW,
407 progress: None,
408 cancel: None,
409 }
410 }
411
412 /// Honour `flag` as a cooperative cancellation signal. The pull loop checks
413 /// it before each new stream message; once set, the worker drains the
414 /// current window (committing the embedded slice) and returns with
415 /// `EmbedSummary { cancelled: true, .. }`. `pond optimize` wires this to a
416 /// Ctrl-C handler so an interrupted run doesn't lose its in-memory window.
417 pub fn with_cancel(mut self, flag: Arc<AtomicBool>) -> Self {
418 self.cancel = Some(flag);
419 self
420 }
421
422 fn cancelled(&self) -> bool {
423 self.cancel
424 .as_ref()
425 .is_some_and(|f| f.load(Ordering::Relaxed))
426 }
427
428 /// Override the length-sort window (default [`DEFAULT_SORT_WINDOW`]). The
429 /// benchmark harness sweeps this to size the padding-waste vs. throughput
430 /// trade-off; a window of [`DEFAULT_BATCH_SIZE`] disables sorting.
431 pub fn with_sort_window(mut self, window: usize) -> Self {
432 self.sort_window = window.max(DEFAULT_BATCH_SIZE);
433 self
434 }
435
436 /// Register a per-batch progress callback. Called once after each
437 /// `flush()` with the messages in the just-finished batch and the running
438 /// totals. `pond optimize` uses this to drive an `indicatif` progress bar.
439 pub fn with_progress(
440 mut self,
441 callback: impl Fn(BatchProgress) + Send + Sync + 'static,
442 ) -> Self {
443 self.progress = Some(Box::new(callback));
444 self
445 }
446
447 /// Cap the run at `limit` messages (default: no cap). The benchmark harness
448 /// uses this to embed a fixed, comparable slice of a corpus.
449 pub fn with_limit(mut self, limit: usize) -> Self {
450 self.limit = Some(limit.max(1));
451 self
452 }
453
454 pub fn include_stale(mut self) -> Self {
455 self.include_stale = true;
456 self
457 }
458
459 /// Embed every message whose `vector` is still null. Idempotent: a re-run
460 /// over an already-embedded corpus finds an empty backlog and is a no-op.
461 ///
462 /// Messages are pulled from a streaming scan, so peak memory is one stream
463 /// page plus the staged batch - not the whole corpus.
464 pub async fn run(&self) -> Result<EmbedSummary> {
465 let mut summary = EmbedSummary::default();
466 let mut window: Vec<PendingMessage> = Vec::with_capacity(self.sort_window);
467 let mut pulled = 0usize;
468
469 let mut stream = if self.include_stale {
470 Box::pin(self.store.pending_or_stale_messages())
471 as std::pin::Pin<Box<dyn tokio_stream::Stream<Item = Result<PendingMessage>> + '_>>
472 } else {
473 Box::pin(self.store.pending_embedding_messages())
474 as std::pin::Pin<Box<dyn tokio_stream::Stream<Item = Result<PendingMessage>> + '_>>
475 };
476 while let Some(pending) = stream.next().await {
477 // Stop pulling once the message cap is reached or cancellation
478 // fires; the staged window is still drained below, so the
479 // already-embedded slice commits cleanly.
480 if self.limit.is_some_and(|limit| pulled >= limit) || self.cancelled() {
481 break;
482 }
483 window.push(pending?);
484 pulled += 1;
485 if window.len() >= self.sort_window {
486 self.drain_window(&mut window, &mut summary).await?;
487 }
488 }
489 self.drain_window(&mut window, &mut summary).await?;
490 summary.cancelled = self.cancelled();
491
492 tracing::info!(
493 model = model_id(),
494 messages = summary.messages,
495 batches = summary.batches,
496 cancelled = summary.cancelled,
497 "embed worker finished",
498 );
499 Ok(summary)
500 }
501
502 /// One `merge_update` per window, not per 32-row batch: each
503 /// `merge_update` streams the target column once, so amortizing it over
504 /// a window-sized batch beats issuing it per model batch. The
505 /// length-sort clusters similar lengths because the tokenizer pads each
506 /// batch to its longest member. Empties `window`.
507 async fn drain_window(
508 &self,
509 window: &mut Vec<PendingMessage>,
510 summary: &mut EmbedSummary,
511 ) -> Result<()> {
512 if window.is_empty() {
513 return Ok(());
514 }
515 window.sort_unstable_by_key(|message| message.search_text.len());
516 let mut batch: Vec<PendingMessage> = Vec::with_capacity(DEFAULT_BATCH_SIZE);
517 let mut accumulator: Vec<EmbeddedMessage> = Vec::with_capacity(window.len());
518 for message in window.drain(..) {
519 batch.push(message);
520 if batch.len() >= DEFAULT_BATCH_SIZE {
521 accumulator.extend(self.embed_batch(&mut batch, summary).await?);
522 }
523 }
524 accumulator.extend(self.embed_batch(&mut batch, summary).await?);
525 if !accumulator.is_empty() {
526 self.store.write_embeddings(&accumulator).await?;
527 }
528 Ok(())
529 }
530
531 /// Run one model batch; return the rows. Store write is batched in
532 /// [`drain_window`](Self::drain_window), one `merge_update` per window.
533 async fn embed_batch(
534 &self,
535 batch: &mut Vec<PendingMessage>,
536 summary: &mut EmbedSummary,
537 ) -> Result<Vec<EmbeddedMessage>> {
538 if batch.is_empty() {
539 return Ok(Vec::new());
540 }
541 let pending = std::mem::take(batch);
542 // Apply e5's `passage: ` document prefix at the model boundary; the
543 // stored `search_text` keeps its uncapped, unprefixed form for FTS.
544 let texts = pending
545 .iter()
546 .map(|message| format_passage(&message.search_text))
547 .collect::<Vec<_>>();
548 let vectors = self.backend.embed(&texts)?;
549 if vectors.len() != pending.len() {
550 return Err(anyhow!(
551 "backend returned {} vectors for {} messages",
552 vectors.len(),
553 pending.len(),
554 ));
555 }
556 let rows = pending
557 .into_iter()
558 .zip(vectors)
559 .map(|(message, vector)| EmbeddedMessage {
560 session_id: message.session_id,
561 id: message.id,
562 vector,
563 })
564 .collect::<Vec<_>>();
565 let batch_messages = rows.len();
566 summary.messages += batch_messages;
567 summary.batches += 1;
568 if let Some(progress) = &self.progress {
569 progress(BatchProgress {
570 batch_messages,
571 total_messages: summary.messages,
572 total_batches: summary.batches,
573 });
574 }
575 Ok(rows)
576 }
577}
578
579#[cfg(test)]
580#[allow(clippy::unwrap_used)]
581mod tests {
582 use super::*;
583 use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
584
585 #[test]
586 fn e5_prefixes_apply_the_asymmetric_retrieval_pair() {
587 assert_eq!(
588 format_query("how does retry backoff work"),
589 "query: how does retry backoff work",
590 );
591 assert_eq!(
592 format_passage("retry uses exponential backoff"),
593 "passage: retry uses exponential backoff",
594 );
595 }
596
597 /// Counts how many times `LazyEmbedder` invokes its loader. Lets the
598 /// idle-eviction test detect reloads without spinning up a real model.
599 struct CountingEmbedder;
600 impl Embedder for CountingEmbedder {
601 fn device(&self) -> &str {
602 "test"
603 }
604 fn embed(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
605 Ok(vec![])
606 }
607 }
608
609 /// `LazyEmbedder` keys eviction on `std::time::Instant`, which isn't
610 /// affected by `tokio::time::pause`. The test uses a tiny real
611 /// threshold so the suite runs in <100 ms.
612 #[tokio::test(flavor = "multi_thread")]
613 async fn lazy_embedder_evicts_after_idle_threshold() {
614 let loads = Arc::new(AtomicUsize::new(0));
615 let counter = Arc::clone(&loads);
616 let loader: EmbedLoader = Arc::new(move || {
617 counter.fetch_add(1, AtomicOrdering::SeqCst);
618 Ok(Arc::new(CountingEmbedder) as Arc<dyn Embedder>)
619 });
620 let embedder =
621 LazyEmbedder::with_loader(loader).with_idle_threshold(Duration::from_millis(20));
622
623 embedder.get().await.unwrap();
624 assert_eq!(
625 loads.load(AtomicOrdering::SeqCst),
626 1,
627 "first get loads once"
628 );
629
630 embedder.get().await.unwrap();
631 assert_eq!(
632 loads.load(AtomicOrdering::SeqCst),
633 1,
634 "back-to-back get reuses the cached backend",
635 );
636
637 tokio::time::sleep(Duration::from_millis(60)).await;
638 embedder.get().await.unwrap();
639 assert_eq!(
640 loads.load(AtomicOrdering::SeqCst),
641 2,
642 "get after the idle threshold triggers a reload",
643 );
644 }
645
646 #[tokio::test(flavor = "multi_thread")]
647 async fn lazy_embedder_from_loaded_never_evicts() {
648 let preloaded = LazyEmbedder::from_loaded(Arc::new(CountingEmbedder));
649 preloaded.get().await.unwrap();
650 // Wait past any reasonable threshold; the from_loaded path uses
651 // Duration::MAX so the fake stays alive for the whole test.
652 tokio::time::sleep(Duration::from_millis(60)).await;
653 preloaded.get().await.unwrap();
654 }
655}