use std::collections::HashSet;
use std::path::Path;
use crate::graph::CodeGraph;
use crate::query::find::find_symbol;
use crate::rag::agent::QueryKind;
use crate::rag::embedding::EmbeddingEngine;
use crate::rag::vector_store::VectorStore;
const MAX_CITATIONS: usize = 5;
const CODE_SNIPPET_COUNT: usize = 3;
const MAX_SNIPPET_LINES: usize = 40;
#[derive(Debug, Clone)]
pub struct Citation {
pub index: usize,
pub file_path: String,
pub line_start: usize,
pub symbol_name: String,
}
#[derive(Debug, Clone)]
pub struct RetrievalResult {
pub context_text: String,
pub citations: Vec<Citation>,
pub tools_used: Vec<String>,
}
pub async fn retrieve(
graph: &CodeGraph,
vector_store: &VectorStore,
engine: &EmbeddingEngine,
query: &str,
kind: QueryKind,
) -> anyhow::Result<RetrievalResult> {
match kind {
QueryKind::Structural => {
let (context_text, citations, tools_used) = retrieve_structural(graph, query);
Ok(RetrievalResult {
context_text,
citations,
tools_used,
})
}
QueryKind::Conceptual => {
let (context_text, citations, tools_used) =
retrieve_conceptual(vector_store, engine, query).await?;
Ok(RetrievalResult {
context_text,
citations,
tools_used,
})
}
QueryKind::Hybrid => {
let (ctx_s, cit_s, tools_s) = retrieve_structural(graph, query);
let (ctx_c, cit_c, tools_c) = retrieve_conceptual(vector_store, engine, query).await?;
let mut seen: HashSet<(String, String)> = HashSet::new();
let mut merged_items: Vec<(String, usize, String)> = Vec::new();
for c in cit_s.iter().chain(cit_c.iter()) {
let key = (c.file_path.clone(), c.symbol_name.clone());
if seen.insert(key) {
merged_items.push((c.file_path.clone(), c.line_start, c.symbol_name.clone()));
}
}
let mut context_lines: Vec<String> = Vec::new();
let mut citations: Vec<Citation> = Vec::new();
for (i, (file_path, line_start, symbol_name)) in
merged_items.iter().take(MAX_CITATIONS).enumerate()
{
let idx = i + 1;
if i < CODE_SNIPPET_COUNT {
if let Some(snippet) = read_code_snippet(file_path, *line_start) {
context_lines.push(format!(
"[{idx}] `{symbol_name}` in {file_path}:{line_start}\n```\n{snippet}\n```"
));
} else {
context_lines.push(format!(
"[{idx}] `{symbol_name}` in {file_path}:{line_start}"
));
}
} else {
context_lines.push(format!(
"[{idx}] `{symbol_name}` in {file_path}:{line_start}"
));
}
citations.push(Citation {
index: idx,
file_path: file_path.clone(),
line_start: *line_start,
symbol_name: symbol_name.clone(),
});
}
let mut tools_used: Vec<String> = tools_s;
for t in tools_c {
if !tools_used.contains(&t) {
tools_used.push(t);
}
}
let context_text = if context_lines.is_empty() {
let combined = format!("{}\n{}", ctx_s, ctx_c);
combined.trim().to_string()
} else {
context_lines.join("\n")
};
Ok(RetrievalResult {
context_text,
citations,
tools_used,
})
}
}
}
pub fn retrieve_structural(graph: &CodeGraph, query: &str) -> (String, Vec<Citation>, Vec<String>) {
let mut tools_used = vec!["find_symbol".to_string()];
let pattern = extract_search_pattern(query);
let project_root = Path::new(".");
let results =
find_symbol(graph, &pattern, true, &[], None, project_root, None).unwrap_or_default();
if results.is_empty() {
return (String::new(), Vec::new(), tools_used);
}
tools_used.push("get_context".to_string());
let mut context_lines: Vec<String> = Vec::new();
let mut citations: Vec<Citation> = Vec::new();
for (i, result) in results.iter().take(MAX_CITATIONS).enumerate() {
let idx = i + 1;
let file_str = result.file_path.to_string_lossy().to_string();
let kind_str = crate::query::find::kind_to_str(&result.kind);
if i < CODE_SNIPPET_COUNT {
if let Some(snippet) = read_code_snippet(&file_str, result.line) {
context_lines.push(format!(
"[{idx}] {kind_str} `{}` in {}:{}\n```\n{}\n```",
result.symbol_name, file_str, result.line, snippet
));
} else {
context_lines.push(format!(
"[{idx}] {kind_str} `{}` in {}:{}",
result.symbol_name, file_str, result.line
));
}
} else {
context_lines.push(format!(
"[{idx}] {kind_str} `{}` in {}:{}",
result.symbol_name, file_str, result.line
));
}
citations.push(Citation {
index: idx,
file_path: file_str,
line_start: result.line,
symbol_name: result.symbol_name.clone(),
});
}
(context_lines.join("\n\n"), citations, tools_used)
}
pub async fn retrieve_conceptual(
vector_store: &VectorStore,
engine: &EmbeddingEngine,
query: &str,
) -> anyhow::Result<(String, Vec<Citation>, Vec<String>)> {
let tools_used = vec!["vector_search".to_string()];
let embeddings = engine.embed_batch(vec![query.to_string()]).await?;
let query_embedding = embeddings
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("embedding engine returned no results"))?;
let results = vector_store.search(&query_embedding, MAX_CITATIONS)?;
if results.is_empty() {
return Ok((String::new(), Vec::new(), tools_used));
}
let mut context_lines: Vec<String> = Vec::new();
let mut citations: Vec<Citation> = Vec::new();
for (i, (meta, _distance)) in results.iter().take(MAX_CITATIONS).enumerate() {
let idx = i + 1;
if i < CODE_SNIPPET_COUNT {
if let Some(snippet) = read_code_snippet(&meta.file_path, meta.line_start) {
context_lines.push(format!(
"[{idx}] {} `{}` in {}:{}\n```\n{}\n```",
meta.kind, meta.symbol_name, meta.file_path, meta.line_start, snippet
));
} else {
context_lines.push(format!(
"[{idx}] {} `{}` in {}:{}",
meta.kind, meta.symbol_name, meta.file_path, meta.line_start
));
}
} else {
context_lines.push(format!(
"[{idx}] {} `{}` in {}:{}",
meta.kind, meta.symbol_name, meta.file_path, meta.line_start
));
}
citations.push(Citation {
index: idx,
file_path: meta.file_path.clone(),
line_start: meta.line_start,
symbol_name: meta.symbol_name.clone(),
});
}
Ok((context_lines.join("\n\n"), citations, tools_used))
}
fn read_code_snippet(file_path: &str, line_start: usize) -> Option<String> {
let content = std::fs::read_to_string(file_path).ok()?;
let lines: Vec<&str> = content.lines().collect();
let start = line_start.saturating_sub(1); let end = (start + MAX_SNIPPET_LINES).min(lines.len());
if start >= lines.len() {
return None;
}
Some(lines[start..end].join("\n"))
}
fn extract_search_pattern(query: &str) -> String {
const STOP_WORDS: &[&str] = &[
"where",
"what",
"which",
"how",
"why",
"when",
"who",
"the",
"a",
"an",
"this",
"that",
"these",
"those",
"its",
"in",
"of",
"for",
"with",
"about",
"to",
"from",
"by",
"on",
"at",
"and",
"or",
"is",
"are",
"was",
"were",
"does",
"do",
"did",
"has",
"have",
"had",
"can",
"could",
"will",
"would",
"should",
"calls",
"find",
"explain",
"describe",
"show",
"locate",
"uses",
"used",
"function",
"method",
"struct",
"class",
"module",
"type",
"enum",
"me",
"it",
"all",
"any",
"some",
"not",
"be",
"tool",
"code",
"codebase",
"project",
"file",
"support",
"work",
"programming",
];
let words: Vec<String> = query
.split_whitespace()
.map(|w| w.to_lowercase())
.filter(|w| {
w.len() > 2 && !STOP_WORDS.contains(&w.as_str())
})
.map(|w| stem_word(&w))
.collect();
if words.is_empty() {
regex::escape(query)
} else {
words
.iter()
.map(|w| regex::escape(w))
.collect::<Vec<_>>()
.join("|")
}
}
fn stem_word(word: &str) -> String {
let suffixes = [
"ies", "ing", "tion", "sion", "ment", "ness", "ed", "ly", "s",
];
for suffix in &suffixes {
if let Some(stem) = word.strip_suffix(suffix) {
if stem.len() >= 3 {
if *suffix == "ies" {
return format!("{}y", stem);
}
return stem.to_string();
}
}
}
word.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn citation_fields_are_correct() {
let c = Citation {
index: 3,
file_path: "src/services/auth.rs".to_string(),
line_start: 99,
symbol_name: "verify_token".to_string(),
};
assert_eq!(c.index, 3);
assert_eq!(c.file_path, "src/services/auth.rs");
assert_eq!(c.line_start, 99);
assert_eq!(c.symbol_name, "verify_token");
}
#[test]
fn retrieval_result_context_text_contains_numbered_citations() {
let citations = [
Citation {
index: 1,
file_path: "src/auth.rs".to_string(),
line_start: 10,
symbol_name: "auth_handler".to_string(),
},
Citation {
index: 2,
file_path: "src/user.rs".to_string(),
line_start: 20,
symbol_name: "get_user".to_string(),
},
];
let context_text = citations
.iter()
.map(|c| {
format!(
"[{}] {} in {}:{}",
c.index, c.symbol_name, c.file_path, c.line_start
)
})
.collect::<Vec<_>>()
.join("\n");
assert!(context_text.contains("[1]"), "should have [1] marker");
assert!(context_text.contains("[2]"), "should have [2] marker");
assert!(
context_text.contains("auth_handler"),
"should mention symbol name"
);
assert!(
context_text.contains("src/auth.rs:10"),
"should have file:line"
);
}
#[test]
fn tools_used_accumulates_tool_names() {
let graph = CodeGraph::new();
let (_ctx, _cit, tools) = retrieve_structural(&graph, "find auth");
assert!(
tools.contains(&"find_symbol".to_string()),
"structural retrieval should track find_symbol"
);
}
#[test]
fn retrieve_structural_empty_graph_returns_empty_context() {
let graph = CodeGraph::new();
let (ctx, cit, _tools) = retrieve_structural(&graph, "some query");
assert!(ctx.is_empty(), "empty graph should produce empty context");
assert!(cit.is_empty(), "empty graph should produce no citations");
}
#[test]
fn extract_search_pattern_strips_stop_words() {
let pattern = extract_search_pattern("where is auth");
assert!(
pattern.contains("auth"),
"should retain meaningful keyword 'auth'"
);
let pattern2 = extract_search_pattern("find UserService");
assert!(
pattern2.contains("userservice"),
"should retain 'UserService' (lowercased)"
);
}
#[test]
fn extract_search_pattern_stems_plurals() {
let pattern = extract_search_pattern("what languages does this tool support");
assert!(
pattern.contains("language"),
"should stem 'languages' to 'language', got: {pattern}"
);
assert!(
!pattern.contains("tool"),
"generic word 'tool' should be stopped"
);
assert!(!pattern.contains("this"), "'this' should be stopped");
}
#[test]
fn stem_word_handles_common_suffixes() {
assert_eq!(stem_word("languages"), "language");
assert_eq!(stem_word("handlers"), "handler"); assert_eq!(stem_word("queries"), "query");
assert_eq!(stem_word("caching"), "cach");
assert_eq!(stem_word("auth"), "auth"); }
}