Skip to main content

sqlite_graphrag/
embedder.rs

1//! Embedding generation for the GraphRAG memory.
2//!
3//! v1.0.76: the default build is **LLM-only** — the binary does NOT bundle
4//! fastembed / ort / ndarray / tokenizers. All embeddings are produced
5//! by a headless invocation of `claude code` or `codex` (OAuth, no MCP,
6//! no hooks) and stored as a BLOB in `memory_embeddings(memory_id, embedding,
7//! source)`. Vector similarity is computed in pure Rust at query time.
8//!
9//! # Workload classification (G42/S3, BLOCK 1 — MANDATORY)
10//!
11//! LLM embedding is **I/O-bound + subprocess-bound**: each call waits
12//! 5-60s on a network round-trip through a headless `claude -p` /
13//! `codex exec` subprocess while the local CPU stays idle. Concurrency
14//! therefore uses **tokio** (async I/O concurrency) and NEVER rayon
15//! (reserved for CPU-bound work).
16//!
17//! # Permit formula (G42/S3, BLOCO 2)
18//!
19//! ```text
20//! permits = clamp(--llm-parallelism, 1, 32)
21//!           .min(available_parallelism())
22//!           .min(available_ram_mb * 0.5 / LLM_WORKER_RSS_MB)
23//! ```
24//!
25//! `LLM_WORKER_RSS_MB = 350` (`crate::constants`): `claude -p` and
26//! `codex exec` are node processes with a typical Maximum RSS of
27//! 200-400 MB (measured via `/usr/bin/time -l` on macOS /
28//! `/usr/bin/time -v` on Linux), so the RAM bound is pertinent.
29//!
30//! # Locking contract (G42/A3 fix)
31//!
32//! The process-wide `Mutex<LlmEmbedding>` protects ONLY the cheap clone
33//! of the client configuration (flavour + binary path + model + shared
34//! schema tempfiles). It is NEVER held across network I/O — the
35//! v1.0.76-v1.0.78 `flush_group` held it for the whole sequential
36//! embedding loop, which is why `--llm-parallelism 8` measured an
37//! effective parallelism of 1.
38
39use crate::errors::AppError;
40use crate::extract::llm_embedding::LlmEmbedding;
41use parking_lot::Mutex;
42use std::path::Path;
43use std::sync::Arc;
44use std::sync::OnceLock;
45use tokio::sync::{mpsc, Semaphore};
46use tokio::task::JoinSet;
47use tokio_util::sync::CancellationToken;
48
49/// Process-wide LLM-embedding client behind a `Mutex`.
50///
51/// The lock guards configuration cloning only (see module docs); the
52/// actual LLM I/O happens on clones, outside the lock.
53static EMBEDDER: OnceLock<Mutex<LlmEmbedding>> = OnceLock::new();
54
55/// Process-wide multi-thread tokio runtime for embedding I/O.
56///
57/// G42/A2 fix: v1.0.76-v1.0.78 built a current-thread runtime PER CALL.
58/// One runtime per process amortises the setup and hosts the bounded
59/// fan-out of `embed_texts_parallel`.
60static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
61
62/// Calibration base: chunk (long-text) batch size per LLM call at the
63/// calibration dimensionality (G42/S2). Use [`chunk_embed_batch_size`]
64/// for the dim-adaptive value (G44).
65pub const CHUNK_EMBED_BATCH_SIZE: usize = 8;
66
67/// Calibration base: entity-name (short-text) batch size per LLM call at
68/// the calibration dimensionality (G42/S2). Use [`entity_embed_batch_size`]
69/// for the dim-adaptive value (G44).
70pub const ENTITY_EMBED_BATCH_SIZE: usize = 25;
71
72/// Dimensionality the batch bases above were calibrated against (G44).
73pub const EMBED_BATCH_CALIBRATION_DIM: usize = 64;
74
75/// G44: scales a calibration-base batch size to the active dimensionality,
76/// keeping the float budget per LLM call constant (~512 floats for chunks,
77/// ~1600 for entity names — the budgets empirically validated at dim 64).
78/// Fixed batches of 8 at 384 dims asked for ~3072 floats per response:
79/// claude returned partial coverage (3 of 8 items, caught by the G42/C5
80/// check) and codex timed out at 300s. `base.max(1)` keeps the function
81/// total — `clamp` panics when the upper bound is below the lower one.
82fn adaptive_batch_for_dim(base: usize, dim: usize) -> usize {
83    let base = base.max(1);
84    (base * EMBED_BATCH_CALIBRATION_DIM / dim.max(1)).clamp(1, base)
85}
86
87/// Dim-adaptive batch size for chunk (long-text) embedding calls (G44).
88pub fn chunk_embed_batch_size() -> usize {
89    let dim = crate::constants::embedding_dim();
90    let batch = adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, dim);
91    tracing::debug!(
92        dim,
93        base = CHUNK_EMBED_BATCH_SIZE,
94        batch,
95        "adaptive chunk batch size (G44)"
96    );
97    batch
98}
99
100/// Dim-adaptive batch size for entity-name (short-text) embedding calls (G44).
101pub fn entity_embed_batch_size() -> usize {
102    let dim = crate::constants::embedding_dim();
103    let batch = adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, dim);
104    tracing::debug!(
105        dim,
106        base = ENTITY_EMBED_BATCH_SIZE,
107        batch,
108        "adaptive entity batch size (G44)"
109    );
110    batch
111}
112
113/// Returns the process-wide multi-thread runtime, building it on first use.
114pub(crate) fn shared_runtime() -> Result<&'static tokio::runtime::Runtime, AppError> {
115    if let Some(rt) = RUNTIME.get() {
116        return Ok(rt);
117    }
118    let rt = tokio::runtime::Builder::new_multi_thread()
119        .worker_threads(2)
120        .enable_all()
121        .build()
122        .map_err(|e| AppError::Embedding(format!("tokio runtime init failed: {e}")))?;
123    let _ = RUNTIME.set(rt);
124    Ok(RUNTIME.get().expect("RUNTIME initialised above"))
125}
126
127/// Initialises the LLM-embedding client on first use and returns it.
128pub fn get_embedder(_models_dir: &Path) -> Result<&'static Mutex<LlmEmbedding>, AppError> {
129    if let Some(e) = EMBEDDER.get() {
130        return Ok(e);
131    }
132    let backend = LlmEmbedding::detect_available()?;
133    let _ = EMBEDDER.set(Mutex::new(backend));
134    Ok(EMBEDDER.get().expect("EMBEDDER initialised above"))
135}
136
137/// Clones the embedding-client configuration. The lock is held only for
138/// the duration of the clone — NEVER across I/O (G42/A3).
139fn clone_client(embedder: &Mutex<LlmEmbedding>) -> LlmEmbedding {
140    embedder.lock().clone()
141}
142
143/// Embeds a single passage for storage. Delegates to the configured LLM
144/// headless (claude code / codex). Returns a vector of the active
145/// dimensionality.
146pub fn embed_passage(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
147    let client = clone_client(embedder);
148    let result = client.embed_passage(text)?;
149    validate_dim(result)
150}
151
152/// Embeds a single query for similarity search. Same model and dim as
153/// `embed_passage`; the only difference is the LLM-side prompt prefix
154/// that the headless invocation uses to disambiguate.
155pub fn embed_query(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
156    let client = clone_client(embedder);
157    let result = client.embed_query(text)?;
158    validate_dim(result)
159}
160
161/// Embeds a batch of passages with token-count-aware batching.
162///
163/// Kept for API compatibility; since v1.0.79 it routes through the
164/// bounded parallel fan-out with conservative defaults.
165pub fn embed_passages_controlled(
166    embedder: &Mutex<LlmEmbedding>,
167    texts: &[&str],
168    _token_counts: &[usize],
169) -> Result<Vec<Vec<f32>>, AppError> {
170    if texts.is_empty() {
171        return Ok(Vec::new());
172    }
173    let owned: Vec<String> = texts.iter().map(|t| t.to_string()).collect();
174    embed_texts_parallel(embedder, &owned, 1, chunk_embed_batch_size())
175}
176
177pub fn embed_passage_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
178    let embedder = get_embedder(models_dir)?;
179    embed_passage(embedder, text)
180}
181
182pub fn embed_query_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
183    let embedder = get_embedder(models_dir)?;
184    embed_query(embedder, text)
185}
186/// G58/S1: reason an embedding call could not be completed and the caller
187/// must fall back to a non-vector retrieval path (FTS5 prefix + LIKE).
188///
189/// Returned by [`try_embed_query_with_fallback`] so the `recall` and
190/// `hybrid-search` handlers can surface a structured `vec_degraded` /
191/// `warning` envelope instead of a hard `AppError::Embedding` exit 11.
192#[derive(Debug, Clone, PartialEq)]
193pub enum FallbackReason {
194    /// The LLM subprocess failed (rate limit, OAuth contention, quota
195    /// exhausted, model unparsable response, divergent dim, etc.).
196    /// Carries the original error message for observability.
197    EmbeddingFailed(String),
198    /// The embedding was cancelled by an external signal (SIGTERM, etc.).
199    Cancelled,
200    /// The embedding exceeded its time budget. Carries the operation name
201    /// and the elapsed seconds for diagnostic logging.
202    Timeout {
203        operation: String,
204        duration_secs: u64,
205    },
206}
207
208impl std::fmt::Display for FallbackReason {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        match self {
211            Self::EmbeddingFailed(msg) => write!(f, "embedding failed: {msg}"),
212            Self::Cancelled => write!(f, "embedding cancelled by external signal"),
213            Self::Timeout {
214                operation,
215                duration_secs,
216            } => {
217                write!(
218                    f,
219                    "embedding timed out after {duration_secs}s during {operation}"
220                )
221            }
222        }
223    }
224}
225
226impl std::error::Error for FallbackReason {}
227
228/// G58/S1: try to embed a query, mapping any failure to a structured
229/// [`FallbackReason`] so callers can route to FTS5 + LIKE fallback instead
230/// of returning exit 11 to the user.
231///
232/// This is the bridge between the hard-fail `embed_query_local` (used by
233/// write paths where embedding failure aborts the operation) and the
234/// graceful-degradation contract of `recall` / `hybrid-search` in v1.0.80.
235pub fn try_embed_query_with_fallback(
236    models_dir: &Path,
237    query: &str,
238) -> Result<Vec<f32>, FallbackReason> {
239    match embed_query_local(models_dir, query) {
240        Ok(v) => Ok(v),
241        Err(AppError::Embedding(msg)) if msg.contains("cancelled") => {
242            Err(FallbackReason::Cancelled)
243        }
244        Err(AppError::Embedding(msg)) => Err(FallbackReason::EmbeddingFailed(msg)),
245        Err(AppError::Timeout {
246            operation,
247            duration_secs,
248        }) => Err(FallbackReason::Timeout {
249            operation,
250            duration_secs,
251        }),
252        Err(e) => Err(FallbackReason::EmbeddingFailed(e.to_string())),
253    }
254}
255
256pub fn embed_passages_controlled_local(
257    models_dir: &Path,
258    texts: &[&str],
259    token_counts: &[usize],
260) -> Result<Vec<Vec<f32>>, AppError> {
261    let embedder = get_embedder(models_dir)?;
262    embed_passages_controlled(embedder, texts, token_counts)
263}
264
265/// G42/S3: embeds `texts` through the bounded parallel fan-out and
266/// returns vectors in input order.
267pub fn embed_passages_parallel_local(
268    models_dir: &Path,
269    texts: &[String],
270    parallelism: usize,
271    batch_size: usize,
272) -> Result<Vec<Vec<f32>>, AppError> {
273    let embedder = get_embedder(models_dir)?;
274    embed_texts_parallel(embedder, texts, parallelism, batch_size)
275}
276
277/// G56: in-process cache for entity embeddings keyed by `(model, text)`.
278///
279/// Schema v13 is immutable: `entity_embeddings` does not have a `text`
280/// column, so a pure DB-side cache would require a schema bump. Instead
281/// we keep a process-wide LRU-style map that survives within one CLI
282/// invocation. The hit rate is high in `ingest` (re-embedding the same
283/// canonical entity across thousands of memories) and modest in `remember`
284/// (typical single-memory invocations).
285///
286/// Key: `blake3(model || "\0" || text)`. Value: `Arc<Vec<f32>>` so the
287/// collector can drop the map entry while a `Vec` is still in flight.
288type EntityEmbedCacheMap = std::collections::HashMap<u64, Arc<Vec<f32>>>;
289
290static ENTITY_EMBED_CACHE: OnceLock<parking_lot::Mutex<EntityEmbedCacheMap>> = OnceLock::new();
291
292fn entity_embed_cache() -> &'static parking_lot::Mutex<EntityEmbedCacheMap> {
293    ENTITY_EMBED_CACHE.get_or_init(|| parking_lot::Mutex::new(std::collections::HashMap::new()))
294}
295
296fn entity_cache_key(model: &str, text: &str) -> u64 {
297    let mut hasher = blake3::Hasher::new();
298    hasher.update(model.as_bytes());
299    hasher.update(b"\0");
300    hasher.update(text.as_bytes());
301    let h = hasher.finalize();
302    let bytes = h.as_bytes();
303    u64::from_le_bytes([
304        bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
305    ])
306}
307
308/// G56: embeds entity-name texts through a process-wide cache.
309///
310/// Skips any `(model, text)` pair already produced in this CLI invocation
311/// and only spawns subprocesses for the cache misses. Returns vectors in
312/// the same order as `texts`.
313///
314/// Designed for entity-name batches (short texts). For chunk embeds use
315/// [`embed_passages_parallel_local`] directly — chunks are unique per
316/// memory and cache hit rate is negligible.
317pub fn embed_entity_texts_cached(
318    models_dir: &Path,
319    texts: &[String],
320    parallelism: usize,
321) -> Result<(Vec<Vec<f32>>, EmbedCacheStats), AppError> {
322    if texts.is_empty() {
323        return Ok((Vec::new(), EmbedCacheStats::default()));
324    }
325    let embedder = get_embedder(models_dir)?;
326    let model = embedder.lock().model_label();
327    let cache = entity_embed_cache();
328    let mut hits: Vec<Option<Arc<Vec<f32>>>> = vec![None; texts.len()];
329    let mut miss_indices: Vec<usize> = Vec::with_capacity(texts.len());
330    {
331        let guard = cache.lock();
332        for (i, text) in texts.iter().enumerate() {
333            let key = entity_cache_key(&model, text);
334            if let Some(v) = guard.get(&key) {
335                hits[i] = Some(Arc::clone(v));
336            } else {
337                miss_indices.push(i);
338            }
339        }
340    }
341    let miss_count = miss_indices.len();
342    if miss_count > 0 {
343        let miss_texts: Vec<String> = miss_indices.iter().map(|&i| texts[i].clone()).collect();
344        let miss_vecs = embed_texts_parallel(
345            embedder,
346            &miss_texts,
347            parallelism,
348            entity_embed_batch_size(),
349        )?;
350        let mut guard = cache.lock();
351        for (slot, &orig_idx) in miss_indices.iter().enumerate() {
352            let vec = Arc::new(miss_vecs[slot].clone());
353            let key = entity_cache_key(&model, &texts[orig_idx]);
354            guard.insert(key, Arc::clone(&vec));
355            hits[orig_idx] = Some(vec);
356        }
357    }
358    let mut out = Vec::with_capacity(texts.len());
359    for hit in hits.into_iter() {
360        let v = hit.ok_or_else(|| {
361            AppError::Embedding("entity embed cache produced null result".to_string())
362        })?;
363        out.push((*v).clone());
364    }
365    Ok((
366        out,
367        EmbedCacheStats {
368            requested: texts.len(),
369            hits: texts.len() - miss_count,
370            misses: miss_count,
371        },
372    ))
373}
374
375/// G56: stats snapshot returned by [`embed_entity_texts_cached`].
376#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, serde::Serialize)]
377pub struct EmbedCacheStats {
378    pub requested: usize,
379    pub hits: usize,
380    pub misses: usize,
381}
382
383impl EmbedCacheStats {
384    /// Hit rate as a fraction in `[0.0, 1.0]`. Returns 0.0 when nothing was requested.
385    pub fn hit_rate(&self) -> f64 {
386        if self.requested == 0 {
387            0.0
388        } else {
389            self.hits as f64 / self.requested as f64
390        }
391    }
392}
393
394/// G42/S3 core: bounded parallel batch embedding.
395///
396/// - texts are grouped into batches of `batch_size` (one LLM call per
397///   batch, G42/S2);
398/// - at most `effective_permits(parallelism)` LLM subprocesses run
399///   simultaneously (`Arc<Semaphore>` + `acquire_owned`, BLOCO 2);
400/// - results stream through a BOUNDED mpsc channel so the caller-side
401///   collector applies backpressure and can persist incrementally
402///   (BLOCO 5);
403/// - the global `CancellationToken` aborts in-flight work on the first
404///   signal; subprocesses die with their futures via `kill_on_drop`
405///   (BLOCO 6).
406pub fn embed_texts_parallel(
407    embedder: &Mutex<LlmEmbedding>,
408    texts: &[String],
409    parallelism: usize,
410    batch_size: usize,
411) -> Result<Vec<Vec<f32>>, AppError> {
412    let mut slots: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
413    embed_texts_parallel_with(embedder, texts, parallelism, batch_size, |idx, v| {
414        slots[idx] = Some(v.to_vec());
415        Ok(())
416    })?;
417    let mut out = Vec::with_capacity(slots.len());
418    for (idx, slot) in slots.into_iter().enumerate() {
419        out.push(slot.ok_or_else(|| {
420            AppError::Embedding(format!("embedding fan-out lost item index {idx}"))
421        })?);
422    }
423    Ok(out)
424}
425
426/// Like [`embed_texts_parallel`] but invokes `on_result` as soon as each
427/// embedding arrives (BLOCO 5: incremental persistence — a kill loses at
428/// most the in-flight batches, never the already-delivered items).
429pub fn embed_texts_parallel_with(
430    embedder: &Mutex<LlmEmbedding>,
431    texts: &[String],
432    parallelism: usize,
433    batch_size: usize,
434    mut on_result: impl FnMut(usize, &[f32]) -> Result<(), AppError>,
435) -> Result<(), AppError> {
436    if texts.is_empty() {
437        return Ok(());
438    }
439    let dim = crate::constants::embedding_dim();
440    if texts.len() == 1 {
441        let v = embed_passage(embedder, &texts[0])?;
442        return on_result(0, &v);
443    }
444
445    let client = clone_client(embedder);
446    let permits = effective_permits(parallelism);
447    let batches = build_batches(texts, batch_size.max(1));
448    let token = crate::cancel_token().clone();
449
450    let work = move |batch: Vec<(usize, String)>| {
451        let client = client.clone();
452        async move {
453            client
454                .embed_batch_async(crate::constants::PASSAGE_PREFIX, &batch)
455                .await
456        }
457    };
458
459    let fan_out = run_bounded(batches, permits, dim, token, work, &mut on_result);
460    match tokio::runtime::Handle::try_current() {
461        Ok(handle) => tokio::task::block_in_place(|| handle.block_on(fan_out)),
462        Err(_) => shared_runtime()?.block_on(fan_out),
463    }
464}
465
466/// Groups `(global_index, text)` pairs into batches of `batch_size`.
467fn build_batches(texts: &[String], batch_size: usize) -> Vec<Vec<(usize, String)>> {
468    texts
469        .iter()
470        .cloned()
471        .enumerate()
472        .collect::<Vec<_>>()
473        .chunks(batch_size)
474        .map(|c| c.to_vec())
475        .collect()
476}
477
478/// G42/S3 BLOCO 2: effective permit count.
479///
480/// `permits = clamp(requested, 1, 32) ∧ cpus ∧ ram_livre*0.5/RSS` — see
481/// the module docs for the measured RSS rationale.
482pub fn effective_permits(requested: usize) -> usize {
483    let cpus = std::thread::available_parallelism()
484        .map(|n| n.get())
485        .unwrap_or(4);
486    let by_ram = ((crate::memory_guard::available_memory_mb() / 2)
487        / crate::constants::LLM_WORKER_RSS_MB)
488        .max(1) as usize;
489    requested.clamp(1, 32).min(cpus).min(by_ram).max(1)
490}
491
492/// Bounded fan-out engine. Generic over the per-batch work so the
493/// concurrency contract is testable without spawning real LLMs.
494///
495/// Cancel safety (BLOCO 6/10): every task races its work against
496/// `token.cancelled()` inside `tokio::select!`; both branches are
497/// cancel-safe (the work future owns its subprocess via `kill_on_drop`,
498/// and `cancelled()` is pure). On collector-side errors the `JoinSet`
499/// is shut down, which drops in-flight futures and kills their
500/// subprocesses.
501async fn run_bounded<F, Fut>(
502    batches: Vec<Vec<(usize, String)>>,
503    permits: usize,
504    dim: usize,
505    token: CancellationToken,
506    work: F,
507    on_result: &mut impl FnMut(usize, &[f32]) -> Result<(), AppError>,
508) -> Result<(), AppError>
509where
510    F: Fn(Vec<(usize, String)>) -> Fut + Clone + Send + 'static,
511    Fut: std::future::Future<Output = Result<Vec<(usize, Vec<f32>)>, AppError>> + Send,
512{
513    let total_batches = batches.len();
514    let semaphore = Arc::new(Semaphore::new(permits));
515    // BLOCO 5: bounded channel — producers block when the collector is
516    // behind (backpressure); PROIBIDO unbounded_channel between stages.
517    let (tx, mut rx) = mpsc::channel::<Result<Vec<(usize, Vec<f32>)>, AppError>>(permits * 2);
518    let mut set: JoinSet<()> = JoinSet::new();
519
520    for (batch_idx, batch) in batches.into_iter().enumerate() {
521        let sem = Arc::clone(&semaphore);
522        let token = token.clone();
523        let tx = tx.clone();
524        let work = work.clone();
525        set.spawn(async move {
526            let wait_start = std::time::Instant::now();
527            // acquire_owned: RAII permit moved into the task; returned
528            // on every exit path INCLUDING panic (BLOCO 2).
529            let Ok(_permit) = sem.acquire_owned().await else {
530                let _ = tx
531                    .send(Err(AppError::Embedding("semaphore closed".to_string())))
532                    .await;
533                return;
534            };
535            let permit_wait_ms = wait_start.elapsed().as_millis() as u64;
536            let work_start = std::time::Instant::now();
537            // ADR-0034: when `SQLITE_GRAPHRAG_IGNORE_SHUTDOWN=1` is set the
538            // cancellation arm is dropped and the batch runs to completion.
539            // This unblocks audit/test invocations whose `SHUTDOWN` flag was
540            // contaminated by an earlier signal handler in the same process
541            // tree. Production code never sees this branch.
542            let outcome = if crate::should_obey_shutdown() {
543                tokio::select! {
544                    res = work(batch) => res,
545                    _ = token.cancelled() => Err(AppError::Embedding(
546                        "embedding cancelled by shutdown signal".to_string(),
547                    )),
548                }
549            } else {
550                work(batch).await
551            };
552            // BLOCO 8: permit wait time logged SEPARATELY from work time.
553            tracing::debug!(
554                target: "embedding",
555                batch_idx,
556                permit_wait_ms,
557                work_ms = work_start.elapsed().as_millis() as u64,
558                ok = outcome.is_ok(),
559                "embedding batch finished"
560            );
561            let _ = tx.send(outcome).await;
562        });
563    }
564    drop(tx);
565
566    let mut completed = 0usize;
567    let mut failed = 0usize;
568    let mut cancelled = 0usize;
569    let mut first_error: Option<AppError> = None;
570
571    while let Some(message) = rx.recv().await {
572        match message {
573            Ok(items) => {
574                completed += 1;
575                if first_error.is_none() {
576                    for (idx, v) in items {
577                        if v.len() != dim {
578                            first_error = Some(AppError::Embedding(format!(
579                                "LLM returned {} dims for item {idx}, expected {dim}; \
580                                 refusing to truncate or pad silently (G42/C5)",
581                                v.len()
582                            )));
583                            break;
584                        }
585                        if let Err(e) = on_result(idx, &v) {
586                            first_error = Some(e);
587                            break;
588                        }
589                    }
590                    if first_error.is_some() {
591                        // Abort remaining work: dropped futures kill
592                        // their subprocesses via kill_on_drop (BLOCO 6).
593                        set.shutdown().await;
594                    }
595                }
596            }
597            Err(e) => {
598                if matches!(&e, AppError::Embedding(msg) if msg.contains("cancelled")) {
599                    cancelled += 1;
600                } else {
601                    failed += 1;
602                }
603                if first_error.is_none() {
604                    first_error = Some(e);
605                    set.shutdown().await;
606                }
607            }
608        }
609    }
610
611    // Drain the JoinSet: surface panics distinctly (panic handling —
612    // JoinError::is_panic tratado em todo join_next, BLOCO 9).
613    while let Some(join_result) = set.join_next().await {
614        if let Err(join_err) = join_result {
615            if join_err.is_panic() {
616                failed += 1;
617                if first_error.is_none() {
618                    first_error = Some(AppError::Embedding(format!(
619                        "embedding task panicked: {join_err}"
620                    )));
621                }
622            } else {
623                cancelled += 1;
624            }
625        }
626    }
627
628    // BLOCO 8: saturation observability — available_permits plus the
629    // completed/failed/cancelled counters on the progress stream.
630    tracing::info!(
631        target: "embedding",
632        total_batches,
633        completed,
634        failed,
635        cancelled,
636        available_permits = semaphore.available_permits(),
637        "embedding fan-out finished"
638    );
639
640    match first_error {
641        Some(e) => Err(e),
642        None => Ok(()),
643    }
644}
645
646pub fn f32_to_bytes(v: &[f32]) -> Vec<u8> {
647    let mut out = Vec::with_capacity(v.len() * 4);
648    for f in v {
649        out.extend_from_slice(&f.to_le_bytes());
650    }
651    out
652}
653
654pub fn bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
655    let mut out = Vec::with_capacity(bytes.len() / 4);
656    for chunk in bytes.chunks_exact(4) {
657        out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
658    }
659    out
660}
661
662/// Returns the dimensionality of the embedding space. Used to
663/// validate LLM responses and to size the in-memory cache.
664pub fn embedding_dim() -> usize {
665    crate::constants::embedding_dim()
666}
667
668/// G42/C5: a vector with a divergent dimensionality is an ERROR, never
669/// silently truncated or zero-padded (the pre-v1.0.79 `normalise_dim`
670/// masked malformed LLM responses).
671fn validate_dim(v: Vec<f32>) -> Result<Vec<f32>, AppError> {
672    let dim = crate::constants::embedding_dim();
673    if v.len() != dim {
674        return Err(AppError::Embedding(format!(
675            "embedding has {} dims, expected {dim}; \
676             refusing to truncate or pad silently (G42/C5)",
677            v.len()
678        )));
679    }
680    Ok(v)
681}
682
683#[cfg(test)]
684mod tests {
685    use super::*;
686    use std::sync::atomic::{AtomicUsize, Ordering};
687
688    #[test]
689    fn f32_to_bytes_roundtrip() {
690        let input = vec![0.0_f32, 1.5, -2.25, f32::MIN, f32::MAX];
691        let bytes = f32_to_bytes(&input);
692        assert_eq!(bytes.len(), input.len() * 4);
693        let out = bytes_to_f32(&bytes);
694        assert_eq!(out, input);
695    }
696
697    #[test]
698    fn validate_dim_rejects_divergent_vectors() {
699        // G42/C5 acceptance criterion: a divergent vector MUST fail —
700        // never be silently normalised.
701        let dim = crate::constants::embedding_dim();
702        let long = vec![0.0; dim + 10];
703        assert!(validate_dim(long).is_err(), "longer vector must error");
704        let short = vec![0.0; dim.saturating_sub(1).max(1)];
705        assert!(validate_dim(short).is_err(), "shorter vector must error");
706        let exact = vec![0.0; dim];
707        assert_eq!(validate_dim(exact).expect("exact dim must pass").len(), dim);
708    }
709
710    #[test]
711    fn embedding_dim_matches_constants_source() {
712        assert_eq!(embedding_dim(), crate::constants::embedding_dim());
713    }
714
715    #[test]
716    fn build_batches_preserves_global_indices() {
717        let texts: Vec<String> = (0..10).map(|i| format!("t{i}")).collect();
718        let batches = build_batches(&texts, 4);
719        assert_eq!(batches.len(), 3);
720        assert_eq!(batches[0].len(), 4);
721        assert_eq!(batches[2].len(), 2);
722        assert_eq!(batches[2][1].0, 9);
723        assert_eq!(batches[2][1].1, "t9");
724    }
725
726    #[test]
727    fn effective_permits_clamps_to_bounds() {
728        assert!(effective_permits(0) >= 1);
729        assert!(effective_permits(1000) <= 32);
730    }
731
732    fn test_batches(n: usize) -> Vec<Vec<(usize, String)>> {
733        (0..n).map(|i| vec![(i, format!("t{i}"))]).collect()
734    }
735
736    fn dummy_vec(dim: usize) -> Vec<f32> {
737        vec![0.0; dim]
738    }
739
740    /// G42 acceptance criterion: with N permits the measured peak of
741    /// concurrent workers NEVER exceeds N, even with 10x more batches.
742    #[test]
743    fn concurrency_peak_never_exceeds_permits() {
744        let permits = 4usize;
745        let batches = test_batches(permits * 10);
746        let dim = crate::constants::embedding_dim();
747        let current = Arc::new(AtomicUsize::new(0));
748        let peak = Arc::new(AtomicUsize::new(0));
749
750        let current_c = Arc::clone(&current);
751        let peak_c = Arc::clone(&peak);
752        let work = move |batch: Vec<(usize, String)>| {
753            let current = Arc::clone(&current_c);
754            let peak = Arc::clone(&peak_c);
755            async move {
756                let now = current.fetch_add(1, Ordering::SeqCst) + 1;
757                peak.fetch_max(now, Ordering::SeqCst);
758                tokio::time::sleep(std::time::Duration::from_millis(20)).await;
759                current.fetch_sub(1, Ordering::SeqCst);
760                Ok(batch
761                    .into_iter()
762                    .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
763                    .collect())
764            }
765        };
766
767        let mut delivered = 0usize;
768        let rt = tokio::runtime::Builder::new_multi_thread()
769            .worker_threads(4)
770            .enable_all()
771            .build()
772            .expect("test runtime");
773        rt.block_on(run_bounded(
774            batches,
775            permits,
776            dim,
777            CancellationToken::new(),
778            work,
779            &mut |_idx, _v| {
780                delivered += 1;
781                Ok(())
782            },
783        ))
784        .expect("fan-out must succeed");
785
786        assert_eq!(delivered, permits * 10, "every item must be delivered");
787        assert!(
788            peak.load(Ordering::SeqCst) <= permits,
789            "peak concurrency {} exceeded permits {permits}",
790            peak.load(Ordering::SeqCst)
791        );
792    }
793
794    /// G42 acceptance criterion: a panicking task returns its permit via
795    /// RAII and surfaces as JoinError::is_panic, not a hang.
796    #[test]
797    fn panicking_task_returns_permit_and_surfaces_error() {
798        let permits = 2usize;
799        let batches = test_batches(4);
800        let dim = crate::constants::embedding_dim();
801
802        let work = move |batch: Vec<(usize, String)>| async move {
803            if batch[0].0 == 1 {
804                panic!("intentional test panic");
805            }
806            Ok(batch
807                .into_iter()
808                .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
809                .collect())
810        };
811
812        let rt = tokio::runtime::Builder::new_multi_thread()
813            .worker_threads(2)
814            .enable_all()
815            .build()
816            .expect("test runtime");
817        let result = rt.block_on(run_bounded(
818            batches,
819            permits,
820            dim,
821            CancellationToken::new(),
822            work,
823            &mut |_idx, _v| Ok(()),
824        ));
825
826        let err = result.expect_err("panic must surface as an error");
827        assert!(
828            err.to_string().contains("panicked"),
829            "error must mention the panic: {err}"
830        );
831    }
832
833    /// G42 acceptance criterion: cancellation aborts in-flight work and
834    /// the fan-out terminates within the shutdown timeout.
835    #[test]
836    fn cancellation_terminates_fan_out_quickly() {
837        let permits = 2usize;
838        let batches = test_batches(8);
839        let dim = crate::constants::embedding_dim();
840        let token = CancellationToken::new();
841
842        let work = move |batch: Vec<(usize, String)>| async move {
843            // Long enough that only cancellation can finish the test fast.
844            tokio::time::sleep(std::time::Duration::from_secs(30)).await;
845            Ok(batch
846                .into_iter()
847                .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
848                .collect())
849        };
850
851        let rt = tokio::runtime::Builder::new_multi_thread()
852            .worker_threads(2)
853            .enable_all()
854            .build()
855            .expect("test runtime");
856        let cancel = token.clone();
857        let start = std::time::Instant::now();
858        let result = rt.block_on(async move {
859            tokio::spawn(async move {
860                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
861                cancel.cancel();
862            });
863            run_bounded(batches, permits, dim, token, work, &mut |_idx, _v| Ok(())).await
864        });
865
866        assert!(result.is_err(), "cancelled fan-out must report an error");
867        assert!(
868            start.elapsed() < std::time::Duration::from_secs(10),
869            "graceful shutdown must finish well under the work duration"
870        );
871    }
872
873    /// G42 acceptance criterion: a divergent dim coming out of the work
874    /// stage fails the fan-out instead of being silently accepted.
875    #[test]
876    fn fan_out_rejects_divergent_dim() {
877        let permits = 2usize;
878        let batches = test_batches(2);
879        let dim = crate::constants::embedding_dim();
880
881        let work = move |batch: Vec<(usize, String)>| async move {
882            Ok(batch
883                .into_iter()
884                .map(|(i, _)| (i, vec![0.0f32; 3]))
885                .collect::<Vec<(usize, Vec<f32>)>>())
886        };
887
888        let rt = tokio::runtime::Builder::new_multi_thread()
889            .worker_threads(2)
890            .enable_all()
891            .build()
892            .expect("test runtime");
893        let result = rt.block_on(run_bounded(
894            batches,
895            permits,
896            dim,
897            CancellationToken::new(),
898            work,
899            &mut |_idx, _v| Ok(()),
900        ));
901
902        let err = result.expect_err("divergent dim must fail the fan-out");
903        assert!(err.to_string().contains("G42/C5"), "error cites C5: {err}");
904    }
905
906    /// G44: the calibration bases stay intact at the calibration dim.
907    #[test]
908    fn adaptive_batch_dim64_keeps_calibrated_sizes() {
909        assert_eq!(adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, 64), 8);
910        assert_eq!(adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, 64), 25);
911    }
912
913    /// G44: legacy 384-dim databases shrink to reliable batch sizes.
914    #[test]
915    fn adaptive_batch_dim384_shrinks() {
916        assert_eq!(adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, 384), 1);
917        assert_eq!(adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, 384), 4);
918    }
919
920    /// G44: intermediate dims scale proportionally to the float budget.
921    #[test]
922    fn adaptive_batch_intermediate_dims() {
923        assert_eq!(adaptive_batch_for_dim(8, 128), 4);
924        assert_eq!(adaptive_batch_for_dim(8, 256), 2);
925    }
926
927    /// G44: dims below the calibration dim never exceed the base.
928    #[test]
929    fn adaptive_batch_small_dim_clamps_to_base() {
930        assert_eq!(adaptive_batch_for_dim(8, 8), 8);
931    }
932
933    /// G44: the function is total — no division by zero, no clamp panic.
934    #[test]
935    fn adaptive_batch_total_function() {
936        assert_eq!(adaptive_batch_for_dim(8, 4096), 1);
937        assert_eq!(adaptive_batch_for_dim(8, 0), 8);
938        assert_eq!(adaptive_batch_for_dim(0, 64), 1);
939    }
940
941    /// G44 end-to-end: the public wrappers follow the env-dim override.
942    #[test]
943    #[serial_test::serial(env)]
944    fn adaptive_wrappers_follow_env_dim() {
945        std::env::set_var("SQLITE_GRAPHRAG_EMBEDDING_DIM", "384");
946        let chunk = chunk_embed_batch_size();
947        let entity = entity_embed_batch_size();
948        std::env::remove_var("SQLITE_GRAPHRAG_EMBEDDING_DIM");
949        crate::constants::set_active_embedding_dim(crate::constants::DEFAULT_EMBEDDING_DIM);
950        assert_eq!(chunk, 1, "384-dim chunk batch must shrink to 1 (G44)");
951        assert_eq!(entity, 4, "384-dim entity batch must shrink to 4 (G44)");
952    }
953
954    // ---------------------------------------------------------------
955    // G58/S1: FallbackReason + try_embed_query_with_fallback tests
956    // ---------------------------------------------------------------
957
958    /// Display impl covers all three variants without panicking.
959    #[test]
960    fn fallback_reason_display_does_not_panic() {
961        let _ = FallbackReason::EmbeddingFailed("rate limit".into()).to_string();
962        let _ = FallbackReason::Cancelled.to_string();
963        let _ = FallbackReason::Timeout {
964            operation: "embed_query".into(),
965            duration_secs: 30,
966        }
967        .to_string();
968    }
969
970    /// FallbackReason is PartialEq — used in test assertions to verify
971    /// the mapping rules.
972    #[test]
973    fn fallback_reason_is_partial_eq() {
974        assert_eq!(
975            FallbackReason::EmbeddingFailed("a".into()),
976            FallbackReason::EmbeddingFailed("a".into())
977        );
978        assert_eq!(FallbackReason::Cancelled, FallbackReason::Cancelled);
979        assert_ne!(
980            FallbackReason::EmbeddingFailed("a".into()),
981            FallbackReason::EmbeddingFailed("b".into())
982        );
983        assert_ne!(
984            FallbackReason::Cancelled,
985            FallbackReason::Timeout {
986                operation: "x".into(),
987                duration_secs: 1
988            }
989        );
990    }
991
992    /// Timeout variant preserves the operation name and duration from the
993    /// original AppError::Timeout for observability.
994    #[test]
995    fn fallback_reason_timeout_preserves_fields() {
996        let r = FallbackReason::Timeout {
997            operation: "embed_query_local".into(),
998            duration_secs: 300,
999        };
1000        match r {
1001            FallbackReason::Timeout {
1002                operation,
1003                duration_secs,
1004            } => {
1005                assert_eq!(operation, "embed_query_local");
1006                assert_eq!(duration_secs, 300);
1007            }
1008            other => panic!("expected Timeout, got {other:?}"),
1009        }
1010    }
1011
1012    /// try_embed_query_with_fallback surfaces an EmbeddingFailed variant
1013    /// when the LLM subprocess errors. Uses a path that surely does not
1014    /// contain any embedder configuration (the binary is invoked as
1015    /// `codex` / `claude` via PATH which, in tests, defaults to nothing
1016    /// in scope, so `LlmEmbedding::detect_available()` returns Err).
1017    #[test]
1018    #[ignore = "G58 S1 stub: requires env without codex/claude on PATH; tracked as T5 of Fase 2"]
1019    fn try_embed_query_with_fallback_surfaces_embedding_failed_for_missing_binary() {
1020        // Pointing at a models dir that does not exist forces the embedder
1021        // init to fail; the error is mapped to EmbeddingFailed.
1022        let bogus = std::path::Path::new("/nonexistent-models-dir-for-g58-fallback-test");
1023        let result = try_embed_query_with_fallback(bogus, "hello world");
1024        match result {
1025            Err(FallbackReason::EmbeddingFailed(msg)) => {
1026                // The original error must survive in the message for ops triage.
1027                assert!(!msg.is_empty(), "fallback message must not be empty");
1028            }
1029            Err(FallbackReason::Cancelled) => {
1030                panic!("expected EmbeddingFailed, got Cancelled");
1031            }
1032            Err(FallbackReason::Timeout { .. }) => {
1033                panic!("expected EmbeddingFailed, got Timeout");
1034            }
1035            Ok(_) => {
1036                panic!("expected an error, got Ok — embedder must fail for bogus path");
1037            }
1038        }
1039    }
1040
1041    // G56: entity embed cache — unit tests
1042    #[test]
1043    fn g56_entity_cache_key_is_stable_and_distinct() {
1044        let k1 = entity_cache_key("codex:default", "sqlite-graphrag");
1045        let k2 = entity_cache_key("codex:default", "sqlite-graphrag");
1046        let k3 = entity_cache_key("codex:default", "claude-code");
1047        let k4 = entity_cache_key("claude:default", "sqlite-graphrag");
1048        assert_eq!(k1, k2, "same model+text must hash identically");
1049        assert_ne!(k1, k3, "different text must hash differently");
1050        assert_ne!(k1, k4, "different model must hash differently");
1051    }
1052
1053    #[test]
1054    fn g56_entity_embed_cache_stats_hit_rate() {
1055        let zero = EmbedCacheStats::default();
1056        assert_eq!(zero.hit_rate(), 0.0);
1057        let half = EmbedCacheStats {
1058            requested: 4,
1059            hits: 2,
1060            misses: 2,
1061        };
1062        assert!((half.hit_rate() - 0.5).abs() < 1e-9);
1063        let all = EmbedCacheStats {
1064            requested: 7,
1065            hits: 7,
1066            misses: 0,
1067        };
1068        assert!((all.hit_rate() - 1.0).abs() < 1e-9);
1069    }
1070
1071    #[test]
1072    fn g56_entity_embed_cache_populates_and_hits() {
1073        // Manually populate the cache: bypasses the LLM by writing a
1074        // known vector under a chosen (model, text) key, then verifies
1075        // the cache is consulted before any LLM call would happen.
1076        let cache = entity_embed_cache();
1077        let model = "test-model";
1078        let text = "sqlite-graphrag";
1079        let key = entity_cache_key(model, text);
1080        let stored = Arc::new(vec![0.42_f32; crate::constants::embedding_dim()]);
1081        cache.lock().insert(key, Arc::clone(&stored));
1082        let guard = cache.lock();
1083        let hit = guard.get(&key).expect("cache must return stored value");
1084        assert_eq!(hit.len(), crate::constants::embedding_dim());
1085        assert!((hit[0] - 0.42).abs() < 1e-6);
1086    }
1087
1088    #[test]
1089    fn g56_empty_texts_short_circuits_with_zero_stats() {
1090        // Cannot call embed_entity_texts_cached without an LLM on PATH,
1091        // so we only verify the empty-input contract via the stats struct.
1092        let stats = EmbedCacheStats::default();
1093        assert_eq!(stats.requested, 0);
1094        assert_eq!(stats.hits, 0);
1095        assert_eq!(stats.misses, 0);
1096        assert_eq!(stats.hit_rate(), 0.0);
1097    }
1098}