use std::collections::{HashMap, HashSet, VecDeque};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use rayon::prelude::*;
use crate::parser::{ChunkType, Language};
use crate::{AnalysisError, Embedder, Embedding};
use crate::store::helpers::{CallGraph, SearchFilter};
use crate::store::SearchResult;
use crate::Store;
pub const DEFAULT_MAX_EXPANDED_NODES: usize = 200;
#[derive(Debug, Clone)]
pub struct GatherOptions {
pub expand_depth: usize,
pub direction: GatherDirection,
pub limit: usize,
pub seed_limit: usize,
pub seed_threshold: f32,
pub decay_factor: f32,
pub max_expanded_nodes: usize,
pub query_embedding: Option<Embedding>,
}
impl GatherOptions {
pub fn with_expand_depth(mut self, depth: usize) -> Self {
self.expand_depth = depth;
self
}
pub fn with_direction(mut self, direction: GatherDirection) -> Self {
self.direction = direction;
self
}
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
pub fn with_seed_limit(mut self, limit: usize) -> Self {
self.seed_limit = limit;
self
}
pub fn with_seed_threshold(mut self, threshold: f32) -> Self {
if threshold.is_finite() {
self.seed_threshold = threshold;
} else {
tracing::warn!(
threshold = %threshold,
"NaN/infinite seed threshold, using default 0.3"
);
self.seed_threshold = 0.3;
}
self
}
pub fn with_decay_factor(mut self, factor: f32) -> Self {
self.decay_factor = if factor.is_finite() {
factor.clamp(0.0, 1.0)
} else {
self.decay_factor
};
self
}
pub fn with_max_expanded_nodes(mut self, max: usize) -> Self {
self.max_expanded_nodes = max;
self
}
}
fn gather_max_nodes() -> usize {
static CAP: std::sync::OnceLock<usize> = std::sync::OnceLock::new();
*CAP.get_or_init(|| match std::env::var("CQS_GATHER_MAX_NODES") {
Ok(val) => match val.parse::<usize>() {
Ok(n) if n > 0 => {
tracing::info!(
max_nodes = n,
"BFS node cap overridden via CQS_GATHER_MAX_NODES"
);
n
}
_ => {
tracing::warn!(
value = %val,
"Invalid CQS_GATHER_MAX_NODES, using default {}",
DEFAULT_MAX_EXPANDED_NODES
);
DEFAULT_MAX_EXPANDED_NODES
}
},
Err(_) => DEFAULT_MAX_EXPANDED_NODES,
})
}
impl Default for GatherOptions {
fn default() -> Self {
Self {
expand_depth: 1,
direction: GatherDirection::Both,
limit: 10,
seed_limit: 5,
seed_threshold: 0.3,
decay_factor: 0.8,
max_expanded_nodes: gather_max_nodes(),
query_embedding: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, clap::ValueEnum)]
pub enum GatherDirection {
Both,
Callers,
Callees,
}
impl std::str::FromStr for GatherDirection {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, String> {
match s {
"both" => Ok(Self::Both),
"callers" => Ok(Self::Callers),
"callees" => Ok(Self::Callees),
_ => Err(format!(
"Invalid direction '{}'. Valid: both, callers, callees",
s
)),
}
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct GatheredChunk {
pub name: String,
#[serde(serialize_with = "crate::serialize_path_normalized")]
pub file: PathBuf,
pub line_start: u32,
pub line_end: u32,
pub language: Language,
pub chunk_type: ChunkType,
pub signature: String,
pub content: String,
pub score: f32,
pub depth: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub source: Option<String>,
}
impl GatheredChunk {
pub fn from_search(
sr: &crate::store::SearchResult,
root: &Path,
score: f32,
depth: usize,
source: Option<String>,
) -> Self {
Self {
name: sr.chunk.name.clone(),
file: sr
.chunk
.file
.strip_prefix(root)
.unwrap_or(&sr.chunk.file)
.to_path_buf(),
line_start: sr.chunk.line_start,
line_end: sr.chunk.line_end,
language: sr.chunk.language,
chunk_type: sr.chunk.chunk_type,
signature: sr.chunk.signature.clone(),
content: sr.chunk.content.clone(),
score,
depth,
source,
}
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct GatherResult {
pub chunks: Vec<GatheredChunk>,
pub expansion_capped: bool,
pub search_degraded: bool,
}
pub(crate) fn bfs_expand(
name_scores: &mut HashMap<String, (f32, usize)>,
graph: &CallGraph,
opts: &GatherOptions,
) -> bool {
let mut expansion_capped = false;
if opts.expand_depth == 0 {
return false;
}
let mut visited: HashSet<Arc<str>> =
name_scores.keys().map(|k| Arc::from(k.as_str())).collect();
let initial_size = name_scores.len();
let mut queue: VecDeque<(Arc<str>, usize)> = VecDeque::new();
for name in name_scores.keys() {
queue.push_back((Arc::from(name.as_str()), 0));
}
while let Some((name, depth)) = queue.pop_front() {
if depth >= opts.expand_depth {
continue;
}
if name_scores.len() >= opts.max_expanded_nodes && visited.len() > initial_size {
expansion_capped = true;
break;
}
let neighbors = get_neighbors(graph, &name, opts.direction);
let base_score = name_scores
.get(name.as_ref())
.map(|(s, _)| *s)
.unwrap_or(0.5);
let new_score = base_score * opts.decay_factor;
for neighbor in neighbors {
if name_scores.len() >= opts.max_expanded_nodes {
expansion_capped = true;
break;
}
if !visited.contains(&neighbor) {
visited.insert(Arc::clone(&neighbor));
let key: String = neighbor.to_string();
name_scores.insert(key, (new_score, depth + 1));
queue.push_back((neighbor, depth + 1));
} else if let Some(existing) = name_scores.get_mut(neighbor.as_ref()) {
if new_score > existing.0 {
existing.0 = new_score;
existing.1 = existing.1.min(depth + 1);
}
}
}
if expansion_capped {
break;
}
}
expansion_capped
}
pub(crate) fn fetch_and_assemble(
store: &Store,
name_scores: &HashMap<String, (f32, usize)>,
root: &Path,
) -> (Vec<GatheredChunk>, bool) {
let all_names: Vec<&str> = name_scores.keys().map(|s| s.as_str()).collect();
let (batch_results, search_degraded) = match store.search_by_names_batch(&all_names, 1) {
Ok(r) => (r, false),
Err(e) => {
tracing::warn!(error = %e, "Batch name search failed, results may be incomplete");
(HashMap::new(), true)
}
};
let mut seen_ids: HashSet<String> = HashSet::new();
let mut chunks: Vec<GatheredChunk> = Vec::new();
for (name, (score, depth)) in name_scores {
if let Some(results) = batch_results.get(name) {
if let Some(r) = results.first() {
if seen_ids.contains(&r.chunk.id) {
continue;
}
seen_ids.insert(r.chunk.id.clone());
chunks.push(GatheredChunk::from_search(r, root, *score, *depth, None));
}
}
}
tracing::debug!(chunk_count = chunks.len(), "Chunks assembled");
(chunks, search_degraded)
}
pub(crate) fn sort_and_truncate(chunks: &mut Vec<GatheredChunk>, limit: usize) {
chunks.sort_by(|a, b| b.score.total_cmp(&a.score).then(a.name.cmp(&b.name)));
chunks.truncate(limit);
chunks.sort_by(|a, b| {
a.file
.cmp(&b.file)
.then(a.line_start.cmp(&b.line_start))
.then(a.name.cmp(&b.name))
});
}
pub fn gather(
store: &Store,
embedder: &Embedder,
description: &str,
opts: &GatherOptions,
root: &Path,
) -> Result<GatherResult, AnalysisError> {
let query_embedding = match &opts.query_embedding {
Some(emb) => emb.clone(),
None => embedder.embed_query(description)?,
};
let graph = store.get_call_graph()?;
gather_with_graph(store, &query_embedding, description, opts, root, &graph)
}
pub fn gather_with_graph(
store: &Store,
query_embedding: &crate::Embedding,
query_text: &str,
opts: &GatherOptions,
root: &Path,
graph: &CallGraph,
) -> Result<GatherResult, AnalysisError> {
let _span = tracing::info_span!(
"gather",
query_len = query_text.len(),
expand_depth = opts.expand_depth,
limit = opts.limit
)
.entered();
let filter = SearchFilter {
query_text: query_text.to_string(),
enable_rrf: false, ..SearchFilter::default()
};
let seed_results = store.search_filtered(
query_embedding,
&filter,
opts.seed_limit,
opts.seed_threshold,
)?;
tracing::debug!(seed_count = seed_results.len(), "Seed search complete");
if seed_results.is_empty() {
return Ok(GatherResult {
chunks: Vec::new(),
expansion_capped: false,
search_degraded: false,
});
}
let mut name_scores: HashMap<String, (f32, usize)> = HashMap::new();
for r in &seed_results {
name_scores.insert(r.chunk.name.clone(), (r.score, 0));
}
let expansion_capped = bfs_expand(&mut name_scores, graph, opts);
tracing::info!(
expanded = name_scores.len(),
capped = expansion_capped,
"BFS expansion complete"
);
let (mut chunks, search_degraded) = fetch_and_assemble(store, &name_scores, root);
sort_and_truncate(&mut chunks, opts.limit);
tracing::info!(final_chunks = chunks.len(), "Gather complete");
Ok(GatherResult {
chunks,
expansion_capped,
search_degraded,
})
}
pub fn gather_cross_index(
project_store: &Store,
ref_idx: &crate::reference::ReferenceIndex,
query_embedding: &crate::Embedding,
query_text: &str,
opts: &GatherOptions,
root: &Path,
) -> Result<GatherResult, AnalysisError> {
gather_cross_index_with_index(
project_store,
ref_idx,
query_embedding,
query_text,
opts,
root,
None,
)
}
pub fn gather_cross_index_with_index(
project_store: &Store,
ref_idx: &crate::reference::ReferenceIndex,
query_embedding: &crate::Embedding,
query_text: &str,
opts: &GatherOptions,
root: &Path,
project_index: Option<&dyn crate::index::VectorIndex>,
) -> Result<GatherResult, AnalysisError> {
let _span = tracing::info_span!(
"gather_cross_index",
ref_name = %ref_idx.name,
query_len = query_text.len(),
expand_depth = opts.expand_depth,
limit = opts.limit,
)
.entered();
if let (Ok(proj_model), Ok(ref_model)) = (
project_store.get_metadata("model_name"),
ref_idx.store.get_metadata("model_name"),
) {
if proj_model != ref_model {
tracing::warn!(
project = %proj_model,
reference = %ref_model,
"Model mismatch between project and reference — results may be inaccurate"
);
}
}
let filter = crate::store::helpers::SearchFilter {
query_text: query_text.to_string(),
enable_rrf: false, ..SearchFilter::default()
};
let ref_seeds = crate::reference::search_reference(
ref_idx,
query_embedding,
&filter,
opts.seed_limit,
opts.seed_threshold,
false, )?;
tracing::debug!(
ref_seed_count = ref_seeds.len(),
"Reference seed search complete"
);
if ref_seeds.is_empty() {
return Ok(GatherResult {
chunks: Vec::new(),
expansion_capped: false,
search_degraded: false,
});
}
let ref_seed_ids: Vec<&str> = ref_seeds.iter().map(|r| r.chunk.id.as_str()).collect();
let ref_embeddings = match ref_idx.store.get_embeddings_by_ids(&ref_seed_ids) {
Ok(e) => e,
Err(e) => {
tracing::warn!(error = %e, "Failed to get ref seed embeddings, falling back to query embedding only");
HashMap::new()
}
};
let ref_chunks: Vec<GatheredChunk> = ref_seeds
.iter()
.map(|r| {
GatheredChunk::from_search(r, Path::new(""), r.score, 0, Some(ref_idx.name.clone()))
})
.collect();
let bridge_filter = SearchFilter {
query_text: query_text.to_string(),
enable_rrf: false, ..SearchFilter::default()
};
let bridge_limit = 3;
let _bridge_span = tracing::info_span!("bridge_search", seed_count = ref_seeds.len()).entered();
let bridge_error_count = std::sync::atomic::AtomicUsize::new(0);
let bridge_results: Vec<(f32, Vec<SearchResult>)> = ref_seeds
.par_iter()
.filter_map(|seed| {
let search_embedding = ref_embeddings
.get(&seed.chunk.id)
.unwrap_or(query_embedding);
match project_store.search_filtered_with_index(
search_embedding,
&bridge_filter,
bridge_limit,
opts.seed_threshold,
project_index,
) {
Ok(r) if !r.is_empty() => Some((seed.score, r)),
Ok(_) => None,
Err(e) => {
tracing::warn!(
error = %e,
ref_seed = %seed.chunk.name,
"Bridge search failed for ref seed"
);
bridge_error_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
None
}
}
})
.collect();
let total_bridge_errors = bridge_error_count.load(std::sync::atomic::Ordering::Relaxed);
if total_bridge_errors > 0 {
tracing::warn!(
failed = total_bridge_errors,
total = ref_seeds.len(),
"Bridge searches failed — cross-index gather quality may be reduced"
);
}
drop(_bridge_span);
let mut bridge_scores: HashMap<String, (f32, String)> = HashMap::new(); for (seed_score, results) in bridge_results {
for pr in &results {
let bridge_score = pr.score * seed_score;
match bridge_scores.entry(pr.chunk.name.clone()) {
std::collections::hash_map::Entry::Vacant(e) => {
e.insert((bridge_score, pr.chunk.id.clone()));
}
std::collections::hash_map::Entry::Occupied(mut e) => {
if bridge_score > e.get().0 {
e.insert((bridge_score, pr.chunk.id.clone()));
}
}
}
}
}
tracing::debug!(bridge_count = bridge_scores.len(), "Bridge search complete");
if bridge_scores.is_empty() {
let mut result_chunks = ref_chunks;
result_chunks.truncate(opts.limit);
return Ok(GatherResult {
chunks: result_chunks,
expansion_capped: false,
search_degraded: false,
});
}
let graph = project_store.get_call_graph()?;
let mut name_scores: HashMap<String, (f32, usize)> = HashMap::new();
for (name, (score, _)) in &bridge_scores {
name_scores.insert(name.clone(), (*score, 0));
}
let expansion_capped = bfs_expand(&mut name_scores, &graph, opts);
tracing::debug!(
expanded_nodes = name_scores.len(),
expansion_capped,
"Project BFS expansion complete"
);
let (project_chunks, search_degraded) = fetch_and_assemble(project_store, &name_scores, root);
let mut all_chunks = ref_chunks;
all_chunks.extend(project_chunks);
all_chunks.sort_by(|a, b| b.score.total_cmp(&a.score).then(a.name.cmp(&b.name)));
all_chunks.truncate(opts.limit);
all_chunks.sort_by(|a, b| {
let source_ord = match (&a.source, &b.source) {
(Some(_), None) => std::cmp::Ordering::Less,
(None, Some(_)) => std::cmp::Ordering::Greater,
_ => std::cmp::Ordering::Equal,
};
source_ord
.then(a.file.cmp(&b.file))
.then(a.line_start.cmp(&b.line_start))
.then(a.name.cmp(&b.name))
});
Ok(GatherResult {
chunks: all_chunks,
expansion_capped,
search_degraded,
})
}
fn get_neighbors(graph: &CallGraph, name: &str, direction: GatherDirection) -> Vec<Arc<str>> {
let mut neighbors = Vec::new();
match direction {
GatherDirection::Callees | GatherDirection::Both => {
if let Some(callees) = graph.forward.get(name) {
neighbors.extend(callees.iter().map(Arc::clone));
}
}
_ => {}
}
match direction {
GatherDirection::Callers | GatherDirection::Both => {
if let Some(callers) = graph.reverse.get(name) {
neighbors.extend(callers.iter().map(Arc::clone));
}
}
_ => {}
}
neighbors
}
#[cfg(test)]
mod tests {
use super::*;
fn make_graph() -> CallGraph {
let mut forward = HashMap::new();
let mut reverse = HashMap::new();
forward.insert("A".to_string(), vec!["B".to_string(), "C".to_string()]);
forward.insert("B".to_string(), vec!["D".to_string()]);
reverse.insert("B".to_string(), vec!["A".to_string()]);
reverse.insert("C".to_string(), vec!["A".to_string()]);
reverse.insert("D".to_string(), vec!["B".to_string()]);
CallGraph::from_string_maps(forward, reverse)
}
#[test]
fn test_direction_parse() {
assert!(matches!(
"both".parse::<GatherDirection>().unwrap(),
GatherDirection::Both
));
assert!(matches!(
"callers".parse::<GatherDirection>().unwrap(),
GatherDirection::Callers
));
assert!(matches!(
"callees".parse::<GatherDirection>().unwrap(),
GatherDirection::Callees
));
assert!("invalid".parse::<GatherDirection>().is_err());
}
#[test]
fn test_default_options() {
let opts = GatherOptions::default();
assert_eq!(opts.expand_depth, 1);
assert_eq!(opts.limit, 10);
assert!(matches!(opts.direction, GatherDirection::Both));
}
#[test]
fn test_get_neighbors_callees() {
let graph = make_graph();
let neighbors = get_neighbors(&graph, "A", GatherDirection::Callees);
assert_eq!(neighbors.len(), 2);
assert!(neighbors.contains(&Arc::from("B")));
assert!(neighbors.contains(&Arc::from("C")));
}
#[test]
fn test_get_neighbors_callers() {
let graph = make_graph();
let neighbors = get_neighbors(&graph, "B", GatherDirection::Callers);
assert_eq!(neighbors.len(), 1);
assert_eq!(&*neighbors[0], "A");
}
#[test]
fn test_get_neighbors_both() {
let graph = make_graph();
let neighbors = get_neighbors(&graph, "B", GatherDirection::Both);
assert_eq!(neighbors.len(), 2);
assert!(neighbors.contains(&Arc::from("D")));
assert!(neighbors.contains(&Arc::from("A")));
}
#[test]
fn test_get_neighbors_unknown_node() {
let graph = make_graph();
let neighbors = get_neighbors(&graph, "Z", GatherDirection::Both);
assert!(neighbors.is_empty());
}
#[test]
fn test_get_neighbors_leaf_node() {
let graph = make_graph();
let callees = get_neighbors(&graph, "D", GatherDirection::Callees);
assert!(callees.is_empty());
let callers = get_neighbors(&graph, "D", GatherDirection::Callers);
assert_eq!(callers.len(), 1);
assert_eq!(&*callers[0], "B");
}
#[test]
fn test_gather_options_builder() {
let opts = GatherOptions::default()
.with_expand_depth(3)
.with_direction(GatherDirection::Callers)
.with_limit(20)
.with_seed_limit(10)
.with_seed_threshold(0.5)
.with_decay_factor(0.9);
assert_eq!(opts.expand_depth, 3);
assert!(matches!(opts.direction, GatherDirection::Callers));
assert_eq!(opts.limit, 20);
assert_eq!(opts.seed_limit, 10);
assert!((opts.seed_threshold - 0.5).abs() < f32::EPSILON);
assert!((opts.decay_factor - 0.9).abs() < f32::EPSILON);
}
#[test]
fn test_bfs_depth_preserves_minimum() {
let mut forward = HashMap::new();
let mut reverse = HashMap::new();
forward.insert("A".to_string(), vec!["B".to_string(), "C".to_string()]);
forward.insert("B".to_string(), vec!["D".to_string()]);
forward.insert("C".to_string(), vec!["D".to_string()]);
reverse.insert("B".to_string(), vec!["A".to_string()]);
reverse.insert("C".to_string(), vec!["A".to_string()]);
reverse.insert("D".to_string(), vec!["B".to_string(), "C".to_string()]);
let graph = CallGraph::from_string_maps(forward, reverse);
let mut name_scores = HashMap::new();
name_scores.insert("A".to_string(), (1.0, 0));
let opts = GatherOptions::default()
.with_expand_depth(3)
.with_direction(GatherDirection::Callees)
.with_decay_factor(0.8);
bfs_expand(&mut name_scores, &graph, &opts);
let (_, depth) = name_scores["D"];
assert_eq!(depth, 2, "D should preserve minimum depth of 2");
assert_eq!(name_scores["B"].1, 1);
assert_eq!(name_scores["C"].1, 1);
}
}