use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use super::bm25::{Bm25Index, Bm25Result};
use super::text::{self, SearchMatch};
use crate::ast::parser::parse_file;
use crate::types::{CodeStructure, DefinitionInfo, Language};
use crate::TldrResult;
#[derive(Debug, Clone, Default)]
pub enum SearchMode {
#[default]
Bm25,
Regex(String),
Hybrid {
query: String,
pattern: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnrichedResult {
pub name: String,
pub kind: String,
pub file: PathBuf,
pub line_range: (u32, u32),
pub signature: String,
pub callers: Vec<String>,
pub callees: Vec<String>,
pub score: f64,
pub matched_terms: Vec<String>,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub preview: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnrichedSearchReport {
pub query: String,
pub results: Vec<EnrichedResult>,
pub total_files_searched: usize,
pub search_mode: String,
}
#[derive(Debug, Clone)]
struct StructureEntry {
name: String,
kind: String,
line_start: u32,
line_end: u32,
signature: String,
preview: String,
}
#[derive(Debug, Clone)]
pub struct EnrichedSearchOptions {
pub top_k: usize,
pub include_callgraph: bool,
pub search_mode: SearchMode,
}
impl Default for EnrichedSearchOptions {
fn default() -> Self {
Self {
top_k: 10,
include_callgraph: true,
search_mode: SearchMode::default(),
}
}
}
#[derive(Debug, Clone)]
pub struct CallGraphLookup {
pub forward: HashMap<String, Vec<String>>,
pub reverse: HashMap<String, Vec<String>>,
}
#[derive(Debug, Clone, Deserialize)]
struct WarmCallEdge {
#[allow(dead_code)]
from_file: PathBuf,
from_func: String,
#[allow(dead_code)]
to_file: PathBuf,
to_func: String,
}
#[derive(Debug, Clone, Deserialize)]
struct WarmCallGraphCache {
edges: Vec<WarmCallEdge>,
#[allow(dead_code)]
languages: Vec<String>,
#[allow(dead_code)]
timestamp: i64,
}
pub fn read_callgraph_cache(cache_path: &Path) -> TldrResult<CallGraphLookup> {
let content = std::fs::read_to_string(cache_path).map_err(crate::TldrError::IoError)?;
let cache: WarmCallGraphCache = serde_json::from_str(&content).map_err(|e| {
crate::TldrError::SerializationError(format!("Failed to parse call graph cache: {}", e))
})?;
let mut forward: HashMap<String, Vec<String>> = HashMap::new();
let mut reverse: HashMap<String, Vec<String>> = HashMap::new();
for edge in &cache.edges {
forward
.entry(edge.from_func.clone())
.or_default()
.push(edge.to_func.clone());
reverse
.entry(edge.to_func.clone())
.or_default()
.push(edge.from_func.clone());
}
Ok(CallGraphLookup { forward, reverse })
}
#[derive(Debug, Clone)]
pub struct StructureLookup {
pub by_file: HashMap<PathBuf, Vec<DefinitionInfo>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct StructureCacheEnvelope {
files: Vec<CachedFileEntry>,
timestamp: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct CachedFileEntry {
path: PathBuf,
definitions: Vec<DefinitionInfo>,
}
pub fn write_structure_cache(structure: &CodeStructure, cache_path: &Path) -> TldrResult<()> {
let envelope = StructureCacheEnvelope {
files: structure
.files
.iter()
.map(|f| CachedFileEntry {
path: f.path.clone(),
definitions: f.definitions.clone(),
})
.collect(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64,
};
let json = serde_json::to_string_pretty(&envelope).map_err(|e| {
crate::TldrError::SerializationError(format!("Failed to serialize structure cache: {}", e))
})?;
if let Some(parent) = cache_path.parent() {
std::fs::create_dir_all(parent).map_err(crate::TldrError::IoError)?;
}
std::fs::write(cache_path, json).map_err(crate::TldrError::IoError)?;
Ok(())
}
pub fn read_structure_cache(cache_path: &Path) -> TldrResult<StructureLookup> {
let content = std::fs::read_to_string(cache_path).map_err(crate::TldrError::IoError)?;
let envelope: StructureCacheEnvelope = serde_json::from_str(&content).map_err(|e| {
crate::TldrError::SerializationError(format!("Failed to parse structure cache: {}", e))
})?;
let mut by_file = HashMap::new();
for entry in envelope.files {
by_file.insert(entry.path, entry.definitions);
}
Ok(StructureLookup { by_file })
}
fn regex_matches_to_bm25_results(matches: &[SearchMatch]) -> Vec<Bm25Result> {
let mut file_counts: HashMap<PathBuf, usize> = HashMap::new();
for m in matches {
*file_counts.entry(m.file.clone()).or_insert(0) += 1;
}
matches
.iter()
.map(|m| {
let file_match_count = file_counts[&m.file] as f64;
Bm25Result {
file_path: m.file.clone(),
score: file_match_count,
line_start: m.line,
line_end: m.line,
snippet: m.content.clone(),
matched_terms: vec![], }
})
.collect()
}
fn do_regex_search(
pattern: &str,
root: &Path,
language: Language,
top_k: usize,
) -> crate::TldrResult<(Vec<SearchMatch>, usize)> {
let extensions: HashSet<String> = language
.extensions()
.iter()
.map(|e| e.to_string())
.collect();
let raw_limit = (top_k * 10).max(200);
let matches = text::search(
pattern,
root,
Some(&extensions),
0, raw_limit,
usize::MAX, None, )?;
let unique_files: HashSet<&PathBuf> = matches.iter().map(|m| &m.file).collect();
let total_files = walkdir::WalkDir::new(root)
.follow_links(false)
.into_iter()
.filter_map(|e| e.ok())
.filter(|e| !e.file_type().is_dir())
.filter(|e| {
e.path()
.extension()
.and_then(|ext| ext.to_str())
.map(|ext| {
let with_dot = format!(".{}", ext);
extensions.contains(&with_dot) || extensions.contains(ext)
})
.unwrap_or(false)
})
.count();
let total = total_files.max(unique_files.len());
Ok((matches, total))
}
pub fn enriched_search(
query: &str,
root: &Path,
language: Language,
options: EnrichedSearchOptions,
) -> TldrResult<EnrichedSearchReport> {
search_with_inner(query, root, language, options, None, None, None)
}
pub fn enriched_search_with_callgraph_cache(
query: &str,
root: &Path,
language: Language,
options: EnrichedSearchOptions,
cache_path: &Path,
) -> TldrResult<EnrichedSearchReport> {
search_with_inner(query, root, language, options, None, None, Some(cache_path))
}
pub fn enriched_search_with_index(
query: &str,
root: &Path,
language: Language,
options: EnrichedSearchOptions,
index: &Bm25Index,
) -> TldrResult<EnrichedSearchReport> {
search_with_inner(query, root, language, options, Some(index), None, None)
}
fn process_file_results(
rel_path: &PathBuf,
results: &[&Bm25Result],
root: &Path,
language: Language,
cached_defs: Option<&[DefinitionInfo]>,
) -> Vec<((PathBuf, String), EnrichedResult)> {
let abs_path = root.join(rel_path);
let entries = if let Some(defs) = cached_defs {
defs.iter()
.map(|d| StructureEntry {
name: d.name.clone(),
kind: d.kind.clone(),
line_start: d.line_start,
line_end: d.line_end,
signature: d.signature.clone(),
preview: String::new(), })
.collect()
} else {
match extract_structure_entries(&abs_path, language) {
Ok(entries) => entries,
Err(_) => {
let mut local_dedup: HashMap<(PathBuf, String), EnrichedResult> = HashMap::new();
for result in results {
let key = (rel_path.clone(), rel_path.display().to_string());
let entry = local_dedup.entry(key).or_insert_with(|| EnrichedResult {
name: rel_path.display().to_string(),
kind: "module".to_string(),
file: rel_path.clone(),
line_range: (result.line_start, result.line_end),
signature: result.snippet.lines().next().unwrap_or("").to_string(),
callers: Vec::new(),
callees: Vec::new(),
score: result.score,
matched_terms: result.matched_terms.clone(),
preview: String::new(),
});
if result.score > entry.score {
entry.score = result.score;
}
}
return local_dedup.into_iter().collect();
}
}
};
let mut local_dedup: HashMap<(PathBuf, String), EnrichedResult> = HashMap::new();
for result in results {
let enclosing = (result.line_start..=result.line_end)
.filter_map(|line| find_enclosing_entry(&entries, line))
.min_by_key(|e| e.line_end - e.line_start);
match enclosing {
Some(entry) => {
let key = (rel_path.clone(), entry.name.clone());
let enriched = local_dedup.entry(key).or_insert_with(|| EnrichedResult {
name: entry.name.clone(),
kind: entry.kind.clone(),
file: rel_path.clone(),
line_range: (entry.line_start, entry.line_end),
signature: entry.signature.clone(),
callers: Vec::new(),
callees: Vec::new(),
score: result.score,
matched_terms: result.matched_terms.clone(),
preview: entry.preview.clone(),
});
if result.score > enriched.score {
enriched.score = result.score;
}
for term in &result.matched_terms {
if !enriched.matched_terms.contains(term) {
enriched.matched_terms.push(term.clone());
}
}
}
None => {
let sig = result
.snippet
.lines()
.find(|l| {
let t = l.trim();
!t.is_empty()
&& !t.starts_with("///")
&& !t.starts_with("//!")
&& !t.starts_with("//")
&& !t.starts_with("/*")
&& !t.starts_with("*")
})
.or_else(|| result.snippet.lines().next())
.unwrap_or("")
.trim()
.to_string();
let key = (
rel_path.clone(),
format!("{}:{}", rel_path.display(), result.line_start),
);
local_dedup.entry(key).or_insert_with(|| EnrichedResult {
name: rel_path
.file_stem()
.map(|s| s.to_string_lossy().to_string())
.unwrap_or_else(|| rel_path.display().to_string()),
kind: "module".to_string(),
file: rel_path.clone(),
line_range: (result.line_start, result.line_end),
signature: sig,
callers: Vec::new(),
callees: Vec::new(),
score: result.score,
matched_terms: result.matched_terms.clone(),
preview: result.snippet.clone(),
});
}
}
}
local_dedup.into_iter().collect()
}
fn enrich_and_deduplicate(
raw_results: &[Bm25Result],
root: &Path,
language: Language,
) -> Vec<EnrichedResult> {
let mut by_file: HashMap<PathBuf, Vec<&Bm25Result>> = HashMap::new();
for result in raw_results {
by_file
.entry(result.file_path.clone())
.or_default()
.push(result);
}
let by_file_vec: Vec<(&PathBuf, &Vec<&Bm25Result>)> = by_file.iter().collect();
let file_results: Vec<Vec<((PathBuf, String), EnrichedResult)>> = if by_file_vec.len() >= 4 {
by_file_vec
.par_iter()
.map(|(rel_path, results)| {
process_file_results(rel_path, results, root, language, None)
})
.collect()
} else {
by_file_vec
.iter()
.map(|(rel_path, results)| {
process_file_results(rel_path, results, root, language, None)
})
.collect()
};
let mut dedup: HashMap<(PathBuf, String), EnrichedResult> = HashMap::new();
for file_entries in file_results {
for (key, entry) in file_entries {
let existing = dedup.entry(key).or_insert(entry.clone());
if entry.score > existing.score {
existing.score = entry.score;
}
for term in &entry.matched_terms {
if !existing.matched_terms.contains(term) {
existing.matched_terms.push(term.clone());
}
}
}
}
dedup.into_values().collect()
}
fn enrich_and_deduplicate_with_cache(
raw_results: &[Bm25Result],
root: &Path,
language: Language,
structure_lookup: &StructureLookup,
) -> Vec<EnrichedResult> {
let mut by_file: HashMap<PathBuf, Vec<&Bm25Result>> = HashMap::new();
for result in raw_results {
by_file
.entry(result.file_path.clone())
.or_default()
.push(result);
}
let by_file_vec: Vec<(&PathBuf, &Vec<&Bm25Result>)> = by_file.iter().collect();
let file_results: Vec<Vec<((PathBuf, String), EnrichedResult)>> = if by_file_vec.len() >= 4 {
by_file_vec
.par_iter()
.map(|(rel_path, results)| {
let cached = structure_lookup
.by_file
.get(*rel_path)
.map(|v| v.as_slice());
process_file_results(rel_path, results, root, language, cached)
})
.collect()
} else {
by_file_vec
.iter()
.map(|(rel_path, results)| {
let cached = structure_lookup
.by_file
.get(*rel_path)
.map(|v| v.as_slice());
process_file_results(rel_path, results, root, language, cached)
})
.collect()
};
let mut dedup: HashMap<(PathBuf, String), EnrichedResult> = HashMap::new();
for file_entries in file_results {
for (key, entry) in file_entries {
let existing = dedup.entry(key).or_insert(entry.clone());
if entry.score > existing.score {
existing.score = entry.score;
}
for term in &entry.matched_terms {
if !existing.matched_terms.contains(term) {
existing.matched_terms.push(term.clone());
}
}
}
}
dedup.into_values().collect()
}
pub fn enriched_search_with_structure_cache(
query: &str,
root: &Path,
language: Language,
options: EnrichedSearchOptions,
structure_lookup: &StructureLookup,
) -> TldrResult<EnrichedSearchReport> {
search_with_inner(
query,
root,
language,
options,
None,
Some(structure_lookup),
None,
)
}
pub fn search_with_inner(
query: &str,
root: &Path,
language: Language,
options: EnrichedSearchOptions,
bm25_index: Option<&Bm25Index>,
structure_cache: Option<&StructureLookup>,
callgraph_cache_path: Option<&Path>,
) -> TldrResult<EnrichedSearchReport> {
let top_k = options.top_k;
let mode_prefix;
let (raw_results, total_files) = match &options.search_mode {
SearchMode::Bm25 => {
mode_prefix = "bm25";
match bm25_index {
Some(idx) => {
let total = idx.document_count();
if idx.is_empty() {
return Ok(EnrichedSearchReport {
query: query.to_string(),
results: Vec::new(),
total_files_searched: 0,
search_mode: if structure_cache.is_some() {
"bm25+cached-structure".to_string()
} else {
"bm25+structure".to_string()
},
});
}
let raw_limit = (top_k * 5).max(50);
(idx.search(query, raw_limit), total)
}
None => {
let index = Bm25Index::from_project(root, language)?;
let total = index.document_count();
if index.is_empty() {
return Ok(EnrichedSearchReport {
query: query.to_string(),
results: Vec::new(),
total_files_searched: 0,
search_mode: if structure_cache.is_some() {
"bm25+cached-structure".to_string()
} else {
"bm25+structure".to_string()
},
});
}
let raw_limit = (top_k * 5).max(50);
(index.search(query, raw_limit), total)
}
}
}
SearchMode::Regex(pattern) => {
mode_prefix = "regex";
let (matches, total) = do_regex_search(pattern, root, language, top_k)?;
if matches.is_empty() {
return Ok(EnrichedSearchReport {
query: pattern.clone(),
results: Vec::new(),
total_files_searched: total,
search_mode: if structure_cache.is_some() {
"regex+cached-structure".to_string()
} else {
"regex+structure".to_string()
},
});
}
(regex_matches_to_bm25_results(&matches), total)
}
SearchMode::Hybrid {
query: hybrid_query,
pattern,
} => {
mode_prefix = "hybrid(bm25+regex)";
let raw_limit = (top_k * 5).max(50);
let (bm25_results, total_files) = match bm25_index {
Some(idx) => {
let total = idx.document_count();
if idx.is_empty() {
return Ok(EnrichedSearchReport {
query: hybrid_query.clone(),
results: Vec::new(),
total_files_searched: 0,
search_mode: "hybrid(bm25+regex)".to_string(),
});
}
(idx.search(hybrid_query, raw_limit), total)
}
None => {
let index = Bm25Index::from_project(root, language)?;
let total = index.document_count();
if index.is_empty() {
return Ok(EnrichedSearchReport {
query: hybrid_query.clone(),
results: Vec::new(),
total_files_searched: 0,
search_mode: "hybrid(bm25+regex)".to_string(),
});
}
(index.search(hybrid_query, raw_limit), total)
}
};
let (regex_matches, _regex_total) = do_regex_search(pattern, root, language, top_k)?;
if regex_matches.is_empty() {
return Ok(EnrichedSearchReport {
query: hybrid_query.clone(),
results: Vec::new(),
total_files_searched: total_files,
search_mode: "hybrid(bm25+regex)".to_string(),
});
}
let regex_results = regex_matches_to_bm25_results(®ex_matches);
let bm25_ranks: HashMap<&Path, usize> = bm25_results
.iter()
.enumerate()
.map(|(i, r)| (r.file_path.as_path(), i + 1))
.collect();
let regex_ranks: HashMap<&Path, usize> = regex_results
.iter()
.enumerate()
.map(|(i, r)| (r.file_path.as_path(), i + 1))
.collect();
let k = 60.0_f64;
let mut fused: Vec<Bm25Result> = Vec::new();
for bm25_result in &bm25_results {
if let Some(®ex_rank) = regex_ranks.get(bm25_result.file_path.as_path()) {
let bm25_rank = bm25_ranks[bm25_result.file_path.as_path()];
let rrf_score = 1.0 / (k + bm25_rank as f64) + 1.0 / (k + regex_rank as f64);
let mut result = bm25_result.clone();
result.score = rrf_score;
fused.push(result);
}
}
fused.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut seen_files: HashSet<PathBuf> = HashSet::new();
fused.retain(|r| seen_files.insert(r.file_path.clone()));
(fused, total_files)
}
};
let report_query = match &options.search_mode {
SearchMode::Bm25 => query.to_string(),
SearchMode::Regex(pattern) => pattern.clone(),
SearchMode::Hybrid {
query: hybrid_query,
..
} => hybrid_query.clone(),
};
let mut enriched = match structure_cache {
Some(lookup) => enrich_and_deduplicate_with_cache(&raw_results, root, language, lookup),
None => enrich_and_deduplicate(&raw_results, root, language),
};
let has_function_results = enriched.iter().any(|r| r.kind != "module");
for result in &mut enriched {
if result.kind == "module" {
result.score *= if has_function_results { 0.2 } else { 0.5 };
}
}
let mut sorted = enriched;
sorted.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.file.cmp(&b.file))
.then_with(|| a.name.cmp(&b.name))
});
sorted.truncate(top_k);
let structure_label = if structure_cache.is_some() {
"cached-structure"
} else {
"structure"
};
match callgraph_cache_path {
Some(path) => {
let lookup = read_callgraph_cache(path)?;
for result in &mut sorted {
if result.kind == "module" {
continue;
}
if let Some(callees) = lookup.forward.get(&result.name) {
result.callees = callees.clone();
result.callees.sort();
}
if let Some(callers) = lookup.reverse.get(&result.name) {
result.callers = callers.clone();
result.callers.sort();
}
}
Ok(EnrichedSearchReport {
query: report_query,
results: sorted,
total_files_searched: total_files,
search_mode: format!("{}+{}+callgraph", mode_prefix, structure_label),
})
}
None if options.include_callgraph => {
let sorted_enriched = try_enrich_with_callgraph(sorted, root, language);
Ok(EnrichedSearchReport {
query: report_query,
results: sorted_enriched,
total_files_searched: total_files,
search_mode: format!("{}+{}+callgraph", mode_prefix, structure_label),
})
}
None => {
Ok(EnrichedSearchReport {
query: report_query,
results: sorted,
total_files_searched: total_files,
search_mode: format!("{}+{}", mode_prefix, structure_label),
})
}
}
}
fn extract_structure_entries(path: &Path, language: Language) -> TldrResult<Vec<StructureEntry>> {
let (tree, source, _) = parse_file(path)?;
let root_node = tree.root_node();
let mut entries = Vec::new();
collect_structure_nodes(root_node, &source, language, &mut entries);
Ok(entries)
}
fn collect_structure_nodes(
node: tree_sitter::Node,
source: &str,
language: Language,
entries: &mut Vec<StructureEntry>,
) {
let kind = node.kind();
let (is_func, is_class) = classify_node(kind, language);
if is_func || is_class {
if let Some(name) = get_definition_name(node, source, language) {
let line_start = node.start_position().row as u32 + 1; let line_end = node.end_position().row as u32 + 1;
let signature = extract_definition_signature(node, source);
let entry_kind = if is_class {
match kind {
"struct_item" | "struct_definition" | "struct_specifier" => "struct",
_ => "class",
}
} else {
if is_inside_class_node(node) {
"method"
} else {
"function"
}
};
let preview = extract_code_preview(node, source, &signature, 5);
entries.push(StructureEntry {
name,
kind: entry_kind.to_string(),
line_start,
line_end,
signature,
preview,
});
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
collect_structure_nodes(child, source, language, entries);
}
}
fn classify_node(kind: &str, _language: Language) -> (bool, bool) {
let is_func = matches!(
kind,
"function_definition"
| "function_declaration"
| "function_item" | "method_definition"
| "method_declaration"
| "arrow_function"
| "function_expression"
| "function" | "func_literal" | "function_type"
);
let is_class = matches!(
kind,
"class_definition"
| "class_declaration"
| "struct_item" | "struct_definition" | "struct_specifier" | "type_spec" | "interface_declaration"
);
(is_func, is_class)
}
fn get_definition_name(
node: tree_sitter::Node,
source: &str,
_language: Language,
) -> Option<String> {
if let Some(name_node) = node.child_by_field_name("name") {
let text = name_node.utf8_text(source.as_bytes()).ok()?;
return Some(text.to_string());
}
if node.kind() == "arrow_function" || node.kind() == "function_expression" {
if let Some(parent) = node.parent() {
if parent.kind() == "variable_declarator" {
if let Some(name_node) = parent.child_by_field_name("name") {
let text = name_node.utf8_text(source.as_bytes()).ok()?;
return Some(text.to_string());
}
}
}
}
None
}
fn is_inside_class_node(node: tree_sitter::Node) -> bool {
let mut current = node.parent();
while let Some(parent) = current {
let kind = parent.kind();
if matches!(
kind,
"class_definition" | "class_declaration" | "class_body" | "impl_item" | "struct_item"
) {
return true;
}
current = parent.parent();
}
false
}
fn extract_definition_signature(node: tree_sitter::Node, source: &str) -> String {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
let ckind = child.kind();
if ckind == "line_comment"
|| ckind == "block_comment"
|| ckind == "comment"
|| ckind == "attribute_item" || ckind == "attribute" || ckind == "decorator" || ckind == "decorator_list"
{
continue;
}
let start_byte = child.start_byte();
let line_from_start = &source[start_byte..];
let sig = line_from_start
.lines()
.next()
.unwrap_or("")
.trim()
.to_string();
if !sig.is_empty() {
return sig;
}
}
let node_text = &source[node.start_byte()..node.end_byte()];
for line in node_text.lines() {
let trimmed = line.trim();
if !trimmed.is_empty()
&& !trimmed.starts_with("///")
&& !trimmed.starts_with("//!")
&& !trimmed.starts_with("//")
&& !trimmed.starts_with("/*")
&& !trimmed.starts_with("*")
&& !trimmed.starts_with("#[")
&& !trimmed.starts_with("@")
&& !trimmed.starts_with("#")
{
return trimmed.to_string();
}
}
source[node.start_byte()..]
.lines()
.next()
.unwrap_or("")
.trim()
.to_string()
}
fn extract_code_preview(
node: tree_sitter::Node,
source: &str,
signature: &str,
max_lines: usize,
) -> String {
let node_text = &source[node.start_byte()..node.end_byte()];
let mut lines: Vec<&str> = Vec::new();
let mut found_sig = false;
for line in node_text.lines() {
let trimmed = line.trim();
if !found_sig {
if trimmed == signature
|| (trimmed.starts_with(&signature[..signature.len().min(20)])
&& !trimmed.starts_with("///")
&& !trimmed.starts_with("//!"))
{
found_sig = true;
lines.push(line);
}
continue;
}
lines.push(line);
if lines.len() >= max_lines {
break;
}
}
if lines.is_empty() {
for line in node_text.lines() {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with("///") || trimmed.starts_with("//!") {
continue;
}
lines.push(line);
if lines.len() >= max_lines {
break;
}
}
}
lines.join("\n")
}
fn find_enclosing_entry(entries: &[StructureEntry], line: u32) -> Option<&StructureEntry> {
let mut best: Option<&StructureEntry> = None;
for entry in entries {
if line >= entry.line_start && line <= entry.line_end {
match best {
None => best = Some(entry),
Some(current_best) => {
let current_range = current_best.line_end - current_best.line_start;
let new_range = entry.line_end - entry.line_start;
if new_range < current_range {
best = Some(entry);
}
}
}
}
}
best
}
fn try_enrich_with_callgraph(
mut results: Vec<EnrichedResult>,
root: &Path,
language: Language,
) -> Vec<EnrichedResult> {
use crate::callgraph::{build_forward_graph, build_reverse_graph};
let call_graph = match crate::build_project_call_graph(root, language, None, true) {
Ok(graph) => graph,
Err(_) => return results, };
let forward = build_forward_graph(&call_graph);
let reverse = build_reverse_graph(&call_graph);
for result in &mut results {
if result.kind == "module" {
continue; }
let result_file = result.file.to_string_lossy();
let mut found_callees = false;
for (func_ref, callees) in &forward {
let ref_file = func_ref.file.to_string_lossy();
if func_ref.name == result.name
&& (ref_file.is_empty()
|| result_file.is_empty()
|| ref_file.ends_with(result_file.as_ref())
|| result_file.ends_with(ref_file.as_ref()))
{
result.callees = callees.iter().map(|f| f.name.clone()).collect();
result.callees.sort();
found_callees = true;
break;
}
}
if !found_callees {
for (func_ref, callees) in &forward {
if func_ref.name == result.name {
result.callees = callees.iter().map(|f| f.name.clone()).collect();
result.callees.sort();
break;
}
}
}
let mut found_callers = false;
for (func_ref, callers) in &reverse {
let ref_file = func_ref.file.to_string_lossy();
if func_ref.name == result.name
&& (ref_file.is_empty()
|| result_file.is_empty()
|| ref_file.ends_with(result_file.as_ref())
|| result_file.ends_with(ref_file.as_ref()))
{
result.callers = callers.iter().map(|f| f.name.clone()).collect();
result.callers.sort();
found_callers = true;
break;
}
}
if !found_callers {
for (func_ref, callers) in &reverse {
if func_ref.name == result.name {
result.callers = callers.iter().map(|f| f.name.clone()).collect();
result.callers.sort();
break;
}
}
}
}
results
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
fn opts(top_k: usize) -> EnrichedSearchOptions {
EnrichedSearchOptions {
top_k,
include_callgraph: false,
search_mode: SearchMode::Bm25,
}
}
fn create_test_project() -> (TempDir, PathBuf) {
let dir = TempDir::new().unwrap();
let project = dir.path().join("project");
fs::create_dir(&project).unwrap();
fs::write(
project.join("auth.py"),
r#"
def verify_jwt_token(request):
"""Verify JWT token from request headers."""
token = request.headers.get("Authorization")
if not token:
raise AuthError("Missing token")
claims = decode_token(token)
check_expiry(claims)
return claims
def decode_token(token):
"""Decode a JWT token string."""
import jwt
return jwt.decode(token, key="secret")
def check_expiry(claims):
"""Check if token has expired."""
if claims["exp"] < time.time():
raise AuthError("Token expired")
class AuthMiddleware:
"""Middleware for authentication."""
def __init__(self, app):
self.app = app
def process_request(self, request):
"""Process incoming request for auth."""
verify_jwt_token(request)
return self.app(request)
"#,
)
.unwrap();
fs::write(
project.join("routes.py"),
r#"
def user_routes(app):
"""Register user routes."""
@app.route("/users")
def list_users():
return get_all_users()
def admin_routes(app):
"""Register admin routes."""
@app.route("/admin")
def admin_panel():
return render_admin()
def get_all_users():
"""Fetch all users from database."""
return db.query("SELECT * FROM users")
def render_admin():
"""Render admin panel."""
return template.render("admin.html")
"#,
)
.unwrap();
fs::write(
project.join("utils.py"),
r#"
def format_date(dt):
"""Format a datetime object."""
return dt.strftime("%Y-%m-%d")
def parse_json(text):
"""Parse JSON string."""
import json
return json.loads(text)
"#,
)
.unwrap();
(dir, project)
}
#[test]
fn test_enriched_result_has_required_fields() {
let result = EnrichedResult {
name: "verify_jwt_token".to_string(),
kind: "function".to_string(),
file: PathBuf::from("auth.py"),
line_range: (2, 9),
signature: "def verify_jwt_token(request):".to_string(),
callers: vec!["process_request".to_string()],
callees: vec!["decode_token".to_string(), "check_expiry".to_string()],
score: 0.94,
matched_terms: vec!["verify".to_string(), "jwt".to_string(), "token".to_string()],
preview: String::new(),
};
assert_eq!(result.name, "verify_jwt_token");
assert_eq!(result.kind, "function");
assert_eq!(result.line_range.0, 2);
assert!(result.score > 0.0);
assert_eq!(result.callers.len(), 1);
assert_eq!(result.callees.len(), 2);
}
#[test]
fn test_enriched_result_serializes_to_json() {
let result = EnrichedResult {
name: "test_func".to_string(),
kind: "function".to_string(),
file: PathBuf::from("test.py"),
line_range: (1, 5),
signature: "def test_func():".to_string(),
callers: Vec::new(),
callees: Vec::new(),
score: 0.5,
matched_terms: vec!["test".to_string()],
preview: String::new(),
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("test_func"));
assert!(json.contains("function"));
}
#[test]
fn test_enriched_search_report_has_metadata() {
let report = EnrichedSearchReport {
query: "authentication".to_string(),
results: Vec::new(),
total_files_searched: 42,
search_mode: "bm25+structure".to_string(),
};
assert_eq!(report.query, "authentication");
assert_eq!(report.total_files_searched, 42);
assert_eq!(report.search_mode, "bm25+structure");
}
#[test]
fn test_bm25_index_finds_test_files() {
let (_dir, root) = create_test_project();
let index = Bm25Index::from_project(&root, Language::Python).unwrap();
assert!(
index.document_count() >= 3,
"Should index at least 3 .py files, got {}",
index.document_count()
);
let raw = index.search("jwt token", 10);
assert!(!raw.is_empty(), "BM25 should find results for 'jwt token'");
}
#[test]
fn test_enriched_search_returns_results_for_matching_query() {
let (_dir, root) = create_test_project();
let report =
enriched_search("jwt token verify", &root, Language::Python, opts(10)).unwrap();
assert!(
!report.results.is_empty(),
"Should find results for 'jwt token verify'"
);
assert!(report.total_files_searched > 0);
assert_eq!(report.search_mode, "bm25+structure");
}
#[test]
fn test_enriched_search_empty_query_returns_empty() {
let (_dir, root) = create_test_project();
let report = enriched_search("", &root, Language::Python, opts(10)).unwrap();
assert!(
report.results.is_empty(),
"Empty query should return no results"
);
}
#[test]
fn test_enriched_search_no_match_returns_empty() {
let (_dir, root) = create_test_project();
let report =
enriched_search("xyznonexistent123", &root, Language::Python, opts(10)).unwrap();
assert!(
report.results.is_empty(),
"Non-matching query should return no results"
);
}
#[test]
fn test_enriched_search_results_have_function_names() {
let (_dir, root) = create_test_project();
let report = enriched_search("jwt token", &root, Language::Python, opts(10)).unwrap();
let names: Vec<&str> = report.results.iter().map(|r| r.name.as_str()).collect();
let has_func = names
.iter()
.any(|n| *n == "verify_jwt_token" || *n == "decode_token" || *n == "check_expiry");
assert!(has_func, "Should find function names, got: {:?}", names);
}
#[test]
fn test_enriched_search_results_have_signatures() {
let (_dir, root) = create_test_project();
let report = enriched_search("verify jwt", &root, Language::Python, opts(10)).unwrap();
for result in &report.results {
if result.kind == "function" || result.kind == "method" {
assert!(
!result.signature.is_empty(),
"Function '{}' should have a signature",
result.name
);
}
}
}
#[test]
fn test_enriched_search_results_have_line_ranges() {
let (_dir, root) = create_test_project();
let report = enriched_search("decode token", &root, Language::Python, opts(10)).unwrap();
for result in &report.results {
assert!(
result.line_range.0 > 0,
"Line start should be > 0 (1-indexed)"
);
assert!(
result.line_range.1 >= result.line_range.0,
"Line end should be >= line start"
);
}
}
#[test]
fn test_enriched_search_deduplicates_same_function() {
let (_dir, root) = create_test_project();
let report = enriched_search("token", &root, Language::Python, opts(20)).unwrap();
let count = report
.results
.iter()
.filter(|r| r.name == "verify_jwt_token")
.count();
assert!(
count <= 1,
"verify_jwt_token should appear at most once (deduplication), found {}",
count
);
}
#[test]
fn test_enriched_search_respects_top_k() {
let (_dir, root) = create_test_project();
let report = enriched_search("def", &root, Language::Python, opts(3)).unwrap();
assert!(
report.results.len() <= 3,
"Should respect top_k=3, got {} results",
report.results.len()
);
}
#[test]
fn test_enriched_search_results_sorted_by_score() {
let (_dir, root) = create_test_project();
let report = enriched_search("token", &root, Language::Python, opts(10)).unwrap();
if report.results.len() > 1 {
for i in 0..report.results.len() - 1 {
assert!(
report.results[i].score >= report.results[i + 1].score,
"Results should be sorted by score descending: {} >= {}",
report.results[i].score,
report.results[i + 1].score
);
}
}
}
#[test]
fn test_enriched_search_has_matched_terms() {
let (_dir, root) = create_test_project();
let report = enriched_search("jwt token", &root, Language::Python, opts(10)).unwrap();
for result in &report.results {
assert!(
!result.matched_terms.is_empty(),
"Result '{}' should have at least one matched term",
result.name
);
}
}
#[test]
fn test_enriched_search_finds_classes() {
let (_dir, root) = create_test_project();
let report = enriched_search("AuthMiddleware", &root, Language::Python, opts(10)).unwrap();
let has_class = report.results.iter().any(|r| r.kind == "class");
assert!(
has_class,
"Should find class-level results for 'AuthMiddleware'"
);
}
#[test]
fn test_enriched_search_finds_methods() {
let (_dir, root) = create_test_project();
let report = enriched_search("process_request", &root, Language::Python, opts(10)).unwrap();
let has_method = report.results.iter().any(|r| r.kind == "method");
assert!(
has_method,
"Should find method-level results for 'process_request'"
);
}
#[test]
fn test_extract_structure_entries_finds_functions() {
let (_dir, root) = create_test_project();
let entries = extract_structure_entries(&root.join("auth.py"), Language::Python).unwrap();
let func_names: Vec<&str> = entries.iter().map(|e| e.name.as_str()).collect();
assert!(
func_names.contains(&"verify_jwt_token"),
"Should find verify_jwt_token, got: {:?}",
func_names
);
assert!(
func_names.contains(&"decode_token"),
"Should find decode_token, got: {:?}",
func_names
);
}
#[test]
fn test_extract_structure_entries_finds_classes() {
let (_dir, root) = create_test_project();
let entries = extract_structure_entries(&root.join("auth.py"), Language::Python).unwrap();
let class_names: Vec<&str> = entries
.iter()
.filter(|e| e.kind == "class")
.map(|e| e.name.as_str())
.collect();
assert!(
class_names.contains(&"AuthMiddleware"),
"Should find AuthMiddleware class, got: {:?}",
class_names
);
}
#[test]
fn test_extract_structure_entries_has_line_ranges() {
let (_dir, root) = create_test_project();
let entries = extract_structure_entries(&root.join("auth.py"), Language::Python).unwrap();
for entry in &entries {
assert!(entry.line_start > 0, "Line start should be 1-indexed");
assert!(
entry.line_end >= entry.line_start,
"Line end should be >= line start for {}",
entry.name
);
}
}
#[test]
fn test_extract_structure_entries_has_signatures() {
let (_dir, root) = create_test_project();
let entries = extract_structure_entries(&root.join("auth.py"), Language::Python).unwrap();
let verify = entries
.iter()
.find(|e| e.name == "verify_jwt_token")
.unwrap();
assert!(
verify.signature.contains("def verify_jwt_token"),
"Signature should contain function definition, got: '{}'",
verify.signature
);
}
#[test]
fn test_find_enclosing_entry_returns_innermost() {
let entries = vec![
StructureEntry {
name: "OuterClass".to_string(),
kind: "class".to_string(),
line_start: 1,
line_end: 20,
signature: "class OuterClass:".to_string(),
preview: String::new(),
},
StructureEntry {
name: "inner_method".to_string(),
kind: "method".to_string(),
line_start: 5,
line_end: 10,
signature: "def inner_method(self):".to_string(),
preview: String::new(),
},
];
let result = find_enclosing_entry(&entries, 7);
assert!(result.is_some());
assert_eq!(result.unwrap().name, "inner_method");
}
#[test]
fn test_find_enclosing_entry_returns_none_outside() {
let entries = vec![StructureEntry {
name: "some_func".to_string(),
kind: "function".to_string(),
line_start: 10,
line_end: 20,
signature: "def some_func():".to_string(),
preview: String::new(),
}];
let result = find_enclosing_entry(&entries, 5);
assert!(result.is_none());
}
#[test]
fn test_enriched_search_on_empty_directory() {
let dir = TempDir::new().unwrap();
let empty = dir.path().join("empty_project");
fs::create_dir(&empty).unwrap();
let report = enriched_search("anything", &empty, Language::Python, opts(10)).unwrap();
assert!(report.results.is_empty());
assert_eq!(report.total_files_searched, 0);
}
#[test]
fn test_enriched_search_report_query_preserved() {
let (_dir, root) = create_test_project();
let report = enriched_search(
"authentication middleware",
&root,
Language::Python,
opts(10),
)
.unwrap();
assert_eq!(report.query, "authentication middleware");
}
#[test]
fn test_perf_enriched_search_repeated_calls_under_200ms() {
let (_dir, root) = create_test_project();
let query = "jwt token verify";
let _ = enriched_search(query, &root, Language::Python, opts(10)).unwrap();
let mut durations = Vec::new();
for _ in 0..2 {
let start = std::time::Instant::now();
let report = enriched_search(query, &root, Language::Python, opts(10)).unwrap();
let elapsed = start.elapsed();
durations.push(elapsed);
assert!(!report.results.is_empty(), "Should find results");
}
let index = Bm25Index::from_project(&root, Language::Python).unwrap();
let _cached_report =
enriched_search_with_index(query, &root, Language::Python, opts(10), &index).unwrap();
assert!(
!_cached_report.results.is_empty(),
"Cached search should find results"
);
let start = std::time::Instant::now();
let _cached_report2 =
enriched_search_with_index(query, &root, Language::Python, opts(10), &index).unwrap();
let cached_elapsed = start.elapsed();
assert!(
cached_elapsed.as_millis() < 200,
"Cached enriched_search should complete in < 200ms, took {}ms",
cached_elapsed.as_millis()
);
for d in &durations {
eprintln!(" enriched_search call took: {:?}", d);
}
}
#[test]
fn test_read_callgraph_cache_builds_forward_map() {
let dir = tempfile::TempDir::new().unwrap();
let cache_path = dir.path().join("call_graph.json");
fs::write(
&cache_path,
r#"{
"edges": [
{"from_file": "a.py", "from_func": "foo", "to_file": "a.py", "to_func": "bar"},
{"from_file": "a.py", "from_func": "foo", "to_file": "b.py", "to_func": "baz"}
],
"languages": ["python"],
"timestamp": 1740000000
}"#,
)
.unwrap();
let lookup = read_callgraph_cache(&cache_path).unwrap();
let callees = lookup.forward.get("foo").unwrap();
assert!(callees.contains(&"bar".to_string()));
assert!(callees.contains(&"baz".to_string()));
}
#[test]
fn test_read_callgraph_cache_builds_reverse_map() {
let dir = tempfile::TempDir::new().unwrap();
let cache_path = dir.path().join("call_graph.json");
fs::write(
&cache_path,
r#"{
"edges": [
{"from_file": "a.py", "from_func": "foo", "to_file": "a.py", "to_func": "bar"},
{"from_file": "b.py", "from_func": "qux", "to_file": "a.py", "to_func": "bar"}
],
"languages": ["python"],
"timestamp": 1740000000
}"#,
)
.unwrap();
let lookup = read_callgraph_cache(&cache_path).unwrap();
let callers = lookup.reverse.get("bar").unwrap();
assert!(callers.contains(&"foo".to_string()));
assert!(callers.contains(&"qux".to_string()));
}
#[test]
fn test_read_callgraph_cache_empty_edges() {
let dir = tempfile::TempDir::new().unwrap();
let cache_path = dir.path().join("call_graph.json");
fs::write(
&cache_path,
r#"{
"edges": [],
"languages": ["python"],
"timestamp": 1740000000
}"#,
)
.unwrap();
let lookup = read_callgraph_cache(&cache_path).unwrap();
assert!(lookup.forward.is_empty());
assert!(lookup.reverse.is_empty());
}
#[test]
fn test_read_callgraph_cache_invalid_json_returns_error() {
let dir = tempfile::TempDir::new().unwrap();
let cache_path = dir.path().join("call_graph.json");
fs::write(&cache_path, "not valid json").unwrap();
let result = read_callgraph_cache(&cache_path);
assert!(result.is_err());
}
#[test]
fn test_read_callgraph_cache_missing_file_returns_error() {
let result = read_callgraph_cache(Path::new("/nonexistent/path/call_graph.json"));
assert!(result.is_err());
}
#[test]
fn test_enriched_search_with_callgraph_cache_populates_callers_callees() {
let (_dir, root) = create_test_project();
let cache_dir = root.join(".tldr").join("cache");
fs::create_dir_all(&cache_dir).unwrap();
let cache_path = cache_dir.join("call_graph.json");
fs::write(&cache_path, r#"{
"edges": [
{"from_file": "auth.py", "from_func": "verify_jwt_token", "to_file": "auth.py", "to_func": "decode_token"},
{"from_file": "auth.py", "from_func": "verify_jwt_token", "to_file": "auth.py", "to_func": "check_expiry"},
{"from_file": "auth.py", "from_func": "process_request", "to_file": "auth.py", "to_func": "verify_jwt_token"}
],
"languages": ["python"],
"timestamp": 1740000000
}"#).unwrap();
let options = EnrichedSearchOptions {
top_k: 10,
include_callgraph: true,
search_mode: SearchMode::Bm25,
};
let report = enriched_search_with_callgraph_cache(
"jwt token verify",
&root,
Language::Python,
options,
&cache_path,
)
.unwrap();
assert!(!report.results.is_empty());
assert_eq!(report.search_mode, "bm25+structure+callgraph");
if let Some(verify) = report.results.iter().find(|r| r.name == "verify_jwt_token") {
assert!(
verify.callees.contains(&"decode_token".to_string()),
"verify_jwt_token should call decode_token, got: {:?}",
verify.callees
);
assert!(
verify.callees.contains(&"check_expiry".to_string()),
"verify_jwt_token should call check_expiry, got: {:?}",
verify.callees
);
assert!(
verify.callers.contains(&"process_request".to_string()),
"verify_jwt_token should be called by process_request, got: {:?}",
verify.callers
);
}
}
#[test]
fn test_enriched_search_with_callgraph_cache_sorts_callers_callees() {
let (_dir, root) = create_test_project();
let cache_dir = root.join(".tldr").join("cache");
fs::create_dir_all(&cache_dir).unwrap();
let cache_path = cache_dir.join("call_graph.json");
fs::write(&cache_path, r#"{
"edges": [
{"from_file": "auth.py", "from_func": "verify_jwt_token", "to_file": "auth.py", "to_func": "decode_token"},
{"from_file": "auth.py", "from_func": "verify_jwt_token", "to_file": "auth.py", "to_func": "check_expiry"}
],
"languages": ["python"],
"timestamp": 1740000000
}"#).unwrap();
let options = EnrichedSearchOptions {
top_k: 10,
include_callgraph: true,
search_mode: SearchMode::Bm25,
};
let report = enriched_search_with_callgraph_cache(
"verify jwt token",
&root,
Language::Python,
options,
&cache_path,
)
.unwrap();
if let Some(verify) = report.results.iter().find(|r| r.name == "verify_jwt_token") {
let mut expected = verify.callees.clone();
expected.sort();
assert_eq!(
verify.callees, expected,
"Callees should be sorted alphabetically"
);
}
}
}