use anyhow::Result;
use crate::execution::CodeContext;
use crate::tools::HierarchicalSearchArgs;
use super::{ContainerGroup, FileGroup, HierarchicalSymbol, estimate_tokens};
pub fn apply_token_budgets(
file: &mut FileGroup,
file_content: &str,
args: &HierarchicalSearchArgs,
) -> Result<()> {
apply_symbol_target_tokens(file, args);
if args.context_cluster_target_tokens > 0 {
apply_context_clustering(file, file_content, args)?;
}
apply_container_target_tokens(file, args);
apply_file_target_tokens(file, args);
if args.include_container_context {
add_container_contexts(file, file_content, args)?;
}
if args.include_file_context {
add_file_context(file, file_content, args)?;
}
Ok(())
}
fn apply_symbol_target_tokens(file: &mut FileGroup, args: &HierarchicalSearchArgs) {
let target = args.symbol_target_tokens;
for container in &mut file.containers {
apply_symbol_target_to_container(container, target);
}
for symbol in &mut file.top_level_symbols {
trim_symbol_context(symbol, target);
}
}
fn apply_symbol_target_to_container(container: &mut ContainerGroup, target: u64) {
for nested in &mut container.nested_containers {
apply_symbol_target_to_container(nested, target);
}
for symbol in &mut container.symbols {
trim_symbol_context(symbol, target);
}
}
fn trim_symbol_context(symbol: &mut HierarchicalSymbol, target: u64) {
if symbol.estimated_tokens <= target {
return;
}
if let Some(ctx) = &mut symbol.context {
let lines: Vec<&str> = ctx.code.lines().collect();
let mut trim_before = 0;
let mut trim_after = 0;
let max_trim_before = ctx.lines_before;
let max_trim_after = ctx.lines_after;
loop {
let start_idx = trim_before;
let end_idx = lines.len().saturating_sub(trim_after);
if start_idx >= end_idx {
break; }
let trimmed_code: String = lines[start_idx..end_idx].join("\n");
let new_tokens = estimate_tokens(&trimmed_code);
if new_tokens <= target {
ctx.code = trimmed_code;
ctx.lines_before = ctx.lines_before.saturating_sub(trim_before);
ctx.lines_after = ctx.lines_after.saturating_sub(trim_after);
symbol.estimated_tokens = new_tokens;
break;
}
if trim_before < max_trim_before {
trim_before += 1;
} else if trim_after < max_trim_after {
trim_after += 1;
} else {
break; }
}
}
}
fn apply_context_clustering(
file: &mut FileGroup,
file_content: &str,
args: &HierarchicalSearchArgs,
) -> Result<()> {
for container in &mut file.containers {
apply_clustering_to_container(container, file_content, args.context_cluster_target_tokens)?;
}
Ok(())
}
fn apply_clustering_to_container(
container: &mut ContainerGroup,
file_content: &str,
target: u64,
) -> Result<()> {
for nested in &mut container.nested_containers {
apply_clustering_to_container(nested, file_content, target)?;
}
cluster_container_symbols(container, file_content, target)?;
Ok(())
}
fn cluster_container_symbols(
container: &mut ContainerGroup,
file_content: &str,
target: u64,
) -> Result<()> {
if container.symbols.len() < 2 {
return Ok(());
}
container.symbols.sort_by_key(|s| s.range.start.line);
let clustered = cluster_symbols(&container.symbols, file_content, target)?;
container.symbols = clustered;
container.estimated_tokens = unmerged_symbol_tokens(&container.symbols)
+ container
.nested_containers
.iter()
.map(|n| n.estimated_tokens)
.sum::<u64>();
Ok(())
}
fn cluster_symbols(
symbols: &[HierarchicalSymbol],
file_content: &str,
target: u64,
) -> Result<Vec<HierarchicalSymbol>> {
let mut clustered = Vec::new();
let mut i = 0;
while i < symbols.len() {
let (cluster_end, _) = find_cluster_end(symbols, i, target);
if cluster_end > i {
if let Some(merged) =
merge_cluster_if_possible(symbols, i, cluster_end, file_content, target)?
{
clustered.push(merged);
} else {
clustered.extend(symbols[i..=cluster_end].iter().cloned());
}
} else {
clustered.push(symbols[i].clone());
}
i = cluster_end + 1;
}
Ok(clustered)
}
fn find_cluster_end(symbols: &[HierarchicalSymbol], start_idx: usize, target: u64) -> (usize, u64) {
let mut cluster_end = start_idx;
let mut cluster_tokens = symbols[start_idx].estimated_tokens;
while cluster_end + 1 < symbols.len() {
let current_symbol = &symbols[cluster_end];
let next_symbol = &symbols[cluster_end + 1];
if is_adjacent_symbol(current_symbol, next_symbol) {
let combined = cluster_tokens + next_symbol.estimated_tokens;
if combined <= target {
cluster_tokens = combined;
cluster_end += 1;
} else {
break; }
} else {
break; }
}
(cluster_end, cluster_tokens)
}
fn is_adjacent_symbol(
current_symbol: &HierarchicalSymbol,
next_symbol: &HierarchicalSymbol,
) -> bool {
next_symbol.range.start.line <= current_symbol.range.end.line + 5
}
fn merge_cluster_if_possible(
symbols: &[HierarchicalSymbol],
cluster_start: usize,
cluster_end: usize,
file_content: &str,
target: u64,
) -> Result<Option<HierarchicalSymbol>> {
let first = &symbols[cluster_start];
let last = &symbols[cluster_end];
let first_ctx_before = first.context.as_ref().map_or(0, |c| c.lines_before);
let merged_start_line = (first.range.start.line as usize).saturating_sub(first_ctx_before);
let last_ctx_after = last.context.as_ref().map_or(0, |c| c.lines_after);
let merged_end_line = (last.range.end.line as usize) + last_ctx_after;
let merged_code = extract_line_range_with_trailing(
file_content,
merged_start_line.max(1), merged_end_line,
)?;
let merged_tokens = estimate_tokens(&merged_code);
if merged_tokens > target {
return Ok(None);
}
let mut merged = symbols[cluster_start].clone();
let merged_lines_before =
(first.range.start.line as usize).saturating_sub(merged_start_line.max(1));
let merged_lines_after = merged_end_line.saturating_sub(last.range.end.line as usize);
merged.context = Some(CodeContext {
code: merged_code,
lines_before: merged_lines_before,
lines_after: merged_lines_after,
});
merged.estimated_tokens = merged_tokens;
merged.range.end = last.range.end.clone();
merged.clustered_count =
Some(u32::try_from(cluster_end - cluster_start + 1).unwrap_or(u32::MAX));
Ok(Some(merged))
}
fn apply_container_target_tokens(file: &mut FileGroup, args: &HierarchicalSearchArgs) {
let target = args.container_target_tokens;
for container in &mut file.containers {
enforce_container_budget(container, target);
}
}
fn enforce_container_budget(container: &mut ContainerGroup, target: u64) {
for nested in &mut container.nested_containers {
enforce_container_budget(nested, target);
}
let current_tokens = container_tokens_with_merged(container);
container.estimated_tokens = current_tokens;
if current_tokens <= target {
return;
}
trim_container_contents_to_budget(container, target);
}
fn container_tokens_with_merged(container: &ContainerGroup) -> u64 {
nested_container_tokens(container)
+ unmerged_symbol_tokens(&container.symbols)
+ merged_container_token_contribution(container)
}
fn nested_container_tokens(container: &ContainerGroup) -> u64 {
container
.nested_containers
.iter()
.map(|n| n.estimated_tokens)
.sum()
}
fn unmerged_symbol_tokens(symbols: &[HierarchicalSymbol]) -> u64 {
symbols
.iter()
.filter(|s| !s.merged) .map(|s| s.estimated_tokens)
.sum()
}
fn merged_container_token_contribution(container: &ContainerGroup) -> u64 {
if container.symbols.iter().any(|s| s.merged) {
container.merged_container_tokens } else {
0
}
}
fn trim_container_contents_to_budget(container: &mut ContainerGroup, target: u64) {
sort_symbols_by_score_desc(&mut container.symbols);
container.estimated_tokens = remove_unmerged_symbols_to_budget(
&mut container.symbols,
container.estimated_tokens,
target,
);
if container.estimated_tokens > target && !container.nested_containers.is_empty() {
sort_containers_by_max_score_desc(&mut container.nested_containers);
container.estimated_tokens = remove_containers_to_budget(
&mut container.nested_containers,
container.estimated_tokens,
target,
);
}
update_container_symbol_count(container);
}
fn sort_symbols_by_score_desc(symbols: &mut [HierarchicalSymbol]) {
symbols.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
fn sort_containers_by_max_score_desc(containers: &mut [ContainerGroup]) {
containers.sort_by(|a, b| {
let a_max = container_max_score(a);
let b_max = container_max_score(b);
b_max
.partial_cmp(&a_max)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
fn container_max_score(container: &ContainerGroup) -> f64 {
container
.symbols
.iter()
.map(|s| s.score)
.fold(0.0, f64::max)
}
fn remove_unmerged_symbols_to_budget(
symbols: &mut Vec<HierarchicalSymbol>,
mut total: u64,
target: u64,
) -> u64 {
while total > target && !symbols.is_empty() {
if let Some(pos) = symbols.iter().rposition(|s| !s.merged) {
let removed = symbols.remove(pos);
total = total.saturating_sub(removed.estimated_tokens);
} else {
break; }
}
total
}
fn remove_containers_to_budget(
containers: &mut Vec<ContainerGroup>,
mut total: u64,
target: u64,
) -> u64 {
while total > target && !containers.is_empty() {
if let Some(removed) = containers.pop() {
total = total.saturating_sub(removed.estimated_tokens);
}
}
total
}
fn update_container_symbol_count(container: &mut ContainerGroup) {
container.symbol_count = container.symbols.len() as u64
+ container
.nested_containers
.iter()
.map(|n| n.symbol_count)
.sum::<u64>();
}
fn apply_file_target_tokens(file: &mut FileGroup, args: &HierarchicalSearchArgs) {
let target = args.file_target_tokens;
let current_tokens = file
.containers
.iter()
.map(|c| c.estimated_tokens)
.sum::<u64>()
+ file
.top_level_symbols
.iter()
.map(|s| s.estimated_tokens)
.sum::<u64>();
if current_tokens <= target {
file.estimated_tokens = current_tokens;
return;
}
let running_total = trim_file_to_budget(file, target, current_tokens);
file.estimated_tokens = running_total;
}
fn trim_file_to_budget(file: &mut FileGroup, target: u64, mut total: u64) -> u64 {
sort_containers_by_max_score_desc(&mut file.containers);
sort_symbols_by_score_desc(&mut file.top_level_symbols);
total = remove_containers_to_budget(&mut file.containers, total, target);
total = remove_symbols_to_budget(&mut file.top_level_symbols, total, target);
update_file_symbol_count(file);
total
}
fn remove_symbols_to_budget(
symbols: &mut Vec<HierarchicalSymbol>,
mut total: u64,
target: u64,
) -> u64 {
while total > target && !symbols.is_empty() {
if let Some(removed) = symbols.pop() {
total = total.saturating_sub(removed.estimated_tokens);
}
}
total
}
fn update_file_symbol_count(file: &mut FileGroup) {
file.symbol_count = file.containers.iter().map(|c| c.symbol_count).sum::<u64>()
+ file.top_level_symbols.len() as u64;
}
fn add_container_contexts(
file: &mut FileGroup,
file_content: &str,
args: &HierarchicalSearchArgs,
) -> Result<()> {
for container in &mut file.containers {
add_context_to_container(container, file_content, args.container_target_tokens)?;
}
file.estimated_tokens = file
.containers
.iter()
.map(|c| c.estimated_tokens)
.sum::<u64>()
+ file
.top_level_symbols
.iter()
.map(|s| s.estimated_tokens)
.sum::<u64>();
revalidate_file_budget(file, args);
Ok(())
}
fn add_context_to_container(
container: &mut ContainerGroup,
file_content: &str,
container_target: u64,
) -> Result<()> {
for nested in &mut container.nested_containers {
add_context_to_container(nested, file_content, container_target)?;
}
container.estimated_tokens = container_tokens_with_merged(container);
let (start_line, end_line) = container.byte_range;
let container_code = extract_line_range_with_trailing(file_content, start_line, end_line)?;
let context_tokens = estimate_tokens(&container_code);
container.container_context = Some(container_code);
container.estimated_tokens += context_tokens;
if container.estimated_tokens > container_target {
trim_container_contents_to_budget(container, container_target);
}
Ok(())
}
fn add_file_context(
file: &mut FileGroup,
file_content: &str,
args: &HierarchicalSearchArgs,
) -> Result<()> {
let header_lines = 20;
let header = extract_line_range_with_trailing(
file_content,
1, header_lines,
)?;
let header_tokens = estimate_tokens(&header);
file.file_context = Some(header);
file.estimated_tokens += header_tokens;
revalidate_file_budget(file, args);
Ok(())
}
fn revalidate_file_budget(file: &mut FileGroup, args: &HierarchicalSearchArgs) {
let target = args.file_target_tokens;
if file.estimated_tokens <= target {
return;
}
let running_total = trim_file_to_budget(file, target, file.estimated_tokens);
file.estimated_tokens = running_total;
}
pub fn extract_line_range_with_trailing(
content: &str,
start_line: usize,
end_line: usize,
) -> Result<String> {
if start_line == 0 {
anyhow::bail!("start_line must be >= 1 (1-indexed), got 0");
}
if end_line < start_line {
anyhow::bail!("Invalid line range: end_line ({end_line}) < start_line ({start_line})");
}
let mut line_starts: Vec<usize> = vec![0];
for (i, ch) in content.char_indices() {
if ch == '\n' {
line_starts.push(i + 1);
}
}
let line_count = line_starts.len();
if start_line > line_count {
anyhow::bail!("start_line {start_line} exceeds file length {line_count} lines");
}
let start_byte = line_starts[start_line - 1];
let end_byte = if end_line >= line_count {
content.len() } else {
line_starts[end_line]
};
Ok(content[start_byte..end_byte].to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_line_range_basic() {
let content = "line1\nline2\nline3\nline4\n";
let result = extract_line_range_with_trailing(content, 2, 3).unwrap();
assert_eq!(result, "line2\nline3\n");
}
#[test]
fn test_extract_line_range_to_end() {
let content = "line1\nline2\nline3";
let result = extract_line_range_with_trailing(content, 2, 10).unwrap();
assert_eq!(result, "line2\nline3");
}
#[test]
fn test_extract_line_range_single_line() {
let content = "line1\nline2\nline3\n";
let result = extract_line_range_with_trailing(content, 2, 2).unwrap();
assert_eq!(result, "line2\n");
}
#[test]
fn test_extract_line_range_invalid_start() {
let content = "line1\nline2\n";
let result = extract_line_range_with_trailing(content, 0, 1);
assert!(result.is_err());
}
#[test]
fn test_extract_line_range_invalid_order() {
let content = "line1\nline2\n";
let result = extract_line_range_with_trailing(content, 3, 1);
assert!(result.is_err());
}
#[test]
fn test_extract_preserves_trailing_newline() {
let content = "fn foo() {\n bar();\n}\n";
let result = extract_line_range_with_trailing(content, 1, 3).unwrap();
assert!(result.ends_with('\n'));
assert_eq!(result, "fn foo() {\n bar();\n}\n");
}
}