use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::time::Instant;
use anyhow::{Context, Result};
use serde::Serialize;
use sqry_core::graph::unified::concurrent::GraphSnapshot;
use sqry_core::graph::unified::node::{NodeId, NodeKind};
use sqry_core::query::results::QueryResults;
use crate::engine::{canonicalize_in_workspace, engine_for_workspace};
use crate::execution::types::{CodeContext, PositionData, RangeData, ToolExecution};
use crate::execution::utils::duration_to_ms;
use crate::tools::{HierarchicalSearchArgs, Visibility};
mod grouping;
mod merging;
mod token_budget;
mod truncation;
use grouping::{FileContentCache, build_container_tree};
use merging::apply_auto_merge;
use token_budget::apply_token_budgets;
use truncation::{enforce_response_limits, paginate_files};
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct HierarchicalSearchData {
pub query: String,
pub files: Vec<FileGroup>,
pub total_symbols: u64,
pub total_files: u64,
pub truncated: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_page_token: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct FileGroup {
pub path: String,
pub language: String,
pub estimated_tokens: u64,
pub symbol_count: u64,
pub containers: Vec<ContainerGroup>,
pub top_level_symbols: Vec<HierarchicalSymbol>,
pub max_score: f64,
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub is_stub: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub file_context: Option<String>,
}
impl Default for FileGroup {
fn default() -> Self {
Self {
path: String::new(),
language: String::new(),
estimated_tokens: 0,
symbol_count: 0,
containers: Vec::new(),
top_level_symbols: Vec::new(),
max_score: 0.0,
is_stub: false,
file_context: None,
}
}
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ContainerGroup {
pub name: String,
pub qualified_name: String,
pub kind: String,
pub estimated_tokens: u64,
pub depth: u32,
pub parent_path: Vec<String>,
pub byte_range: (usize, usize),
pub symbols: Vec<HierarchicalSymbol>,
pub nested_containers: Vec<ContainerGroup>,
pub symbol_count: u64,
pub children_count: u64,
pub children_names: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub container_context: Option<String>,
#[serde(skip)]
pub merged_container_tokens: u64,
}
impl Default for ContainerGroup {
fn default() -> Self {
Self {
name: String::new(),
qualified_name: String::new(),
kind: String::new(),
estimated_tokens: 0,
depth: 1,
parent_path: Vec::new(),
byte_range: (0, 0),
symbols: Vec::new(),
nested_containers: Vec::new(),
symbol_count: 0,
children_count: 0,
children_names: Vec::new(),
container_context: None,
merged_container_tokens: 0,
}
}
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct HierarchicalSymbol {
pub name: String,
pub qualified_name: String,
pub kind: String,
pub range: RangeData,
pub score: f64,
pub estimated_tokens: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub context: Option<CodeContext>,
#[serde(skip_serializing_if = "Option::is_none")]
pub signature: Option<String>,
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub merged: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub original_level: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub clustered_count: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub macro_metadata: Option<super::types::MacroMetadataResponse>,
}
impl Default for HierarchicalSymbol {
fn default() -> Self {
Self {
name: String::new(),
qualified_name: String::new(),
kind: String::new(),
range: RangeData {
start: PositionData {
line: 0,
character: 0,
},
end: PositionData {
line: 0,
character: 0,
},
},
score: 0.0,
estimated_tokens: 0,
context: None,
signature: None,
merged: false,
original_level: None,
clustered_count: None,
macro_metadata: None,
}
}
}
fn query_results_to_scored_nodes(results: &QueryResults) -> Vec<(NodeId, f64)> {
let total = f64::from(u32::try_from(results.len().max(1)).unwrap_or(u32::MAX));
results
.node_ids()
.iter()
.enumerate()
.map(|(idx, &node_id)| {
let idx = f64::from(u32::try_from(idx).unwrap_or(u32::MAX));
let score = round_relevance_score(1.0 - (idx / total) * 0.5);
(node_id, score)
})
.collect()
}
pub(crate) fn round_relevance_score(score: f64) -> f64 {
(score * 1000.0).round() / 1000.0
}
fn node_kind_to_string(kind: NodeKind) -> &'static str {
match kind {
NodeKind::Function => "function",
NodeKind::Method => "method",
NodeKind::Class => "class",
NodeKind::Interface => "interface",
NodeKind::Trait => "trait",
NodeKind::Module => "module",
NodeKind::Variable => "variable",
NodeKind::Constant => "constant",
NodeKind::Type => "type",
NodeKind::Struct => "struct",
NodeKind::Enum => "enum",
NodeKind::EnumVariant => "enum_variant",
NodeKind::Macro => "macro",
NodeKind::Parameter => "parameter",
NodeKind::Property => "property",
NodeKind::Import => "import",
NodeKind::Export => "export",
NodeKind::Component => "component",
NodeKind::Service => "service",
NodeKind::Resource => "resource",
NodeKind::Endpoint => "endpoint",
NodeKind::Test => "test",
NodeKind::CallSite => "call_site",
NodeKind::StyleRule => "style_rule",
NodeKind::StyleAtRule => "style_at_rule",
NodeKind::StyleVariable => "style_variable",
NodeKind::Lifetime => "lifetime",
NodeKind::TypeParameter => "type_parameter",
NodeKind::Annotation => "annotation",
NodeKind::AnnotationValue => "annotation_value",
NodeKind::LambdaTarget => "lambda_target",
NodeKind::JavaModule => "java_module",
NodeKind::EnumConstant => "enum_constant",
NodeKind::Other => "other",
}
}
#[allow(clippy::too_many_lines)] pub fn execute_hierarchical_search(
args: &HierarchicalSearchArgs,
) -> Result<ToolExecution<HierarchicalSearchData>> {
let start = Instant::now();
let explicit_path = if args.path.is_empty() || args.path == "." {
None
} else {
Some(PathBuf::from(&args.path))
};
let engine = engine_for_workspace(explicit_path.as_ref())?;
let workspace_root = engine.workspace_root();
let search_root = canonicalize_in_workspace(&args.path, workspace_root)?;
tracing::debug!(
query = %args.query,
path = %search_root.display(),
max_results = args.max_results,
max_total_symbols = args.max_total_symbols,
"Executing hierarchical_search"
);
let graph = engine.ensure_graph()?;
let snapshot = graph.snapshot();
let executor = engine.executor();
let query = normalized_query(&args.query)?;
let query_results = executor
.execute_on_graph(query, &search_root)
.context("Failed to execute semantic query")?;
let scored_nodes = query_results_to_scored_nodes(&query_results);
if scored_nodes.is_empty() {
return Ok(empty_hierarchical_execution(
&args.query,
true, workspace_root,
duration_to_ms(start.elapsed()),
));
}
let filtered_nodes = apply_filters(&snapshot, &scored_nodes, args);
let filtered_nodes: Vec<(NodeId, f64)> =
filtered_nodes.into_iter().take(args.max_results).collect();
let candidates_scanned = filtered_nodes.len();
let nodes_by_file = group_nodes_by_file(&snapshot, filtered_nodes, workspace_root);
let file_content_cache = FileContentCache::new();
let mut files = build_file_groups(
nodes_by_file,
&snapshot,
workspace_root,
args,
&file_content_cache,
)?;
apply_auto_merge_if_enabled(&mut files, args, &file_content_cache, workspace_root)?;
apply_token_budgets_for_files(&mut files, args, &file_content_cache, workspace_root)?;
sort_files_deterministic(&mut files);
let total_files = files.len() as u64;
let (limit_truncated, files) = apply_expand_or_limits(files, args);
let total_symbols = total_symbols_after_truncation(&files);
let (paginated_files, next_token, page_truncated) = paginate_files(files, args);
let truncated = limit_truncated || page_truncated;
let elapsed = duration_to_ms(start.elapsed());
Ok(ToolExecution {
data: HierarchicalSearchData {
query: args.query.clone(),
files: paginated_files,
total_symbols,
total_files,
truncated,
next_page_token: next_token.clone(),
},
used_index: false,
used_graph: true,
graph_metadata: None,
execution_ms: elapsed,
next_page_token: next_token,
total: Some(total_symbols),
truncated: Some(truncated),
candidates_scanned: Some(candidates_scanned as u64),
workspace_path: crate::execution::symbol_utils::path_to_forward_slash(workspace_root),
})
}
fn normalized_query(query: &str) -> Result<&str> {
let trimmed = query.trim();
if trimmed.is_empty() {
anyhow::bail!("query cannot be empty");
}
Ok(trimmed)
}
fn empty_hierarchical_execution(
query: &str,
used_index: bool,
workspace_root: &Path,
execution_ms: u64,
) -> ToolExecution<HierarchicalSearchData> {
ToolExecution {
data: HierarchicalSearchData {
query: query.to_string(),
files: Vec::new(),
total_symbols: 0,
total_files: 0,
truncated: false,
next_page_token: None,
},
used_index,
used_graph: false,
graph_metadata: None,
execution_ms,
next_page_token: None,
total: Some(0),
truncated: Some(false),
candidates_scanned: Some(0),
workspace_path: crate::execution::symbol_utils::path_to_forward_slash(workspace_root),
}
}
fn apply_filters(
snapshot: &GraphSnapshot,
nodes: &[(NodeId, f64)],
args: &HierarchicalSearchArgs,
) -> Vec<(NodeId, f64)> {
nodes
.iter()
.filter(|(node_id, score)| {
matches_language_node(snapshot, *node_id, args)
&& matches_kind_node(snapshot, *node_id, args)
&& matches_visibility_node(snapshot, *node_id, args)
&& matches_score(*score, args.score_min)
})
.copied()
.collect()
}
fn matches_language_node(
snapshot: &GraphSnapshot,
node_id: NodeId,
args: &HierarchicalSearchArgs,
) -> bool {
if args.filters.languages.is_empty() {
return true;
}
let Some(entry) = snapshot.get_node(node_id) else {
return false;
};
let lang = snapshot.files().language_for_file(entry.file).map_or_else(
|| "unknown".to_string(),
|l| l.to_string().to_ascii_lowercase(),
);
args.filters
.languages
.iter()
.any(|l| l.eq_ignore_ascii_case(&lang))
}
fn matches_kind_node(
snapshot: &GraphSnapshot,
node_id: NodeId,
args: &HierarchicalSearchArgs,
) -> bool {
if args.filters.kinds.is_empty() {
return true;
}
let Some(entry) = snapshot.get_node(node_id) else {
return false;
};
let kind = node_kind_to_string(entry.kind);
args.filters
.kinds
.iter()
.any(|k| k.eq_ignore_ascii_case(kind))
}
fn matches_visibility_node(
snapshot: &GraphSnapshot,
node_id: NodeId,
args: &HierarchicalSearchArgs,
) -> bool {
let Some(vis) = &args.filters.visibility else {
return true;
};
let Some(entry) = snapshot.get_node(node_id) else {
return false;
};
let visibility = entry
.visibility
.and_then(|id| snapshot.strings().resolve(id))
.map(|s| s.to_ascii_lowercase());
match vis {
Visibility::Public => visibility.as_deref() == Some("public"),
Visibility::Private => visibility.as_deref() == Some("private"),
}
}
fn matches_score(score: f64, min_score: Option<f64>) -> bool {
if let Some(min) = min_score {
score >= min
} else {
true
}
}
fn group_nodes_by_file(
snapshot: &GraphSnapshot,
nodes: Vec<(NodeId, f64)>,
workspace_root: &Path,
) -> HashMap<String, Vec<(NodeId, f64)>> {
let files = snapshot.files();
let mut by_file: HashMap<String, Vec<(NodeId, f64)>> = HashMap::new();
for (node_id, score) in nodes {
let Some(entry) = snapshot.get_node(node_id) else {
continue;
};
let file_path = files
.resolve(entry.file)
.map(|p| {
crate::execution::symbol_utils::relative_path_forward_slash(&p, workspace_root)
})
.unwrap_or_default();
by_file.entry(file_path).or_default().push((node_id, score));
}
by_file
}
fn preindex_nodes_by_file(snapshot: &GraphSnapshot) -> HashMap<String, Vec<NodeId>> {
let files = snapshot.files();
let file_paths: HashMap<_, _> = files.iter().collect();
let mut by_file: HashMap<String, Vec<NodeId>> = HashMap::new();
for (node_id, entry) in snapshot.iter_nodes() {
let Some(relative_path) = file_paths.get(&entry.file) else {
continue;
};
let file_path_str = crate::execution::symbol_utils::path_to_forward_slash(relative_path);
by_file.entry(file_path_str).or_default().push(node_id);
}
by_file
}
fn build_file_groups(
nodes_by_file: HashMap<String, Vec<(NodeId, f64)>>,
snapshot: &GraphSnapshot,
workspace_root: &Path,
args: &HierarchicalSearchArgs,
file_cache: &FileContentCache,
) -> Result<Vec<FileGroup>> {
let all_nodes_by_file = preindex_nodes_by_file(snapshot);
let mut files = Vec::new();
for (file_path, file_nodes) in nodes_by_file {
let file_group = build_file_group(
&file_path,
&file_nodes,
&all_nodes_by_file,
snapshot,
workspace_root,
args,
file_cache,
)?;
files.push(file_group);
}
Ok(files)
}
fn build_file_group(
file_path: &str,
nodes: &[(NodeId, f64)],
all_nodes_by_file: &HashMap<String, Vec<NodeId>>,
snapshot: &GraphSnapshot,
workspace_root: &Path,
args: &HierarchicalSearchArgs,
file_cache: &FileContentCache,
) -> Result<FileGroup> {
let files = snapshot.files();
let language = nodes
.first()
.and_then(|(node_id, _)| snapshot.get_node(*node_id))
.and_then(|entry| files.language_for_file(entry.file))
.map_or_else(
|| "unknown".to_string(),
|l| l.to_string().to_ascii_lowercase(),
);
let max_score = nodes
.iter()
.map(|(_, score)| *score)
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0.0, round_relevance_score);
let full_path = workspace_root.join(file_path);
let empty_vec = Vec::new();
let all_file_nodes = all_nodes_by_file.get(file_path).unwrap_or(&empty_vec);
let (containers, top_level_symbols) = build_container_tree(
nodes,
all_file_nodes,
snapshot,
workspace_root,
args,
file_cache,
&full_path,
)?;
let container_symbols: u64 = containers.iter().map(|c| c.symbol_count).sum();
let top_level_count = top_level_symbols.len() as u64;
let symbol_count = container_symbols + top_level_count;
let container_tokens: u64 = containers.iter().map(|c| c.estimated_tokens).sum();
let top_level_tokens: u64 = top_level_symbols.iter().map(|s| s.estimated_tokens).sum();
let estimated_tokens = container_tokens + top_level_tokens;
Ok(FileGroup {
path: file_path.to_string(),
language,
estimated_tokens,
symbol_count,
containers,
top_level_symbols,
max_score,
is_stub: false,
file_context: None,
})
}
fn apply_auto_merge_if_enabled(
files: &mut [FileGroup],
args: &HierarchicalSearchArgs,
file_cache: &FileContentCache,
workspace_root: &Path,
) -> Result<()> {
if args.auto_merge && !args.expand_files.is_empty() {
tracing::debug!("Skipping auto-merge for expand_files request");
return Ok(());
}
if !args.auto_merge {
return Ok(());
}
for file in files {
let full_path = workspace_root.join(&file.path);
let content = file_cache.get(&full_path).with_context(|| {
format!(
"Failed to read file for auto-merge: {path}",
path = file.path.as_str()
)
})?;
apply_auto_merge(file, &content, args).with_context(|| {
format!(
"Auto-merge failed for file: {path}",
path = file.path.as_str()
)
})?;
}
Ok(())
}
fn apply_token_budgets_for_files(
files: &mut [FileGroup],
args: &HierarchicalSearchArgs,
file_cache: &FileContentCache,
workspace_root: &Path,
) -> Result<()> {
for file in files {
let full_path = workspace_root.join(&file.path);
let content = file_cache.get(&full_path).with_context(|| {
format!(
"Failed to read file for token budget: {path}",
path = file.path.as_str()
)
})?;
apply_token_budgets(file, &content, args).with_context(|| {
format!(
"Token budget enforcement failed for file: {path}",
path = file.path.as_str()
)
})?;
}
Ok(())
}
fn apply_expand_or_limits(
files: Vec<FileGroup>,
args: &HierarchicalSearchArgs,
) -> (bool, Vec<FileGroup>) {
if args.expand_files.is_empty() {
let mut files = files;
let limit_truncated = enforce_response_limits(files.as_mut_slice(), args);
(limit_truncated, files)
} else {
let expanded: Vec<FileGroup> = files
.into_iter()
.filter(|f| args.expand_files.contains(&f.path))
.collect();
(false, expanded)
}
}
fn total_symbols_after_truncation(files: &[FileGroup]) -> u64 {
files
.iter()
.filter(|f| !f.is_stub)
.map(|f| f.symbol_count)
.sum()
}
fn sort_files_deterministic(files: &mut [FileGroup]) {
files.sort_by(|a, b| match b.max_score.partial_cmp(&a.max_score) {
Some(std::cmp::Ordering::Equal) | None => a.path.cmp(&b.path),
Some(ord) => ord,
});
}
#[allow(clippy::float_cmp)] pub fn estimate_tokens(content: &str) -> u64 {
if content.is_empty() {
return 0;
}
let char_count = u64::try_from(content.len()).unwrap_or(u64::MAX);
let base_estimate = char_count.div_ceil(4); let adjusted = base_estimate.saturating_mul(6).div_ceil(5);
adjusted.max(1)
}
#[cfg(test)]
mod tests {
use super::round_relevance_score;
#[test]
#[allow(clippy::float_cmp)] fn round_relevance_score_stabilizes_serialized_output() {
assert_eq!(round_relevance_score(0.923_076_923_076_923_2), 0.923);
assert_eq!(round_relevance_score(0.980_769_230_769_230_8), 0.981);
assert_eq!(round_relevance_score(1.0), 1.0);
}
}