use std::collections::HashSet;
use std::fs;
use std::path::Path;
use crate::db::Database;
use crate::errors::Result;
use crate::graph::GraphTraverser;
use crate::types::*;
pub struct ContextBuilder<'a> {
db: &'a Database,
project_root: &'a Path,
}
impl<'a> ContextBuilder<'a> {
pub fn new(db: &'a Database, project_root: &'a Path) -> Self {
Self { db, project_root }
}
pub async fn build_context(&self, query: &str, options: &BuildContextOptions) -> Result<TaskContext> {
debug_assert!(!query.is_empty(), "build_context called with empty query");
debug_assert!(options.max_nodes > 0, "max_nodes must be positive");
let symbols = extract_symbols_from_query(query);
let entry_points = self.find_entry_points(query, &symbols, options).await?;
let subgraph = self.expand_subgraph(&entry_points, options).await?;
let code_blocks = if options.include_code {
self.extract_code_blocks(&entry_points, options).await?
} else {
Vec::new()
};
let related_files = self.collect_related_files(&subgraph);
let summary = self.build_summary(query, &entry_points, &subgraph);
Ok(TaskContext {
query: query.to_string(),
summary,
subgraph,
entry_points,
code_blocks,
related_files,
})
}
pub async fn find_relevant_context(
&self,
query: &str,
options: &BuildContextOptions,
) -> Result<Subgraph> {
let symbols = extract_symbols_from_query(query);
let entry_points = self.find_entry_points(query, &symbols, options).await?;
self.expand_subgraph(&entry_points, options).await
}
pub async fn get_code(&self, node: &Node) -> Result<Option<String>> {
debug_assert!(!node.file_path.is_empty(), "get_code called with empty file_path");
debug_assert!(!node.id.is_empty(), "get_code called with empty node id");
let file_path = self.project_root.join(&node.file_path);
if let (Ok(canonical), Ok(root)) = (file_path.canonicalize(), self.project_root.canonicalize()) {
if !canonical.starts_with(&root) {
return Ok(None);
}
}
let content = match fs::read_to_string(&file_path) {
Ok(c) => c,
Err(_) => return Ok(None),
};
let lines: Vec<&str> = content.lines().collect();
if node.start_line == 0 || node.end_line == 0 {
return Ok(None);
}
let start = (node.start_line as usize).saturating_sub(1);
let end = node.end_line as usize;
if start >= lines.len() {
return Ok(None);
}
let end = end.min(lines.len());
let snippet: String = lines[start..end].join("\n");
if snippet.is_empty() {
Ok(None)
} else {
Ok(Some(snippet))
}
}
async fn find_entry_points(
&self,
query: &str,
symbols: &[String],
options: &BuildContextOptions,
) -> Result<Vec<Node>> {
debug_assert!(!query.is_empty(), "find_entry_points called with empty query");
debug_assert!(options.search_limit > 0, "search_limit must be positive");
let mut seen_ids: HashSet<String> = HashSet::new();
let mut entry_points: Vec<Node> = Vec::new();
let search_results = self.db.search_nodes(query, options.search_limit).await?;
for sr in &search_results {
if self.score_passes(sr.score, options.min_score) && seen_ids.insert(sr.node.id.clone())
{
entry_points.push(sr.node.clone());
}
}
for symbol in symbols {
if entry_points.len() >= options.max_nodes {
break;
}
let results = self.db.search_nodes(symbol, options.search_limit).await?;
for sr in results {
if self.score_passes(sr.score, options.min_score)
&& seen_ids.insert(sr.node.id.clone())
{
entry_points.push(sr.node.clone());
}
}
}
for keyword in &options.extra_keywords {
if entry_points.len() >= options.max_nodes {
break;
}
let results = self.db.search_nodes(keyword, options.search_limit).await?;
for sr in results {
if self.score_passes(sr.score, options.min_score)
&& seen_ids.insert(sr.node.id.clone())
{
entry_points.push(sr.node.clone());
}
}
}
entry_points.truncate(options.max_nodes);
debug_assert!(entry_points.len() <= options.max_nodes, "entry_points exceeds max_nodes after truncation");
Ok(entry_points)
}
async fn expand_subgraph(
&self,
entry_points: &[Node],
options: &BuildContextOptions,
) -> Result<Subgraph> {
debug_assert!(options.traversal_depth > 0, "traversal_depth must be positive");
debug_assert!(options.max_nodes > 0, "max_nodes must be positive for expand_subgraph");
let traverser = GraphTraverser::new(self.db);
let mut all_nodes: Vec<Node> = Vec::new();
let mut all_edges: Vec<Edge> = Vec::new();
let mut all_roots: Vec<String> = Vec::new();
let mut seen_node_ids: HashSet<String> = HashSet::new();
let mut seen_edge_keys: HashSet<(String, String, String)> = HashSet::new();
let traversal_opts = TraversalOptions {
max_depth: options.traversal_depth as u32,
edge_kinds: None,
node_kinds: None,
direction: TraversalDirection::Both,
limit: options.max_nodes as u32,
include_start: true,
};
for node in entry_points {
let sub = traverser.traverse_bfs(&node.id, &traversal_opts).await?;
for root in sub.roots {
if !all_roots.contains(&root) {
all_roots.push(root);
}
}
for n in sub.nodes {
if seen_node_ids.insert(n.id.clone()) {
all_nodes.push(n);
}
}
for e in sub.edges {
let key = (
e.source.clone(),
e.target.clone(),
e.kind.as_str().to_string(),
);
if seen_edge_keys.insert(key) {
all_edges.push(e);
}
}
if all_nodes.len() >= options.max_nodes {
break;
}
}
all_nodes.truncate(options.max_nodes);
Ok(Subgraph {
nodes: all_nodes,
edges: all_edges,
roots: all_roots,
})
}
async fn extract_code_blocks(
&self,
entry_points: &[Node],
options: &BuildContextOptions,
) -> Result<Vec<CodeBlock>> {
debug_assert!(options.max_code_blocks > 0, "max_code_blocks must be positive");
debug_assert!(options.max_code_block_size > 0, "max_code_block_size must be positive");
let mut blocks: Vec<CodeBlock> = Vec::new();
for node in entry_points {
if blocks.len() >= options.max_code_blocks {
break;
}
if let Some(code) = self.get_code(node).await? {
let truncated = if code.len() > options.max_code_block_size {
let mut end = options.max_code_block_size;
while !code.is_char_boundary(end) && end > 0 {
end -= 1;
}
if let Some(pos) = code[..end].rfind('\n') {
end = pos;
}
format!("{}...", &code[..end])
} else {
code
};
blocks.push(CodeBlock {
content: truncated,
file_path: node.file_path.clone(),
start_line: node.start_line,
end_line: node.end_line,
node_id: Some(node.id.clone()),
});
}
}
Ok(blocks)
}
fn score_passes(&self, score: f64, min_score: f64) -> bool {
score > 0.0 && score >= min_score
}
fn collect_related_files(&self, subgraph: &Subgraph) -> Vec<String> {
let mut seen: HashSet<String> = HashSet::new();
let mut files: Vec<String> = Vec::new();
for node in &subgraph.nodes {
if seen.insert(node.file_path.clone()) {
files.push(node.file_path.clone());
}
}
files
}
fn build_summary(&self, query: &str, entry_points: &[Node], subgraph: &Subgraph) -> String {
let ep_count = entry_points.len();
let node_count = subgraph.nodes.len();
let edge_count = subgraph.edges.len();
if ep_count == 0 {
format!("No matching symbols found for \"{query}\"")
} else {
format!(
"Found {ep_count} entry point(s) for \"{query}\" with {node_count} related node(s) and {edge_count} edge(s)"
)
}
}
}
pub fn extract_symbols_from_query(query: &str) -> Vec<String> {
debug_assert!(!query.is_empty(), "extract_symbols_from_query called with empty query");
let stop_words: HashSet<&str> = SYMBOL_STOP_WORDS.iter().copied().collect();
let mut symbols: Vec<String> = Vec::new();
let mut seen: HashSet<String> = HashSet::new();
for token in query.split_whitespace() {
let clean = token.trim_matches(|c: char| !c.is_alphanumeric() && c != '_' && c != ':');
classify_token(clean, &stop_words, &mut symbols, &mut seen);
}
symbols
}
const SYMBOL_STOP_WORDS: &[&str] = &[
"the", "is", "in", "for", "to", "a", "an", "of", "and", "or", "not",
"this", "that", "it", "with", "on", "at", "by", "from", "as", "be",
"was", "are", "been", "being", "have", "has", "had", "do", "does", "did",
"will", "would", "could", "should", "may", "might", "can", "shall",
"how", "what", "where", "when", "who", "which", "why",
"if", "then", "else", "but", "so", "up", "out", "no", "yes",
"all", "any", "each", "every",
"fix", "look", "update", "add", "remove", "delete", "change", "check",
"find", "get", "set", "use", "make", "call",
"function", "method", "class", "struct", "type", "module", "file",
"handler", "implement", "create", "about",
"interface", "trait", "enum", "variable", "import", "export",
"return", "error", "test", "spec", "helper", "util",
"config", "service", "model", "view", "controller",
"code", "new", "init", "default", "value", "data", "result",
];
fn classify_token(
clean: &str,
stop_words: &HashSet<&str>,
symbols: &mut Vec<String>,
seen: &mut HashSet<String>,
) {
if clean.is_empty() { return; }
if clean.contains("::") {
if let Some(last) = clean.rsplit("::").next() {
if !last.is_empty()
&& !stop_words.contains(last.to_lowercase().as_str())
&& seen.insert(last.to_string())
{
symbols.push(last.to_string());
}
}
let full = clean.to_string();
if seen.insert(full.clone()) {
symbols.push(full);
}
return;
}
if clean.contains('_') {
if !stop_words.contains(clean.to_lowercase().as_str()) && seen.insert(clean.to_string()) {
symbols.push(clean.to_string());
}
for part in split_compound(clean) {
if part.len() >= 3
&& !stop_words.contains(part.to_lowercase().as_str())
&& seen.insert(part.to_string())
{
symbols.push(part.to_string());
}
}
return;
}
if is_camel_case(clean) {
if !stop_words.contains(clean.to_lowercase().as_str()) && seen.insert(clean.to_string()) {
symbols.push(clean.to_string());
}
for part in split_compound(clean) {
if part.len() >= 3
&& !stop_words.contains(part.to_lowercase().as_str())
&& seen.insert(part.to_string())
{
symbols.push(part.to_string());
}
}
}
}
fn split_compound(name: &str) -> Vec<&str> {
if name.contains('_') {
return name.split('_').filter(|s| !s.is_empty()).collect();
}
let bytes = name.as_bytes();
let mut parts = Vec::new();
let mut start = 0;
for i in 1..bytes.len() {
let cur = bytes[i] as char;
let prev = bytes[i - 1] as char;
let boundary = prev.is_ascii_lowercase() && cur.is_ascii_uppercase();
let acronym_end = i + 1 < bytes.len()
&& prev.is_ascii_uppercase()
&& cur.is_ascii_uppercase()
&& (bytes[i + 1] as char).is_ascii_lowercase();
if boundary || acronym_end {
if i > start {
parts.push(&name[start..i]);
}
start = i;
}
}
if start < name.len() {
parts.push(&name[start..]);
}
parts
}
fn is_camel_case(word: &str) -> bool {
if word.len() < 2 {
return false;
}
if !word.chars().all(|c| c.is_ascii_alphanumeric()) {
return false;
}
word[1..].chars().any(|c| c.is_ascii_uppercase())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_snake_case() {
let symbols = extract_symbols_from_query("fix the process_request function");
assert!(symbols.contains(&"process_request".to_string()));
}
#[test]
fn test_extract_camel_case() {
let symbols = extract_symbols_from_query("update UserService handler");
assert!(symbols.contains(&"UserService".to_string()));
}
#[test]
fn test_extract_screaming_snake() {
let symbols = extract_symbols_from_query("increase MAX_RETRIES limit");
assert!(symbols.contains(&"MAX_RETRIES".to_string()));
}
#[test]
fn test_extract_qualified_path() {
let symbols = extract_symbols_from_query("look at crate::types::Node");
assert!(symbols.iter().any(|s| s.contains("Node")));
}
#[test]
fn test_filters_stop_words() {
let symbols = extract_symbols_from_query("the is in for to a an");
assert!(symbols.is_empty());
}
#[test]
fn test_is_camel_case() {
assert!(is_camel_case("UserService"));
assert!(is_camel_case("processRequest"));
assert!(!is_camel_case("user"));
assert!(!is_camel_case("U"));
assert!(!is_camel_case("process_request"));
}
}