use std::collections::HashMap;
use anyhow::{Context, Result};
use indicatif::{ProgressBar, ProgressStyle};
use cqs::{Embedder, Embedding, Store};
pub(crate) fn enrichment_pass(store: &Store, embedder: &Embedder, quiet: bool) -> Result<usize> {
let _span = tracing::info_span!("enrichment_pass").entered();
let stats = store.stats().context("Failed to get index stats")?;
let total_chunks = stats.total_chunks as f32;
if total_chunks < 1.0 {
return Ok(0);
}
let callee_freq = store
.callee_caller_counts()
.context("Failed to compute callee frequencies")?;
let callee_doc_freq: HashMap<String, f32> = callee_freq
.into_iter()
.map(|(name, count)| (name, count as f32 / total_chunks))
.collect();
let mut enriched_count = 0usize;
let mut cursor = 0i64;
const ENRICHMENT_PAGE_SIZE: usize = 500;
let identities = store
.all_chunk_identities()
.context("Failed to load chunk identities")?;
let mut name_file_count: HashMap<&str, usize> = HashMap::new();
for ci in &identities {
*name_file_count.entry(ci.name.as_str()).or_insert(0) += 1;
}
let progress = if quiet {
ProgressBar::hidden()
} else {
let pb = ProgressBar::new(stats.total_chunks);
pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{bar:40}] {pos}/{len} enriching ({eta})")
.expect("valid progress template")
.progress_chars("=>-"),
);
pb
};
let mut embed_batch: Vec<(String, String, String)> = Vec::new();
let enrich_embed_batch: usize = super::pipeline::embed_batch_size();
let mut skipped_count = 0usize;
let all_summaries = match store.get_all_summaries("summary") {
Ok(s) => s,
Err(e) => {
tracing::warn!(error = %e, "Failed to pre-fetch LLM summaries for enrichment");
HashMap::new()
}
};
let all_hyde = match store.get_all_summaries("hyde") {
Ok(s) => s,
Err(e) => {
tracing::warn!(error = %e, "Failed to pre-fetch hyde predictions for enrichment");
HashMap::new()
}
};
let all_enrichment_hashes = match store.get_all_enrichment_hashes() {
Ok(h) => h,
Err(e) => {
tracing::warn!(error = %e, "Failed to pre-fetch enrichment hashes");
HashMap::new()
}
};
let result: Result<usize> = (|| {
loop {
let (chunks, next_cursor) = store
.chunks_paged(cursor, ENRICHMENT_PAGE_SIZE)
.context("Failed to page chunks")?;
if chunks.is_empty() {
break;
}
cursor = next_cursor;
let page_names: Vec<&str> = chunks.iter().map(|cs| cs.name.as_str()).collect();
tracing::debug!(
page = cursor,
names = page_names.len(),
"Loading callers/callees for enrichment page"
);
let callers_map = store
.get_callers_full_batch(&page_names)
.context("Failed to batch-fetch callers for page")?;
let callees_map = store
.get_callees_full_batch(&page_names)
.context("Failed to batch-fetch callees for page")?;
for cs in &chunks {
progress.inc(1);
let callers = callers_map.get(&cs.name);
let callees = callees_map.get(&cs.name);
let has_callers = callers.is_some_and(|v| !v.is_empty());
let has_callees = callees.is_some_and(|v| !v.is_empty());
let summary = all_summaries.get(&cs.content_hash).map(|s| s.as_str());
let hyde = all_hyde.get(&cs.content_hash).map(|s| s.as_str());
if !has_callers && !has_callees && summary.is_none() && hyde.is_none() {
continue;
}
if name_file_count.get(cs.name.as_str()).copied().unwrap_or(0) > 1
&& summary.is_none()
&& hyde.is_none()
{
continue;
}
let ctx = cqs::CallContext {
callers: callers
.map(|v| v.iter().map(|c| c.name.clone()).collect())
.unwrap_or_default(),
callees: callees
.map(|v| v.iter().map(|(name, _)| name.clone()).collect())
.unwrap_or_default(),
};
let enrichment_hash =
compute_enrichment_hash_with_summary(&ctx, &callee_doc_freq, summary, hyde);
if let Some(stored) = all_enrichment_hashes.get(&cs.id) {
if *stored == enrichment_hash {
skipped_count += 1;
continue;
}
}
let chunk: cqs::parser::Chunk = cs.into();
let enriched_nl = cqs::generate_nl_with_call_context_and_summary(
&chunk,
&ctx,
&callee_doc_freq,
5, 5, summary,
hyde,
);
embed_batch.push((cs.id.clone(), enriched_nl, enrichment_hash));
if embed_batch.len() >= enrich_embed_batch {
enriched_count += flush_enrichment_batch(store, embedder, &mut embed_batch)?;
}
}
}
if !embed_batch.is_empty() {
enriched_count += flush_enrichment_batch(store, embedder, &mut embed_batch)?;
}
Ok(enriched_count)
})();
progress.finish_and_clear();
let enriched_count = result?;
tracing::info!(enriched_count, skipped_count, "Enrichment pass complete");
if !quiet {
if skipped_count > 0 {
eprintln!(
"Enriched {} chunks with call graph context ({} already up-to-date)",
enriched_count, skipped_count
);
} else {
eprintln!("Enriched {} chunks with call graph context", enriched_count);
}
}
Ok(enriched_count)
}
fn compute_enrichment_hash_with_summary(
ctx: &cqs::CallContext,
callee_doc_freq: &HashMap<String, f32>,
summary: Option<&str>,
hyde: Option<&str>,
) -> String {
use std::fmt::Write;
let mut input = String::new();
let mut callers: Vec<&str> = ctx.callers.iter().map(|s| s.as_str()).collect();
callers.sort_unstable();
for c in &callers {
let _ = write!(input, "c:{c}|");
}
let mut callees: Vec<&str> = ctx
.callees
.iter()
.filter(|name| {
(callee_doc_freq.get(name.as_str()).copied().unwrap_or(0.0) as f64) < 0.1_f64
})
.map(|s| s.as_str())
.collect();
callees.sort_unstable();
for c in &callees {
let _ = write!(input, "e:{c}|");
}
if let Some(s) = summary {
let norm: String = s.split_whitespace().collect::<Vec<_>>().join(" ");
let _ = write!(input, "s:{norm}");
}
if let Some(h) = hyde {
let norm: String = h.split_whitespace().collect::<Vec<_>>().join(" ");
let _ = write!(input, "h:{norm}");
}
let hash = blake3::hash(input.as_bytes());
hash.to_hex()[..32].to_string()
}
fn flush_enrichment_batch(
store: &Store,
embedder: &Embedder,
batch: &mut Vec<(String, String, String)>,
) -> Result<usize> {
let _span = tracing::info_span!("flush_enrichment_batch", count = batch.len()).entered();
let texts: Vec<&str> = batch.iter().map(|(_, nl, _)| nl.as_str()).collect();
let expected = texts.len();
let embeddings = embedder
.embed_documents(&texts)
.context("Failed to embed enriched NL batch")?;
anyhow::ensure!(
embeddings.len() == expected,
"Embedding count mismatch: expected {}, got {}",
expected,
embeddings.len()
);
let updates: Vec<(String, Embedding, Option<String>)> = batch
.iter()
.zip(embeddings)
.map(|((id, _, hash), emb)| (id.clone(), emb, Some(hash.clone())))
.collect();
store
.update_embeddings_with_hashes_batch(&updates)
.context("Failed to update enriched embeddings")?;
let count = updates.len();
batch.clear(); Ok(count)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_ctx(callers: &[&str], callees: &[&str]) -> cqs::CallContext {
cqs::CallContext {
callers: callers.iter().map(|s| s.to_string()).collect(),
callees: callees.iter().map(|s| s.to_string()).collect(),
}
}
#[test]
fn enrichment_hash_deterministic_same_inputs() {
let ctx = make_ctx(&["caller_a", "caller_b"], &["callee_x", "callee_y"]);
let freq: HashMap<String, f32> = HashMap::new();
let summary = Some("Processes raw data");
let h1 = compute_enrichment_hash_with_summary(&ctx, &freq, summary, None);
let h2 = compute_enrichment_hash_with_summary(&ctx, &freq, summary, None);
assert_eq!(h1, h2, "Same inputs must produce identical hashes");
}
#[test]
fn enrichment_hash_deterministic_regardless_of_caller_order() {
let ctx_ab = make_ctx(&["caller_a", "caller_b"], &["callee_x"]);
let ctx_ba = make_ctx(&["caller_b", "caller_a"], &["callee_x"]);
let freq: HashMap<String, f32> = HashMap::new();
let h1 = compute_enrichment_hash_with_summary(&ctx_ab, &freq, None, None);
let h2 = compute_enrichment_hash_with_summary(&ctx_ba, &freq, None, None);
assert_eq!(h1, h2, "Caller order must not affect hash");
}
#[test]
fn enrichment_hash_deterministic_regardless_of_callee_order() {
let ctx_xy = make_ctx(&[], &["callee_x", "callee_y"]);
let ctx_yx = make_ctx(&[], &["callee_y", "callee_x"]);
let freq: HashMap<String, f32> = HashMap::new();
let h1 = compute_enrichment_hash_with_summary(&ctx_xy, &freq, None, None);
let h2 = compute_enrichment_hash_with_summary(&ctx_yx, &freq, None, None);
assert_eq!(h1, h2, "Callee order must not affect hash");
}
#[test]
fn enrichment_hash_changes_with_different_callers() {
let ctx1 = make_ctx(&["caller_a"], &["callee_x"]);
let ctx2 = make_ctx(&["caller_b"], &["callee_x"]);
let freq: HashMap<String, f32> = HashMap::new();
let h1 = compute_enrichment_hash_with_summary(&ctx1, &freq, None, None);
let h2 = compute_enrichment_hash_with_summary(&ctx2, &freq, None, None);
assert_ne!(h1, h2, "Different callers must produce different hashes");
}
#[test]
fn enrichment_hash_changes_with_summary() {
let ctx = make_ctx(&["caller_a"], &["callee_x"]);
let freq: HashMap<String, f32> = HashMap::new();
let h_none = compute_enrichment_hash_with_summary(&ctx, &freq, None, None);
let h_some = compute_enrichment_hash_with_summary(&ctx, &freq, Some("a summary"), None);
assert_ne!(h_none, h_some, "Adding a summary must change the hash");
}
#[test]
fn enrichment_hash_filters_high_freq_callees() {
let ctx = make_ctx(&[], &["log", "rare_fn"]);
let mut freq: HashMap<String, f32> = HashMap::new();
freq.insert("log".to_string(), 0.15);
freq.insert("rare_fn".to_string(), 0.02);
let ctx_without_log = make_ctx(&[], &["rare_fn"]);
let empty_freq: HashMap<String, f32> = HashMap::new();
let h_with = compute_enrichment_hash_with_summary(&ctx, &freq, None, None);
let h_without =
compute_enrichment_hash_with_summary(&ctx_without_log, &empty_freq, None, None);
assert_eq!(
h_with, h_without,
"High-frequency callees (>=10% IDF) must be excluded from hash"
);
}
#[test]
fn enrichment_hash_changes_with_hyde() {
let ctx = make_ctx(&["caller_a"], &[]);
let freq: HashMap<String, f32> = HashMap::new();
let h_none = compute_enrichment_hash_with_summary(&ctx, &freq, None, None);
let h_hyde = compute_enrichment_hash_with_summary(&ctx, &freq, None, Some("how to search"));
assert_ne!(h_none, h_hyde, "Adding hyde must change the hash");
}
}