1use crate::errors::AppError;
40use crate::extract::llm_embedding::LlmEmbedding;
41use parking_lot::Mutex;
42use std::path::Path;
43use std::sync::Arc;
44use std::sync::OnceLock;
45use tokio::sync::{mpsc, Semaphore};
46use tokio::task::JoinSet;
47use tokio_util::sync::CancellationToken;
48
49static EMBEDDER: OnceLock<Mutex<LlmEmbedding>> = OnceLock::new();
54
55static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
61
62pub const CHUNK_EMBED_BATCH_SIZE: usize = 8;
66
67pub const ENTITY_EMBED_BATCH_SIZE: usize = 25;
71
72pub const EMBED_BATCH_CALIBRATION_DIM: usize = 64;
74
75fn adaptive_batch_for_dim(base: usize, dim: usize) -> usize {
83 let base = base.max(1);
84 (base * EMBED_BATCH_CALIBRATION_DIM / dim.max(1)).clamp(1, base)
85}
86
87pub fn chunk_embed_batch_size() -> usize {
89 let dim = crate::constants::embedding_dim();
90 let batch = adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, dim);
91 tracing::debug!(
92 dim,
93 base = CHUNK_EMBED_BATCH_SIZE,
94 batch,
95 "adaptive chunk batch size (G44)"
96 );
97 batch
98}
99
100pub fn entity_embed_batch_size() -> usize {
102 let dim = crate::constants::embedding_dim();
103 let batch = adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, dim);
104 tracing::debug!(
105 dim,
106 base = ENTITY_EMBED_BATCH_SIZE,
107 batch,
108 "adaptive entity batch size (G44)"
109 );
110 batch
111}
112
113pub(crate) fn shared_runtime() -> Result<&'static tokio::runtime::Runtime, AppError> {
115 if let Some(rt) = RUNTIME.get() {
116 return Ok(rt);
117 }
118 let rt = tokio::runtime::Builder::new_multi_thread()
119 .worker_threads(2)
120 .enable_all()
121 .build()
122 .map_err(|e| AppError::Embedding(format!("tokio runtime init failed: {e}")))?;
123 let _ = RUNTIME.set(rt);
124 Ok(RUNTIME.get().expect("RUNTIME initialised above"))
125}
126
127pub fn get_embedder(_models_dir: &Path) -> Result<&'static Mutex<LlmEmbedding>, AppError> {
129 if let Some(e) = EMBEDDER.get() {
130 return Ok(e);
131 }
132 let backend = LlmEmbedding::detect_available()?;
133 let _ = EMBEDDER.set(Mutex::new(backend));
134 Ok(EMBEDDER.get().expect("EMBEDDER initialised above"))
135}
136
137fn clone_client(embedder: &Mutex<LlmEmbedding>) -> LlmEmbedding {
140 embedder.lock().clone()
141}
142
143pub fn embed_passage(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
147 let client = clone_client(embedder);
148 let result = client.embed_passage(text)?;
149 validate_dim(result)
150}
151
152pub fn embed_query(embedder: &Mutex<LlmEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
156 let client = clone_client(embedder);
157 let result = client.embed_query(text)?;
158 validate_dim(result)
159}
160
161pub fn embed_passages_controlled(
166 embedder: &Mutex<LlmEmbedding>,
167 texts: &[&str],
168 _token_counts: &[usize],
169) -> Result<Vec<Vec<f32>>, AppError> {
170 if texts.is_empty() {
171 return Ok(Vec::new());
172 }
173 let owned: Vec<String> = texts.iter().map(|t| t.to_string()).collect();
174 embed_texts_parallel(embedder, &owned, 1, chunk_embed_batch_size())
175}
176
177pub fn embed_passage_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
178 let _slot_guard = acquire_llm_slot_for_embedding()?;
179 let embedder = get_embedder(models_dir)?;
180 embed_passage(embedder, text)
181}
182
183pub fn embed_query_local(models_dir: &Path, text: &str) -> Result<Vec<f32>, AppError> {
184 let _slot_guard = acquire_llm_slot_for_embedding()?;
185 let embedder = get_embedder(models_dir)?;
186 embed_query(embedder, text)
187}
188
189pub fn embed_passage_with_choice(
206 models_dir: &Path,
207 text: &str,
208 choice: Option<crate::cli::LlmBackendChoice>,
209) -> Result<Vec<f32>, AppError> {
210 let _slot_guard = acquire_llm_slot_for_embedding()?;
211 match choice {
212 None => {
213 let embedder = get_embedder(models_dir)?;
214 embed_passage(embedder, text)
215 }
216 Some(choice) => embed_with_fallback(models_dir, text, &choice.to_chain(), false),
217 }
218}
219
220pub fn try_embed_query_with_choice(
227 models_dir: &Path,
228 text: &str,
229 choice: Option<crate::cli::LlmBackendChoice>,
230) -> Result<Vec<f32>, FallbackReason> {
231 match embed_passage_with_choice(models_dir, text, choice) {
232 Ok(v) => Ok(v),
233 Err(AppError::Embedding(msg)) if msg.contains("cancelled") => {
234 Err(FallbackReason::Cancelled)
235 }
236 Err(AppError::Embedding(msg)) => Err(FallbackReason::EmbeddingFailed(msg)),
237 Err(AppError::Timeout {
238 operation,
239 duration_secs,
240 }) => Err(FallbackReason::Timeout {
241 operation,
242 duration_secs,
243 }),
244 Err(e) => Err(FallbackReason::EmbeddingFailed(e.to_string())),
245 }
246}
247
248fn acquire_llm_slot_for_embedding() -> Result<crate::llm_slots::LlmSlotGuard, AppError> {
261 use crate::constants::{CLI_LOCK_DEFAULT_WAIT_SECS, LLM_WORKER_RSS_MB};
262 let max = std::env::var("SQLITE_GRAPHRAG_LLM_MAX_HOST_CONCURRENCY")
263 .ok()
264 .and_then(|s| s.parse::<u32>().ok())
265 .filter(|n| *n >= 1)
266 .unwrap_or_else(crate::llm_slots::default_max_concurrency);
267 let wait_secs = if std::env::var("SQLITE_GRAPHRAG_LLM_SLOT_NO_WAIT").is_ok() {
268 0
269 } else {
270 std::env::var("SQLITE_GRAPHRAG_LLM_SLOT_WAIT_SECS")
271 .ok()
272 .and_then(|s| s.parse::<u64>().ok())
273 .unwrap_or(CLI_LOCK_DEFAULT_WAIT_SECS)
274 };
275 let _ = LLM_WORKER_RSS_MB; crate::llm_slots::acquire_llm_slot(max, wait_secs)
277}
278#[derive(Debug, Clone, PartialEq)]
285pub enum FallbackReason {
286 EmbeddingFailed(String),
290 Cancelled,
292 Timeout {
295 operation: String,
296 duration_secs: u64,
297 },
298}
299
300impl std::fmt::Display for FallbackReason {
301 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 match self {
303 Self::EmbeddingFailed(msg) => write!(f, "embedding failed: {msg}"),
304 Self::Cancelled => write!(f, "embedding cancelled by external signal"),
305 Self::Timeout {
306 operation,
307 duration_secs,
308 } => {
309 write!(
310 f,
311 "embedding timed out after {duration_secs}s during {operation}"
312 )
313 }
314 }
315 }
316}
317
318impl std::error::Error for FallbackReason {}
319
320pub fn try_embed_query_with_fallback(
328 models_dir: &Path,
329 query: &str,
330) -> Result<Vec<f32>, FallbackReason> {
331 match embed_query_local(models_dir, query) {
332 Ok(v) => Ok(v),
333 Err(AppError::Embedding(msg)) if msg.contains("cancelled") => {
334 Err(FallbackReason::Cancelled)
335 }
336 Err(AppError::Embedding(msg)) => Err(FallbackReason::EmbeddingFailed(msg)),
337 Err(AppError::Timeout {
338 operation,
339 duration_secs,
340 }) => Err(FallbackReason::Timeout {
341 operation,
342 duration_secs,
343 }),
344 Err(e) => Err(FallbackReason::EmbeddingFailed(e.to_string())),
345 }
346}
347
348pub fn embed_with_fallback(
369 models_dir: &Path,
370 text: &str,
371 chain: &[LlmBackendKind],
372 skip_on_failure: bool,
373) -> Result<Vec<f32>, AppError> {
374 use crate::llm::exit_code_hints::LlmBackendError;
375 let effective: Vec<LlmBackendKind> = if chain.is_empty() {
376 vec![
377 LlmBackendKind::Codex,
378 LlmBackendKind::Claude,
379 LlmBackendKind::None,
380 ]
381 } else {
382 chain.to_vec()
383 };
384
385 let mut last_err: Option<AppError> = None;
386 for backend in &effective {
387 match embed_via_backend(models_dir, text, backend) {
388 Ok(v) => return Ok(v),
389 Err(e) => {
390 tracing::warn!(
391 target: "embedding",
392 backend = ?backend,
393 error = %e,
394 "embed_with_fallback: backend failed, trying next"
395 );
396 last_err = Some(e);
397 }
398 }
399 }
400 if skip_on_failure {
401 return Ok(Vec::new());
406 }
407 Err(last_err
408 .unwrap_or_else(|| AppError::Embedding(LlmBackendError::NoBackendsAvailable.to_string())))
409}
410
411#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
415pub enum LlmBackendKind {
416 Codex,
418 Claude,
420 None,
422}
423
424pub fn embed_via_backend(
428 models_dir: &Path,
429 text: &str,
430 backend: &LlmBackendKind,
431) -> Result<Vec<f32>, AppError> {
432 match backend {
433 LlmBackendKind::None => Ok(Vec::new()),
434 LlmBackendKind::Codex => embed_passage_local(models_dir, text),
435 LlmBackendKind::Claude => {
436 embed_passage_local(models_dir, text)
444 }
445 }
446}
447
448pub fn embed_passages_controlled_local(
449 models_dir: &Path,
450 texts: &[&str],
451 token_counts: &[usize],
452) -> Result<Vec<Vec<f32>>, AppError> {
453 let embedder = get_embedder(models_dir)?;
454 embed_passages_controlled(embedder, texts, token_counts)
455}
456
457pub fn embed_passages_parallel_local(
460 models_dir: &Path,
461 texts: &[String],
462 parallelism: usize,
463 batch_size: usize,
464) -> Result<Vec<Vec<f32>>, AppError> {
465 let embedder = get_embedder(models_dir)?;
466 embed_texts_parallel(embedder, texts, parallelism, batch_size)
467}
468
469type EntityEmbedCacheMap = std::collections::HashMap<u64, Arc<Vec<f32>>>;
481
482static ENTITY_EMBED_CACHE: OnceLock<parking_lot::Mutex<EntityEmbedCacheMap>> = OnceLock::new();
483
484fn entity_embed_cache() -> &'static parking_lot::Mutex<EntityEmbedCacheMap> {
485 ENTITY_EMBED_CACHE.get_or_init(|| parking_lot::Mutex::new(std::collections::HashMap::new()))
486}
487
488fn entity_cache_key(model: &str, text: &str) -> u64 {
489 let mut hasher = blake3::Hasher::new();
490 hasher.update(model.as_bytes());
491 hasher.update(b"\0");
492 hasher.update(text.as_bytes());
493 let h = hasher.finalize();
494 let bytes = h.as_bytes();
495 u64::from_le_bytes([
496 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
497 ])
498}
499
500pub fn embed_entity_texts_cached(
510 models_dir: &Path,
511 texts: &[String],
512 parallelism: usize,
513) -> Result<(Vec<Vec<f32>>, EmbedCacheStats), AppError> {
514 if texts.is_empty() {
515 return Ok((Vec::new(), EmbedCacheStats::default()));
516 }
517 let embedder = get_embedder(models_dir)?;
518 let model = embedder.lock().model_label();
519 let cache = entity_embed_cache();
520 let mut hits: Vec<Option<Arc<Vec<f32>>>> = vec![None; texts.len()];
521 let mut miss_indices: Vec<usize> = Vec::with_capacity(texts.len());
522 {
523 let guard = cache.lock();
524 for (i, text) in texts.iter().enumerate() {
525 let key = entity_cache_key(&model, text);
526 if let Some(v) = guard.get(&key) {
527 hits[i] = Some(Arc::clone(v));
528 } else {
529 miss_indices.push(i);
530 }
531 }
532 }
533 let miss_count = miss_indices.len();
534 if miss_count > 0 {
535 let miss_texts: Vec<String> = miss_indices.iter().map(|&i| texts[i].clone()).collect();
536 let miss_vecs = embed_texts_parallel(
537 embedder,
538 &miss_texts,
539 parallelism,
540 entity_embed_batch_size(),
541 )?;
542 let mut guard = cache.lock();
543 for (slot, &orig_idx) in miss_indices.iter().enumerate() {
544 let vec = Arc::new(miss_vecs[slot].clone());
545 let key = entity_cache_key(&model, &texts[orig_idx]);
546 guard.insert(key, Arc::clone(&vec));
547 hits[orig_idx] = Some(vec);
548 }
549 }
550 let mut out = Vec::with_capacity(texts.len());
551 for hit in hits.into_iter() {
552 let v = hit.ok_or_else(|| {
553 AppError::Embedding("entity embed cache produced null result".to_string())
554 })?;
555 out.push((*v).clone());
556 }
557 Ok((
558 out,
559 EmbedCacheStats {
560 requested: texts.len(),
561 hits: texts.len() - miss_count,
562 misses: miss_count,
563 },
564 ))
565}
566
567#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, serde::Serialize)]
569pub struct EmbedCacheStats {
570 pub requested: usize,
571 pub hits: usize,
572 pub misses: usize,
573}
574
575impl EmbedCacheStats {
576 pub fn hit_rate(&self) -> f64 {
578 if self.requested == 0 {
579 0.0
580 } else {
581 self.hits as f64 / self.requested as f64
582 }
583 }
584}
585
586pub fn embed_texts_parallel(
599 embedder: &Mutex<LlmEmbedding>,
600 texts: &[String],
601 parallelism: usize,
602 batch_size: usize,
603) -> Result<Vec<Vec<f32>>, AppError> {
604 let mut slots: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
605 embed_texts_parallel_with(embedder, texts, parallelism, batch_size, |idx, v| {
606 slots[idx] = Some(v.to_vec());
607 Ok(())
608 })?;
609 let mut out = Vec::with_capacity(slots.len());
610 for (idx, slot) in slots.into_iter().enumerate() {
611 out.push(slot.ok_or_else(|| {
612 AppError::Embedding(format!("embedding fan-out lost item index {idx}"))
613 })?);
614 }
615 Ok(out)
616}
617
618pub fn embed_texts_parallel_with(
622 embedder: &Mutex<LlmEmbedding>,
623 texts: &[String],
624 parallelism: usize,
625 batch_size: usize,
626 mut on_result: impl FnMut(usize, &[f32]) -> Result<(), AppError>,
627) -> Result<(), AppError> {
628 if texts.is_empty() {
629 return Ok(());
630 }
631 let dim = crate::constants::embedding_dim();
632 if texts.len() == 1 {
633 let v = embed_passage(embedder, &texts[0])?;
634 return on_result(0, &v);
635 }
636
637 let client = clone_client(embedder);
638 let permits = effective_permits(parallelism);
639 let batches = build_batches(texts, batch_size.max(1));
640 let token = crate::cancel_token().clone();
641
642 let work = move |batch: Vec<(usize, String)>| {
643 let client = client.clone();
644 async move {
645 client
646 .embed_batch_async(crate::constants::PASSAGE_PREFIX, &batch)
647 .await
648 }
649 };
650
651 let fan_out = run_bounded(batches, permits, dim, token, work, &mut on_result);
652 match tokio::runtime::Handle::try_current() {
653 Ok(handle) => tokio::task::block_in_place(|| handle.block_on(fan_out)),
654 Err(_) => shared_runtime()?.block_on(fan_out),
655 }
656}
657
658fn build_batches(texts: &[String], batch_size: usize) -> Vec<Vec<(usize, String)>> {
660 texts
661 .iter()
662 .cloned()
663 .enumerate()
664 .collect::<Vec<_>>()
665 .chunks(batch_size)
666 .map(|c| c.to_vec())
667 .collect()
668}
669
670pub fn effective_permits(requested: usize) -> usize {
675 let cpus = std::thread::available_parallelism()
676 .map(|n| n.get())
677 .unwrap_or(4);
678 let by_ram = ((crate::memory_guard::available_memory_mb() / 2)
679 / crate::constants::LLM_WORKER_RSS_MB)
680 .max(1) as usize;
681 requested.clamp(1, 32).min(cpus).min(by_ram).max(1)
682}
683
684async fn run_bounded<F, Fut>(
694 batches: Vec<Vec<(usize, String)>>,
695 permits: usize,
696 dim: usize,
697 token: CancellationToken,
698 work: F,
699 on_result: &mut impl FnMut(usize, &[f32]) -> Result<(), AppError>,
700) -> Result<(), AppError>
701where
702 F: Fn(Vec<(usize, String)>) -> Fut + Clone + Send + 'static,
703 Fut: std::future::Future<Output = Result<Vec<(usize, Vec<f32>)>, AppError>> + Send,
704{
705 let total_batches = batches.len();
706 let semaphore = Arc::new(Semaphore::new(permits));
707 let (tx, mut rx) = mpsc::channel::<Result<Vec<(usize, Vec<f32>)>, AppError>>(permits * 2);
710 let mut set: JoinSet<()> = JoinSet::new();
711
712 for (batch_idx, batch) in batches.into_iter().enumerate() {
713 let sem = Arc::clone(&semaphore);
714 let token = token.clone();
715 let tx = tx.clone();
716 let work = work.clone();
717 set.spawn(async move {
718 let wait_start = std::time::Instant::now();
719 let Ok(_permit) = sem.acquire_owned().await else {
722 let _ = tx
723 .send(Err(AppError::Embedding("semaphore closed".to_string())))
724 .await;
725 return;
726 };
727 let permit_wait_ms = wait_start.elapsed().as_millis() as u64;
728 let work_start = std::time::Instant::now();
729 let outcome = if crate::should_obey_shutdown() {
735 tokio::select! {
736 res = work(batch) => res,
737 _ = token.cancelled() => Err(AppError::Embedding(
738 "embedding cancelled by shutdown signal".to_string(),
739 )),
740 }
741 } else {
742 work(batch).await
743 };
744 tracing::debug!(
746 target: "embedding",
747 batch_idx,
748 permit_wait_ms,
749 work_ms = work_start.elapsed().as_millis() as u64,
750 ok = outcome.is_ok(),
751 "embedding batch finished"
752 );
753 let _ = tx.send(outcome).await;
754 });
755 }
756 drop(tx);
757
758 let mut completed = 0usize;
759 let mut failed = 0usize;
760 let mut cancelled = 0usize;
761 let mut first_error: Option<AppError> = None;
762
763 while let Some(message) = rx.recv().await {
764 match message {
765 Ok(items) => {
766 completed += 1;
767 if first_error.is_none() {
768 for (idx, v) in items {
769 if v.len() != dim {
770 first_error = Some(AppError::Embedding(format!(
771 "LLM returned {} dims for item {idx}, expected {dim}; \
772 refusing to truncate or pad silently (G42/C5)",
773 v.len()
774 )));
775 break;
776 }
777 if let Err(e) = on_result(idx, &v) {
778 first_error = Some(e);
779 break;
780 }
781 }
782 if first_error.is_some() {
783 set.shutdown().await;
786 }
787 }
788 }
789 Err(e) => {
790 if matches!(&e, AppError::Embedding(msg) if msg.contains("cancelled")) {
791 cancelled += 1;
792 } else {
793 failed += 1;
794 }
795 if first_error.is_none() {
796 first_error = Some(e);
797 set.shutdown().await;
798 }
799 }
800 }
801 }
802
803 while let Some(join_result) = set.join_next().await {
806 if let Err(join_err) = join_result {
807 if join_err.is_panic() {
808 failed += 1;
809 if first_error.is_none() {
810 first_error = Some(AppError::Embedding(format!(
811 "embedding task panicked: {join_err}"
812 )));
813 }
814 } else {
815 cancelled += 1;
816 }
817 }
818 }
819
820 tracing::info!(
823 target: "embedding",
824 total_batches,
825 completed,
826 failed,
827 cancelled,
828 available_permits = semaphore.available_permits(),
829 "embedding fan-out finished"
830 );
831
832 match first_error {
833 Some(e) => Err(e),
834 None => Ok(()),
835 }
836}
837
838pub fn f32_to_bytes(v: &[f32]) -> Vec<u8> {
839 let mut out = Vec::with_capacity(v.len() * 4);
840 for f in v {
841 out.extend_from_slice(&f.to_le_bytes());
842 }
843 out
844}
845
846pub fn bytes_to_f32(bytes: &[u8]) -> Vec<f32> {
847 let mut out = Vec::with_capacity(bytes.len() / 4);
848 for chunk in bytes.chunks_exact(4) {
849 out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
850 }
851 out
852}
853
854pub fn embedding_dim() -> usize {
857 crate::constants::embedding_dim()
858}
859
860fn validate_dim(v: Vec<f32>) -> Result<Vec<f32>, AppError> {
864 let dim = crate::constants::embedding_dim();
865 if v.len() != dim {
866 return Err(AppError::Embedding(format!(
867 "embedding has {} dims, expected {dim}; \
868 refusing to truncate or pad silently (G42/C5)",
869 v.len()
870 )));
871 }
872 Ok(v)
873}
874
875#[cfg(test)]
876mod tests {
877 use super::*;
878 use std::sync::atomic::{AtomicUsize, Ordering};
879
880 #[test]
881 fn f32_to_bytes_roundtrip() {
882 let input = vec![0.0_f32, 1.5, -2.25, f32::MIN, f32::MAX];
883 let bytes = f32_to_bytes(&input);
884 assert_eq!(bytes.len(), input.len() * 4);
885 let out = bytes_to_f32(&bytes);
886 assert_eq!(out, input);
887 }
888
889 #[test]
890 fn validate_dim_rejects_divergent_vectors() {
891 let dim = crate::constants::embedding_dim();
894 let long = vec![0.0; dim + 10];
895 assert!(validate_dim(long).is_err(), "longer vector must error");
896 let short = vec![0.0; dim.saturating_sub(1).max(1)];
897 assert!(validate_dim(short).is_err(), "shorter vector must error");
898 let exact = vec![0.0; dim];
899 assert_eq!(validate_dim(exact).expect("exact dim must pass").len(), dim);
900 }
901
902 #[test]
903 fn embedding_dim_matches_constants_source() {
904 assert_eq!(embedding_dim(), crate::constants::embedding_dim());
905 }
906
907 #[test]
908 fn build_batches_preserves_global_indices() {
909 let texts: Vec<String> = (0..10).map(|i| format!("t{i}")).collect();
910 let batches = build_batches(&texts, 4);
911 assert_eq!(batches.len(), 3);
912 assert_eq!(batches[0].len(), 4);
913 assert_eq!(batches[2].len(), 2);
914 assert_eq!(batches[2][1].0, 9);
915 assert_eq!(batches[2][1].1, "t9");
916 }
917
918 #[test]
919 fn effective_permits_clamps_to_bounds() {
920 assert!(effective_permits(0) >= 1);
921 assert!(effective_permits(1000) <= 32);
922 }
923
924 fn test_batches(n: usize) -> Vec<Vec<(usize, String)>> {
925 (0..n).map(|i| vec![(i, format!("t{i}"))]).collect()
926 }
927
928 fn dummy_vec(dim: usize) -> Vec<f32> {
929 vec![0.0; dim]
930 }
931
932 #[test]
935 fn concurrency_peak_never_exceeds_permits() {
936 let permits = 4usize;
937 let batches = test_batches(permits * 10);
938 let dim = crate::constants::embedding_dim();
939 let current = Arc::new(AtomicUsize::new(0));
940 let peak = Arc::new(AtomicUsize::new(0));
941
942 let current_c = Arc::clone(¤t);
943 let peak_c = Arc::clone(&peak);
944 let work = move |batch: Vec<(usize, String)>| {
945 let current = Arc::clone(¤t_c);
946 let peak = Arc::clone(&peak_c);
947 async move {
948 let now = current.fetch_add(1, Ordering::SeqCst) + 1;
949 peak.fetch_max(now, Ordering::SeqCst);
950 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
951 current.fetch_sub(1, Ordering::SeqCst);
952 Ok(batch
953 .into_iter()
954 .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
955 .collect())
956 }
957 };
958
959 let mut delivered = 0usize;
960 let rt = tokio::runtime::Builder::new_multi_thread()
961 .worker_threads(4)
962 .enable_all()
963 .build()
964 .expect("test runtime");
965 rt.block_on(run_bounded(
966 batches,
967 permits,
968 dim,
969 CancellationToken::new(),
970 work,
971 &mut |_idx, _v| {
972 delivered += 1;
973 Ok(())
974 },
975 ))
976 .expect("fan-out must succeed");
977
978 assert_eq!(delivered, permits * 10, "every item must be delivered");
979 assert!(
980 peak.load(Ordering::SeqCst) <= permits,
981 "peak concurrency {} exceeded permits {permits}",
982 peak.load(Ordering::SeqCst)
983 );
984 }
985
986 #[test]
989 fn panicking_task_returns_permit_and_surfaces_error() {
990 let permits = 2usize;
991 let batches = test_batches(4);
992 let dim = crate::constants::embedding_dim();
993
994 let work = move |batch: Vec<(usize, String)>| async move {
995 if batch[0].0 == 1 {
996 panic!("intentional test panic");
997 }
998 Ok(batch
999 .into_iter()
1000 .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
1001 .collect())
1002 };
1003
1004 let rt = tokio::runtime::Builder::new_multi_thread()
1005 .worker_threads(2)
1006 .enable_all()
1007 .build()
1008 .expect("test runtime");
1009 let result = rt.block_on(run_bounded(
1010 batches,
1011 permits,
1012 dim,
1013 CancellationToken::new(),
1014 work,
1015 &mut |_idx, _v| Ok(()),
1016 ));
1017
1018 let err = result.expect_err("panic must surface as an error");
1019 assert!(
1020 err.to_string().contains("panicked"),
1021 "error must mention the panic: {err}"
1022 );
1023 }
1024
1025 #[test]
1028 fn cancellation_terminates_fan_out_quickly() {
1029 let permits = 2usize;
1030 let batches = test_batches(8);
1031 let dim = crate::constants::embedding_dim();
1032 let token = CancellationToken::new();
1033
1034 let work = move |batch: Vec<(usize, String)>| async move {
1035 tokio::time::sleep(std::time::Duration::from_secs(30)).await;
1037 Ok(batch
1038 .into_iter()
1039 .map(|(i, _)| (i, dummy_vec(crate::constants::embedding_dim())))
1040 .collect())
1041 };
1042
1043 let rt = tokio::runtime::Builder::new_multi_thread()
1044 .worker_threads(2)
1045 .enable_all()
1046 .build()
1047 .expect("test runtime");
1048 let cancel = token.clone();
1049 let start = std::time::Instant::now();
1050 let result = rt.block_on(async move {
1051 tokio::spawn(async move {
1052 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1053 cancel.cancel();
1054 });
1055 run_bounded(batches, permits, dim, token, work, &mut |_idx, _v| Ok(())).await
1056 });
1057
1058 assert!(result.is_err(), "cancelled fan-out must report an error");
1059 assert!(
1060 start.elapsed() < std::time::Duration::from_secs(10),
1061 "graceful shutdown must finish well under the work duration"
1062 );
1063 }
1064
1065 #[test]
1068 fn fan_out_rejects_divergent_dim() {
1069 let permits = 2usize;
1070 let batches = test_batches(2);
1071 let dim = crate::constants::embedding_dim();
1072
1073 let work = move |batch: Vec<(usize, String)>| async move {
1074 Ok(batch
1075 .into_iter()
1076 .map(|(i, _)| (i, vec![0.0f32; 3]))
1077 .collect::<Vec<(usize, Vec<f32>)>>())
1078 };
1079
1080 let rt = tokio::runtime::Builder::new_multi_thread()
1081 .worker_threads(2)
1082 .enable_all()
1083 .build()
1084 .expect("test runtime");
1085 let result = rt.block_on(run_bounded(
1086 batches,
1087 permits,
1088 dim,
1089 CancellationToken::new(),
1090 work,
1091 &mut |_idx, _v| Ok(()),
1092 ));
1093
1094 let err = result.expect_err("divergent dim must fail the fan-out");
1095 assert!(err.to_string().contains("G42/C5"), "error cites C5: {err}");
1096 }
1097
1098 #[test]
1100 fn adaptive_batch_dim64_keeps_calibrated_sizes() {
1101 assert_eq!(adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, 64), 8);
1102 assert_eq!(adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, 64), 25);
1103 }
1104
1105 #[test]
1107 fn adaptive_batch_dim384_shrinks() {
1108 assert_eq!(adaptive_batch_for_dim(CHUNK_EMBED_BATCH_SIZE, 384), 1);
1109 assert_eq!(adaptive_batch_for_dim(ENTITY_EMBED_BATCH_SIZE, 384), 4);
1110 }
1111
1112 #[test]
1114 fn adaptive_batch_intermediate_dims() {
1115 assert_eq!(adaptive_batch_for_dim(8, 128), 4);
1116 assert_eq!(adaptive_batch_for_dim(8, 256), 2);
1117 }
1118
1119 #[test]
1121 fn adaptive_batch_small_dim_clamps_to_base() {
1122 assert_eq!(adaptive_batch_for_dim(8, 8), 8);
1123 }
1124
1125 #[test]
1127 fn adaptive_batch_total_function() {
1128 assert_eq!(adaptive_batch_for_dim(8, 4096), 1);
1129 assert_eq!(adaptive_batch_for_dim(8, 0), 8);
1130 assert_eq!(adaptive_batch_for_dim(0, 64), 1);
1131 }
1132
1133 #[test]
1135 #[serial_test::serial(env)]
1136 fn adaptive_wrappers_follow_env_dim() {
1137 std::env::set_var("SQLITE_GRAPHRAG_EMBEDDING_DIM", "384");
1138 let chunk = chunk_embed_batch_size();
1139 let entity = entity_embed_batch_size();
1140 std::env::remove_var("SQLITE_GRAPHRAG_EMBEDDING_DIM");
1141 crate::constants::set_active_embedding_dim(crate::constants::DEFAULT_EMBEDDING_DIM);
1142 assert_eq!(chunk, 1, "384-dim chunk batch must shrink to 1 (G44)");
1143 assert_eq!(entity, 4, "384-dim entity batch must shrink to 4 (G44)");
1144 }
1145
1146 #[test]
1152 fn fallback_reason_display_does_not_panic() {
1153 let _ = FallbackReason::EmbeddingFailed("rate limit".into()).to_string();
1154 let _ = FallbackReason::Cancelled.to_string();
1155 let _ = FallbackReason::Timeout {
1156 operation: "embed_query".into(),
1157 duration_secs: 30,
1158 }
1159 .to_string();
1160 }
1161
1162 #[test]
1165 fn fallback_reason_is_partial_eq() {
1166 assert_eq!(
1167 FallbackReason::EmbeddingFailed("a".into()),
1168 FallbackReason::EmbeddingFailed("a".into())
1169 );
1170 assert_eq!(FallbackReason::Cancelled, FallbackReason::Cancelled);
1171 assert_ne!(
1172 FallbackReason::EmbeddingFailed("a".into()),
1173 FallbackReason::EmbeddingFailed("b".into())
1174 );
1175 assert_ne!(
1176 FallbackReason::Cancelled,
1177 FallbackReason::Timeout {
1178 operation: "x".into(),
1179 duration_secs: 1
1180 }
1181 );
1182 }
1183
1184 #[test]
1187 fn fallback_reason_timeout_preserves_fields() {
1188 let r = FallbackReason::Timeout {
1189 operation: "embed_query_local".into(),
1190 duration_secs: 300,
1191 };
1192 match r {
1193 FallbackReason::Timeout {
1194 operation,
1195 duration_secs,
1196 } => {
1197 assert_eq!(operation, "embed_query_local");
1198 assert_eq!(duration_secs, 300);
1199 }
1200 other => panic!("expected Timeout, got {other:?}"),
1201 }
1202 }
1203
1204 #[test]
1210 #[ignore = "G58 S1 stub: requires env without codex/claude on PATH; tracked as T5 of Fase 2"]
1211 fn try_embed_query_with_fallback_surfaces_embedding_failed_for_missing_binary() {
1212 let bogus = std::path::Path::new("/nonexistent-models-dir-for-g58-fallback-test");
1215 let result = try_embed_query_with_fallback(bogus, "hello world");
1216 match result {
1217 Err(FallbackReason::EmbeddingFailed(msg)) => {
1218 assert!(!msg.is_empty(), "fallback message must not be empty");
1220 }
1221 Err(FallbackReason::Cancelled) => {
1222 panic!("expected EmbeddingFailed, got Cancelled");
1223 }
1224 Err(FallbackReason::Timeout { .. }) => {
1225 panic!("expected EmbeddingFailed, got Timeout");
1226 }
1227 Ok(_) => {
1228 panic!("expected an error, got Ok — embedder must fail for bogus path");
1229 }
1230 }
1231 }
1232
1233 #[test]
1235 fn g56_entity_cache_key_is_stable_and_distinct() {
1236 let k1 = entity_cache_key("codex:default", "sqlite-graphrag");
1237 let k2 = entity_cache_key("codex:default", "sqlite-graphrag");
1238 let k3 = entity_cache_key("codex:default", "claude-code");
1239 let k4 = entity_cache_key("claude:default", "sqlite-graphrag");
1240 assert_eq!(k1, k2, "same model+text must hash identically");
1241 assert_ne!(k1, k3, "different text must hash differently");
1242 assert_ne!(k1, k4, "different model must hash differently");
1243 }
1244
1245 #[test]
1246 fn g56_entity_embed_cache_stats_hit_rate() {
1247 let zero = EmbedCacheStats::default();
1248 assert_eq!(zero.hit_rate(), 0.0);
1249 let half = EmbedCacheStats {
1250 requested: 4,
1251 hits: 2,
1252 misses: 2,
1253 };
1254 assert!((half.hit_rate() - 0.5).abs() < 1e-9);
1255 let all = EmbedCacheStats {
1256 requested: 7,
1257 hits: 7,
1258 misses: 0,
1259 };
1260 assert!((all.hit_rate() - 1.0).abs() < 1e-9);
1261 }
1262
1263 #[test]
1264 fn g56_entity_embed_cache_populates_and_hits() {
1265 let cache = entity_embed_cache();
1269 let model = "test-model";
1270 let text = "sqlite-graphrag";
1271 let key = entity_cache_key(model, text);
1272 let stored = Arc::new(vec![0.42_f32; crate::constants::embedding_dim()]);
1273 cache.lock().insert(key, Arc::clone(&stored));
1274 let guard = cache.lock();
1275 let hit = guard.get(&key).expect("cache must return stored value");
1276 assert_eq!(hit.len(), crate::constants::embedding_dim());
1277 assert!((hit[0] - 0.42).abs() < 1e-6);
1278 }
1279
1280 #[test]
1281 fn g56_empty_texts_short_circuits_with_zero_stats() {
1282 let stats = EmbedCacheStats::default();
1285 assert_eq!(stats.requested, 0);
1286 assert_eq!(stats.hits, 0);
1287 assert_eq!(stats.misses, 0);
1288 assert_eq!(stats.hit_rate(), 0.0);
1289 }
1290}
1291
1292#[cfg(test)]
1296mod embed_with_fallback_tests {
1297 use super::*;
1298 use crate::llm::exit_code_hints::LlmBackendError;
1299
1300 #[test]
1301 fn none_backend_returns_empty_vector_without_calling_llm() {
1302 let v = embed_via_backend(
1306 std::path::Path::new("/nonexistent"),
1307 "any text",
1308 &LlmBackendKind::None,
1309 )
1310 .expect("None backend never fails");
1311 assert!(v.is_empty());
1312 }
1313
1314 #[test]
1315 fn empty_chain_defaults_to_codex_claude_none() {
1316 let defaults = [
1320 LlmBackendKind::Codex,
1321 LlmBackendKind::Claude,
1322 LlmBackendKind::None,
1323 ];
1324 assert_eq!(defaults.len(), 3);
1325 }
1326
1327 #[test]
1328 fn embed_with_fallback_succeeds_via_none_when_chain_exhausts() {
1329 let chain = vec![LlmBackendKind::None];
1343 let v = embed_with_fallback(
1344 std::path::Path::new("/nonexistent-models-dir-for-gap005-test"),
1345 "hello",
1346 &chain,
1347 false,
1348 )
1349 .expect("chain ending in None must always succeed");
1350 assert!(v.is_empty());
1351 }
1352
1353 #[test]
1354 fn embed_with_fallback_skip_on_failure_with_only_none_returns_empty() {
1355 let chain = vec![LlmBackendKind::None];
1360 let v = embed_with_fallback(
1361 std::path::Path::new("/nonexistent-models-dir-for-gap005-test"),
1362 "hello",
1363 &chain,
1364 true,
1365 )
1366 .expect("None chain is always Ok");
1367 assert!(v.is_empty());
1368 }
1369
1370 #[test]
1371 fn llm_backend_error_no_backends_default_message() {
1372 let e = LlmBackendError::NoBackendsAvailable;
1375 let h = e.hint();
1376 assert!(h.contains("--llm-fallback"));
1377 }
1378
1379 #[test]
1380 fn llm_backend_error_nonzero_exit_carries_stderr_tail() {
1381 let e = LlmBackendError::NonZeroExit {
1382 exit_code: Some(137),
1383 signal: Some(9),
1384 stdout_tail: "out".into(),
1385 stderr_tail: "OOM killed".into(),
1386 binary: "codex".into(),
1387 hint: "OOM".into(),
1388 };
1389 let s = e.to_string();
1390 assert!(s.contains("codex"));
1391 assert!(s.contains("OOM killed"));
1392 assert!(s.contains("signal 9") || s.contains("exit 137"));
1393 }
1394}