Skip to main content

sqlite_graphrag/
embedder.rs

1//! Embedding generation for the GraphRAG memory.
2//!
3//! v1.0.76: the default build is **LLM-only** — the binary does NOT bundle
4//! fastembed / ort / ndarray / tokenizers. All embeddings are produced
5//! by a headless invocation of `claude code` or `codex` (OAuth, no MCP,
6//! no hooks) and stored as a BLOB in `memory_embeddings(memory_id, embedding,
7//! source)`. Vector similarity is computed in pure Rust at query time.
8//!
9//! # Workload classification (G42/S3, BLOCO 1 — OBRIGATÓRIA)
10//!
11//! LLM embedding is **I/O-bound + subprocess-bound**: each call waits
12//! 5-60s on a network round-trip through a headless `claude -p` /
13//! `codex exec` subprocess while the local CPU stays idle. Concurrency
14//! therefore uses **tokio** (async I/O concurrency) and NEVER rayon
15//! (reserved for CPU-bound work).
16//!
17//! # Permit formula (G42/S3, BLOCO 2)
18//!
19//! ```text
20//! permits = clamp(--llm-parallelism, 1, 32)
21//!           .min(available_parallelism())
22//!           .min(available_ram_mb * 0.5 / LLM_WORKER_RSS_MB)
23//! ```
24//!
25//! `LLM_WORKER_RSS_MB = 350` (`crate::constants`): `claude -p` and
26//! `codex exec` are node processes with a typical Maximum RSS of
27//! 200-400 MB (measured via `/usr/bin/time -l` on macOS /
28//! `/usr/bin/time -v` on Linux), so the RAM bound is pertinent.
29//!
30//! # Locking contract (G42/A3 fix)
31//!
32//! The process-wide `Mutex<LlmEmbedding>` protects ONLY the cheap clone
33//! of the client configuration (flavour + binary path + model + shared
34//! schema tempfiles). It is NEVER held across network I/O — the
35//! v1.0.76-v1.0.78 `flush_group` held it for the whole sequential
36//! embedding loop, which is why `--llm-parallelism 8` measured an
37//! effective parallelism of 1.
38
39use crate::errors::AppError;
40use crate::extract::llm_embedding::LlmEmbedding;
41use parking_lot::Mutex;
42use std::path::Path;
43use std::sync::Arc;
44use std::sync::OnceLock;
45use tokio::sync::{mpsc, Semaphore};
46use tokio::task::JoinSet;
47use tokio_util::sync::CancellationToken;
48
49/// Process-wide LLM-embedding client behind a `Mutex`.
50///
51/// The lock guards configuration cloning only (see module docs); the
52/// actual LLM I/O happens on clones, outside the lock.
53static EMBEDDER: OnceLock<Mutex<LlmEmbedding>> = OnceLock::new();
54
55/// Process-wide multi-thread tokio runtime for embedding I/O.
56///
57/// G42/A2 fix: v1.0.76-v1.0.78 built a current-thread runtime PER CALL.
58/// One runtime per process amortises the setup and hosts the bounded
59/// fan-out of `embed_texts_parallel`.
60static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
61
62/// Calibration base: chunk (long-text) batch size per LLM call at the
63/// calibration dimensionality (G42/S2). Use [`chunk_embed_batch_size`]
64/// for the dim-adaptive value (G44).
65pub const CHUNK_EMBED_BATCH_SIZE: usize = 8;
66
67/// Calibration base: entity-name (short-text) batch size per LLM call at
68/// the calibration dimensionality (G42/S2). Use [`entity_embed_batch_size`]
69/// for the dim-adaptive value (G44).
70pub const ENTITY_EMBED_BATCH_SIZE: usize = 25;
71
72/// Dimensionality the batch bases above were calibrated against (G44).
73pub const EMBED_BATCH_CALIBRATION_DIM: usize = 64;
74
75/// G44: scales a calibration-base batch size to the active dimensionality,
76/// keeping the float budget per LLM call constant (~512 floats for chunks,
77/// ~1600 for entity names — the budgets empirically validated at dim 64).
78/// Fixed batches of 8 at 384 dims asked for ~3072 floats per response:
79/// claude returned partial coverage (3 of 8 items, caught by the G42/C5
80/// check) and codex timed out at 300s. `base.max(1)` keeps the function
81/// total — `clamp` panics when the upper bound is below the lower one.
82fn adaptive_batch_for_dim(base: usize, dim: usize) -> usize {
83    let base = base.max(1);
84    (base * EMBED_BATCH_CALIBRATION_DIM / dim.max(1)).clamp(1, base)
85}
86
87/// Dim-adaptive batch size for chunk (long-text) embedding calls (G44).
88pub fn chunk_embed_batch_size() -> usize {
89    let dim = crate::constants::embedding_dim();
90    let batch = adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, dim);
91    tracing::debug!(
92        dim,
93        base = CHUNK_EMBED_BATCH_SIZE,
94        batch,
95        "adaptive chunk batch size (G44)"
96    );
97    batch
98}
99
100/// Dim-adaptive batch size for entity-name (short-text) embedding calls (G44).
101pub fn entity_embed_batch_size() -> usize {
102    let dim = crate::constants::embedding_dim();
103    let batch = adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, dim);
104    tracing::debug!(
105        dim,
106        base = ENTITY_EMBED_BATCH_SIZE,
107        batch,
108        "adaptive entity batch size (G44)"
109    );
110    batch
111}
112
113/// Returns the process-wide multi-thread runtime, building it on first use.
114pub(crate) fn shared_runtime() -> Result<&'static tokio::runtime::Runtime, AppError> {
115    if let Some(rt) = RUNTIME.get() {
116        return Ok(rt);
117    }
118    let rt = tokio::runtime::Builder::new_multi_thread()
119        .worker_threads(2)
120        .enable_all()
121        .build()
122        .map_err(|e| AppError::Embedding(format!("tokio runtime init failed: {e}")))?;
123    let _ = RUNTIME.set(rt);
124    Ok(RUNTIME.get().expect("RUNTIME initialised above"))
125}
126
127/// Initialises the LLM-embedding client on first use and returns it.
128pub fn get_embedder(_models_dir: &Path) -> Result<&'static Mutex<LlmEmbedding>, AppError> {
129    if let Some(e) = EMBEDDER.get() {
130        return Ok(e);
131    }
132    let backend = LlmEmbedding::detect_available()?;
133    let _ = EMBEDDER.set(Mutex::new(backend));
134    Ok(EMBEDDER.get().expect("EMBEDDER initialised above"))
135}
136
137/// Clones the embedding-client configuration. The lock is held only for
138/// the duration of the clone — NEVER across I/O (G42/A3).
139fn clone_client(embedder: &Mutex<LlmEmbedding>) -> LlmEmbedding {
140    embedder.lock().clone()
141}
142
143/// Embeds a single passage for storage. Delegates to the configured LLM
144/// headless (claude code / codex). Returns a vector of the active
145/// dimensionality.
146pub fn embed_passage(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
147    let client = clone_client(embedder);
148    let result = client.embed_passage(text)?;
149    validate_dim(result)
150}
151
152/// Embeds a single query for similarity search. Same model and dim as
153/// `embed_passage`; the only difference is the LLM-side prompt prefix
154/// that the headless invocation uses to disambiguate.
155pub fn embed_query(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
156    let client = clone_client(embedder);
157    let result = client.embed_query(text)?;
158    validate_dim(result)
159}
160
161/// Embeds a batch of passages with token-count-aware batching.
162///
163/// Kept for API compatibility; since v1.0.79 it routes through the
164/// bounded parallel fan-out with conservative defaults.
165pub fn embed_passages_controlled(
166    embedder: &Mutex<LlmEmbedding>,
167    texts: &[&str],
168    _token_counts: &[usize],
169) -> Result<Vec<Vec<f32>>, AppError> {
170    if texts.is_empty() {
171        return Ok(Vec::new());
172    }
173    let owned: Vec<String> = texts.iter().map(|t| t.to_string()).collect();
174    embed_texts_parallel(embedder, &owned, 1, chunk_embed_batch_size())
175}
176
177pub fn embed_passage_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
178    let embedder = get_embedder(models_dir)?;
179    embed_passage(embedder, text)
180}
181
182pub fn embed_query_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
183    let embedder = get_embedder(models_dir)?;
184    embed_query(embedder, text)
185}
186
187pub fn embed_passages_controlled_local(
188    models_dir: &Path,
189    texts: &[&str],
190    token_counts: &[usize],
191) -> Result<Vec<Vec<f32>>, AppError> {
192    let embedder = get_embedder(models_dir)?;
193    embed_passages_controlled(embedder, texts, token_counts)
194}
195
196/// G42/S3: embeds `texts` through the bounded parallel fan-out and
197/// returns vectors in input order.
198pub fn embed_passages_parallel_local(
199    models_dir: &Path,
200    texts: &[String],
201    parallelism: usize,
202    batch_size: usize,
203) -> Result<Vec<Vec<f32>>, AppError> {
204    let embedder = get_embedder(models_dir)?;
205    embed_texts_parallel(embedder, texts, parallelism, batch_size)
206}
207
208/// G42/S3 core: bounded parallel batch embedding.
209///
210/// - texts are grouped into batches of `batch_size` (one LLM call per
211///   batch, G42/S2);
212/// - at most `effective_permits(parallelism)` LLM subprocesses run
213///   simultaneously (`Arc<Semaphore>` + `acquire_owned`, BLOCO 2);
214/// - results stream through a BOUNDED mpsc channel so the caller-side
215///   collector applies backpressure and can persist incrementally
216///   (BLOCO 5);
217/// - the global `CancellationToken` aborts in-flight work on the first
218///   signal; subprocesses die with their futures via `kill_on_drop`
219///   (BLOCO 6).
220pub fn embed_texts_parallel(
221    embedder: &Mutex<LlmEmbedding>,
222    texts: &[String],
223    parallelism: usize,
224    batch_size: usize,
225) -> Result<Vec<Vec<f32>>, AppError> {
226    let mut slots: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
227    embed_texts_parallel_with(embedder, texts, parallelism, batch_size, |idx, v| {
228        slots[idx] = Some(v.to_vec());
229        Ok(())
230    })?;
231    let mut out = Vec::with_capacity(slots.len());
232    for (idx, slot) in slots.into_iter().enumerate() {
233        out.push(slot.ok_or_else(|| {
234            AppError::Embedding(format!("embedding fan-out lost item index {idx}"))
235        })?);
236    }
237    Ok(out)
238}
239
240/// Like [`embed_texts_parallel`] but invokes `on_result` as soon as each
241/// embedding arrives (BLOCO 5: incremental persistence — a kill loses at
242/// most the in-flight batches, never the already-delivered items).
243pub fn embed_texts_parallel_with(
244    embedder: &Mutex<LlmEmbedding>,
245    texts: &[String],
246    parallelism: usize,
247    batch_size: usize,
248    mut on_result: impl FnMut(usize, &[f32]) -> Result<(), AppError>,
249) -> Result<(), AppError> {
250    if texts.is_empty() {
251        return Ok(());
252    }
253    let dim = crate::constants::embedding_dim();
254    if texts.len() == 1 {
255        let v = embed_passage(embedder, &texts[0])?;
256        return on_result(0, &v);
257    }
258
259    let client = clone_client(embedder);
260    let permits = effective_permits(parallelism);
261    let batches = build_batches(texts, batch_size.max(1));
262    let token = crate::cancel_token().clone();
263
264    let work = move |batch: Vec<(usize, String)>| {
265        let client = client.clone();
266        async move {
267            client
268                .embed_batch_async(crate::constants::PASSAGE_PREFIX, &batch)
269                .await
270        }
271    };
272
273    let fan_out = run_bounded(batches, permits, dim, token, work, &mut on_result);
274    match tokio::runtime::Handle::try_current() {
275        Ok(handle) => tokio::task::block_in_place(|| handle.block_on(fan_out)),
276        Err(_) => shared_runtime()?.block_on(fan_out),
277    }
278}
279
280/// Groups `(global_index, text)` pairs into batches of `batch_size`.
281fn build_batches(texts: &[String], batch_size: usize) -> Vec<Vec<(usize, String)>> {
282    texts
283        .iter()
284        .cloned()
285        .enumerate()
286        .collect::<Vec<_>>()
287        .chunks(batch_size)
288        .map(|c| c.to_vec())
289        .collect()
290}
291
292/// G42/S3 BLOCO 2: effective permit count.
293///
294/// `permits = clamp(requested, 1, 32) ∧ cpus ∧ ram_livre*0.5/RSS` — see
295/// the module docs for the measured RSS rationale.
296pub fn effective_permits(requested: usize) -> usize {
297    let cpus = std::thread::available_parallelism()
298        .map(|n| n.get())
299        .unwrap_or(4);
300    let by_ram = ((crate::memory_guard::available_memory_mb() / 2)
301        / crate::constants::LLM_WORKER_RSS_MB)
302        .max(1) as usize;
303    requested.clamp(1, 32).min(cpus).min(by_ram).max(1)
304}
305
306/// Bounded fan-out engine. Generic over the per-batch work so the
307/// concurrency contract is testable without spawning real LLMs.
308///
309/// Cancel safety (BLOCO 6/10): every task races its work against
310/// `token.cancelled()` inside `tokio::select!`; both branches are
311/// cancel-safe (the work future owns its subprocess via `kill_on_drop`,
312/// and `cancelled()` is pure). On collector-side errors the `JoinSet`
313/// is shut down, which drops in-flight futures and kills their
314/// subprocesses.
315async fn run_bounded<F, Fut>(
316    batches: Vec<Vec<(usize, String)>>,
317    permits: usize,
318    dim: usize,
319    token: CancellationToken,
320    work: F,
321    on_result: &mut impl FnMut(usize, &[f32]) -> Result<(), AppError>,
322) -> Result<(), AppError>
323where
324    F: Fn(Vec<(usize, String)>) -> Fut + Clone + Send + 'static,
325    Fut: std::future::Future<Output = Result<Vec<(usize, Vec<f32>)>, AppError>> + Send,
326{
327    let total_batches = batches.len();
328    let semaphore = Arc::new(Semaphore::new(permits));
329    // BLOCO 5: bounded channel — producers block when the collector is
330    // behind (backpressure); PROIBIDO unbounded_channel between stages.
331    let (tx, mut rx) = mpsc::channel::<Result<Vec<(usize, Vec<f32>)>, AppError>>(permits * 2);
332    let mut set: JoinSet<()> = JoinSet::new();
333
334    for (batch_idx, batch) in batches.into_iter().enumerate() {
335        let sem = Arc::clone(&semaphore);
336        let token = token.clone();
337        let tx = tx.clone();
338        let work = work.clone();
339        set.spawn(async move {
340            let wait_start = std::time::Instant::now();
341            // acquire_owned: RAII permit moved into the task; returned
342            // on every exit path INCLUDING panic (BLOCO 2).
343            let Ok(_permit) = sem.acquire_owned().await else {
344                let _ = tx
345                    .send(Err(AppError::Embedding("semaphore closed".to_string())))
346                    .await;
347                return;
348            };
349            let permit_wait_ms = wait_start.elapsed().as_millis() as u64;
350            let work_start = std::time::Instant::now();
351            let outcome = tokio::select! {
352                res = work(batch) => res,
353                _ = token.cancelled() => Err(AppError::Embedding(
354                    "embedding cancelled by shutdown signal".to_string(),
355                )),
356            };
357            // BLOCO 8: permit wait time logged SEPARATELY from work time.
358            tracing::debug!(
359                target: "embedding",
360                batch_idx,
361                permit_wait_ms,
362                work_ms = work_start.elapsed().as_millis() as u64,
363                ok = outcome.is_ok(),
364                "embedding batch finished"
365            );
366            let _ = tx.send(outcome).await;
367        });
368    }
369    drop(tx);
370
371    let mut completed = 0usize;
372    let mut failed = 0usize;
373    let mut cancelled = 0usize;
374    let mut first_error: Option<AppError> = None;
375
376    while let Some(message) = rx.recv().await {
377        match message {
378            Ok(items) => {
379                completed += 1;
380                if first_error.is_none() {
381                    for (idx, v) in items {
382                        if v.len() != dim {
383                            first_error = Some(AppError::Embedding(format!(
384                                "LLM returned {} dims for item {idx}, expected {dim}; \
385                                 refusing to truncate or pad silently (G42/C5)",
386                                v.len()
387                            )));
388                            break;
389                        }
390                        if let Err(e) = on_result(idx, &v) {
391                            first_error = Some(e);
392                            break;
393                        }
394                    }
395                    if first_error.is_some() {
396                        // Abort remaining work: dropped futures kill
397                        // their subprocesses via kill_on_drop (BLOCO 6).
398                        set.shutdown().await;
399                    }
400                }
401            }
402            Err(e) => {
403                if matches!(&e, AppError::Embedding(msg) if msg.contains("cancelled")) {
404                    cancelled += 1;
405                } else {
406                    failed += 1;
407                }
408                if first_error.is_none() {
409                    first_error = Some(e);
410                    set.shutdown().await;
411                }
412            }
413        }
414    }
415
416    // Drain the JoinSet: surface panics distinctly (panic handling —
417    // JoinError::is_panic tratado em todo join_next, BLOCO 9).
418    while let Some(join_result) = set.join_next().await {
419        if let Err(join_err) = join_result {
420            if join_err.is_panic() {
421                failed += 1;
422                if first_error.is_none() {
423                    first_error = Some(AppError::Embedding(format!(
424                        "embedding task panicked: {join_err}"
425                    )));
426                }
427            } else {
428                cancelled += 1;
429            }
430        }
431    }
432
433    // BLOCO 8: saturation observability — available_permits plus the
434    // completed/failed/cancelled counters on the progress stream.
435    tracing::info!(
436        target: "embedding",
437        total_batches,
438        completed,
439        failed,
440        cancelled,
441        available_permits = semaphore.available_permits(),
442        "embedding fan-out finished"
443    );
444
445    match first_error {
446        Some(e) => Err(e),
447        None => Ok(()),
448    }
449}
450
451pub fn f32_to_bytes(v: &[f32]) -> Vec<u8> {
452    let mut out = Vec::with_capacity(v.len() * 4);
453    for f in v {
454        out.extend_from_slice(&f.to_le_bytes());
455    }
456    out
457}
458
459pub fn bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
460    let mut out = Vec::with_capacity(bytes.len() / 4);
461    for chunk in bytes.chunks_exact(4) {
462        out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
463    }
464    out
465}
466
467/// Returns the dimensionality of the embedding space. Used to
468/// validate LLM responses and to size the in-memory cache.
469pub fn embedding_dim() -> usize {
470    crate::constants::embedding_dim()
471}
472
473/// G42/C5: a vector with a divergent dimensionality is an ERROR, never
474/// silently truncated or zero-padded (the pre-v1.0.79 `normalise_dim`
475/// masked malformed LLM responses).
476fn validate_dim(v: Vec<f32>) -> Result<Vec<f32>, AppError> {
477    let dim = crate::constants::embedding_dim();
478    if v.len() != dim {
479        return Err(AppError::Embedding(format!(
480            "embedding has {} dims, expected {dim}; \
481             refusing to truncate or pad silently (G42/C5)",
482            v.len()
483        )));
484    }
485    Ok(v)
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491    use std::sync::atomic::{AtomicUsize, Ordering};
492
493    #[test]
494    fn f32_to_bytes_roundtrip() {
495        let input = vec![0.0_f32, 1.5, -2.25, f32::MIN, f32::MAX];
496        let bytes = f32_to_bytes(&input);
497        assert_eq!(bytes.len(), input.len() * 4);
498        let out = bytes_to_f32(&bytes);
499        assert_eq!(out, input);
500    }
501
502    #[test]
503    fn validate_dim_rejects_divergent_vectors() {
504        // G42/C5 acceptance criterion: a divergent vector MUST fail —
505        // never be silently normalised.
506        let dim = crate::constants::embedding_dim();
507        let long = vec![0.0; dim + 10];
508        assert!(validate_dim(long).is_err(), "longer vector must error");
509        let short = vec![0.0; dim.saturating_sub(1).max(1)];
510        assert!(validate_dim(short).is_err(), "shorter vector must error");
511        let exact = vec![0.0; dim];
512        assert_eq!(validate_dim(exact).expect("exact dim must pass").len(), dim);
513    }
514
515    #[test]
516    fn embedding_dim_matches_constants_source() {
517        assert_eq!(embedding_dim(), crate::constants::embedding_dim());
518    }
519
520    #[test]
521    fn build_batches_preserves_global_indices() {
522        let texts: Vec<String> = (0..10).map(|i| format!("t{i}")).collect();
523        let batches = build_batches(&texts, 4);
524        assert_eq!(batches.len(), 3);
525        assert_eq!(batches[0].len(), 4);
526        assert_eq!(batches[2].len(), 2);
527        assert_eq!(batches[2][1].0, 9);
528        assert_eq!(batches[2][1].1, "t9");
529    }
530
531    #[test]
532    fn effective_permits_clamps_to_bounds() {
533        assert!(effective_permits(0) >= 1);
534        assert!(effective_permits(1000) <= 32);
535    }
536
537    fn test_batches(n: usize) -> Vec<Vec<(usize, String)>> {
538        (0..n).map(|i| vec![(i, format!("t{i}"))]).collect()
539    }
540
541    fn dummy_vec(dim: usize) -> Vec<f32> {
542        vec![0.0; dim]
543    }
544
545    /// G42 acceptance criterion: with N permits the measured peak of
546    /// concurrent workers NEVER exceeds N, even with 10x more batches.
547    #[test]
548    fn concurrency_peak_never_exceeds_permits() {
549        let permits = 4usize;
550        let batches = test_batches(permits * 10);
551        let dim = crate::constants::embedding_dim();
552        let current = Arc::new(AtomicUsize::new(0));
553        let peak = Arc::new(AtomicUsize::new(0));
554
555        let current_c = Arc::clone(&current);
556        let peak_c = Arc::clone(&peak);
557        let work = move |batch: Vec<(usize, String)>| {
558            let current = Arc::clone(&current_c);
559            let peak = Arc::clone(&peak_c);
560            async move {
561                let now = current.fetch_add(1, Ordering::SeqCst) + 1;
562                peak.fetch_max(now, Ordering::SeqCst);
563                tokio::time::sleep(std::time::Duration::from_millis(20)).await;
564                current.fetch_sub(1, Ordering::SeqCst);
565                Ok(batch
566                    .into_iter()
567                    .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
568                    .collect())
569            }
570        };
571
572        let mut delivered = 0usize;
573        let rt = tokio::runtime::Builder::new_multi_thread()
574            .worker_threads(4)
575            .enable_all()
576            .build()
577            .expect("test runtime");
578        rt.block_on(run_bounded(
579            batches,
580            permits,
581            dim,
582            CancellationToken::new(),
583            work,
584            &mut |_idx, _v| {
585                delivered += 1;
586                Ok(())
587            },
588        ))
589        .expect("fan-out must succeed");
590
591        assert_eq!(delivered, permits * 10, "every item must be delivered");
592        assert!(
593            peak.load(Ordering::SeqCst) <= permits,
594            "peak concurrency {} exceeded permits {permits}",
595            peak.load(Ordering::SeqCst)
596        );
597    }
598
599    /// G42 acceptance criterion: a panicking task returns its permit via
600    /// RAII and surfaces as JoinError::is_panic, not a hang.
601    #[test]
602    fn panicking_task_returns_permit_and_surfaces_error() {
603        let permits = 2usize;
604        let batches = test_batches(4);
605        let dim = crate::constants::embedding_dim();
606
607        let work = move |batch: Vec<(usize, String)>| async move {
608            if batch[0].0 == 1 {
609                panic!("intentional test panic");
610            }
611            Ok(batch
612                .into_iter()
613                .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
614                .collect())
615        };
616
617        let rt = tokio::runtime::Builder::new_multi_thread()
618            .worker_threads(2)
619            .enable_all()
620            .build()
621            .expect("test runtime");
622        let result = rt.block_on(run_bounded(
623            batches,
624            permits,
625            dim,
626            CancellationToken::new(),
627            work,
628            &mut |_idx, _v| Ok(()),
629        ));
630
631        let err = result.expect_err("panic must surface as an error");
632        assert!(
633            err.to_string().contains("panicked"),
634            "error must mention the panic: {err}"
635        );
636    }
637
638    /// G42 acceptance criterion: cancellation aborts in-flight work and
639    /// the fan-out terminates within the shutdown timeout.
640    #[test]
641    fn cancellation_terminates_fan_out_quickly() {
642        let permits = 2usize;
643        let batches = test_batches(8);
644        let dim = crate::constants::embedding_dim();
645        let token = CancellationToken::new();
646
647        let work = move |batch: Vec<(usize, String)>| async move {
648            // Long enough that only cancellation can finish the test fast.
649            tokio::time::sleep(std::time::Duration::from_secs(30)).await;
650            Ok(batch
651                .into_iter()
652                .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
653                .collect())
654        };
655
656        let rt = tokio::runtime::Builder::new_multi_thread()
657            .worker_threads(2)
658            .enable_all()
659            .build()
660            .expect("test runtime");
661        let cancel = token.clone();
662        let start = std::time::Instant::now();
663        let result = rt.block_on(async move {
664            tokio::spawn(async move {
665                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
666                cancel.cancel();
667            });
668            run_bounded(batches, permits, dim, token, work, &mut |_idx, _v| Ok(())).await
669        });
670
671        assert!(result.is_err(), "cancelled fan-out must report an error");
672        assert!(
673            start.elapsed() < std::time::Duration::from_secs(10),
674            "graceful shutdown must finish well under the work duration"
675        );
676    }
677
678    /// G42 acceptance criterion: a divergent dim coming out of the work
679    /// stage fails the fan-out instead of being silently accepted.
680    #[test]
681    fn fan_out_rejects_divergent_dim() {
682        let permits = 2usize;
683        let batches = test_batches(2);
684        let dim = crate::constants::embedding_dim();
685
686        let work = move |batch: Vec<(usize, String)>| async move {
687            Ok(batch
688                .into_iter()
689                .map(|(i, _)| (i, vec![0.0f32; 3]))
690                .collect::<Vec<(usize, Vec<f32>)>>())
691        };
692
693        let rt = tokio::runtime::Builder::new_multi_thread()
694            .worker_threads(2)
695            .enable_all()
696            .build()
697            .expect("test runtime");
698        let result = rt.block_on(run_bounded(
699            batches,
700            permits,
701            dim,
702            CancellationToken::new(),
703            work,
704            &mut |_idx, _v| Ok(()),
705        ));
706
707        let err = result.expect_err("divergent dim must fail the fan-out");
708        assert!(err.to_string().contains("G42/C5"), "error cites C5: {err}");
709    }
710
711    /// G44: the calibration bases stay intact at the calibration dim.
712    #[test]
713    fn adaptive_batch_dim64_keeps_calibrated_sizes() {
714        assert_eq!(adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, 64), 8);
715        assert_eq!(adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, 64), 25);
716    }
717
718    /// G44: legacy 384-dim databases shrink to reliable batch sizes.
719    #[test]
720    fn adaptive_batch_dim384_shrinks() {
721        assert_eq!(adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, 384), 1);
722        assert_eq!(adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, 384), 4);
723    }
724
725    /// G44: intermediate dims scale proportionally to the float budget.
726    #[test]
727    fn adaptive_batch_intermediate_dims() {
728        assert_eq!(adaptive_batch_for_dim(8, 128), 4);
729        assert_eq!(adaptive_batch_for_dim(8, 256), 2);
730    }
731
732    /// G44: dims below the calibration dim never exceed the base.
733    #[test]
734    fn adaptive_batch_small_dim_clamps_to_base() {
735        assert_eq!(adaptive_batch_for_dim(8, 8), 8);
736    }
737
738    /// G44: the function is total — no division by zero, no clamp panic.
739    #[test]
740    fn adaptive_batch_total_function() {
741        assert_eq!(adaptive_batch_for_dim(8, 4096), 1);
742        assert_eq!(adaptive_batch_for_dim(8, 0), 8);
743        assert_eq!(adaptive_batch_for_dim(0, 64), 1);
744    }
745
746    /// G44 end-to-end: the public wrappers follow the env-dim override.
747    #[test]
748    #[serial_test::serial(env)]
749    fn adaptive_wrappers_follow_env_dim() {
750        std::env::set_var("SQLITE_GRAPHRAG_EMBEDDING_DIM", "384");
751        let chunk = chunk_embed_batch_size();
752        let entity = entity_embed_batch_size();
753        std::env::remove_var("SQLITE_GRAPHRAG_EMBEDDING_DIM");
754        crate::constants::set_active_embedding_dim(crate::constants::DEFAULT_EMBEDDING_DIM);
755        assert_eq!(chunk, 1, "384-dim chunk batch must shrink to 1 (G44)");
756        assert_eq!(entity, 4, "384-dim entity batch must shrink to 4 (G44)");
757    }
758}