use std::collections::HashSet;
use crate::pagination::encode_cursor;
use crate::tools::HierarchicalSearchArgs;
use super::{ContainerGroup, FileGroup};
pub fn enforce_response_limits(files: &mut [FileGroup], args: &HierarchicalSearchArgs) -> bool {
let mut truncated = false;
let mut total_symbols: u64 = 0;
let mut limit_hit_at_file: Option<usize> = None;
for (file_idx, file) in files.iter_mut().enumerate() {
if limit_hit_at_file.is_some() {
convert_to_stub(file);
truncated = true;
continue;
}
if file.containers.len() > args.max_containers_per_file {
file.containers.truncate(args.max_containers_per_file);
truncated = true;
}
for container in &mut file.containers {
if enforce_container_limits(container, args.max_symbols_per_container) {
truncated = true;
}
}
let file_symbols = count_symbols_in_file(file);
total_symbols += file_symbols;
if total_symbols > args.max_total_symbols as u64 {
let excess = total_symbols - args.max_total_symbols as u64;
truncate_file_symbols(file, file_symbols.saturating_sub(excess));
truncated = true;
limit_hit_at_file = Some(file_idx);
}
}
for file in files.iter_mut() {
if !file.is_stub {
refresh_file_metadata(file);
}
}
truncated
}
fn convert_to_stub(file: &mut FileGroup) {
file.containers.clear();
file.top_level_symbols.clear();
file.is_stub = true;
}
fn enforce_container_limits(container: &mut ContainerGroup, max_symbols: usize) -> bool {
let mut truncated = false;
for nested in &mut container.nested_containers {
if enforce_container_limits(nested, max_symbols) {
truncated = true;
}
}
if container.symbols.len() > max_symbols {
container.symbols.truncate(max_symbols);
truncated = true;
}
truncated
}
fn count_symbols_in_file(file: &FileGroup) -> u64 {
let container_symbols: u64 = file.containers.iter().map(count_symbols_in_container).sum();
container_symbols + file.top_level_symbols.len() as u64
}
fn count_symbols_in_container(container: &ContainerGroup) -> u64 {
let nested_symbols: u64 = container
.nested_containers
.iter()
.map(count_symbols_in_container)
.sum();
container.symbols.len() as u64 + nested_symbols
}
fn truncate_file_symbols(file: &mut FileGroup, target: u64) {
let mut all_symbols: Vec<(f64, bool, Vec<usize>, usize, usize)> = Vec::new();
let mut original_order: usize = 0;
for (cidx, container) in file.containers.iter().enumerate() {
let path = vec![cidx];
collect_container_symbols_v5(container, &path, &mut all_symbols, &mut original_order);
}
for (sidx, sym) in file.top_level_symbols.iter().enumerate() {
all_symbols.push((sym.score, false, vec![], sidx, original_order));
original_order += 1;
}
all_symbols.sort_by(
|a, b| match b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal) {
std::cmp::Ordering::Equal => a.4.cmp(&b.4),
other => other,
},
);
let to_keep: HashSet<(bool, Vec<usize>, usize)> = all_symbols
.into_iter()
.take(usize::try_from(target).unwrap_or(usize::MAX))
.map(|(_, is_container, path, sidx, _)| (is_container, path, sidx))
.collect();
for (cidx, container) in file.containers.iter_mut().enumerate() {
let path = vec![cidx];
prune_container_symbols_v5(container, &path, &to_keep);
}
file.top_level_symbols = file
.top_level_symbols
.drain(..)
.enumerate()
.filter(|(sidx, _)| to_keep.contains(&(false, vec![], *sidx)))
.map(|(_, sym)| sym)
.collect();
for container in &mut file.containers {
refresh_container_metadata(container);
}
}
fn collect_container_symbols_v5(
container: &ContainerGroup,
container_path: &[usize],
out: &mut Vec<(f64, bool, Vec<usize>, usize, usize)>,
original_order: &mut usize,
) {
let path_key = container_path.to_vec();
for (sidx, sym) in container.symbols.iter().enumerate() {
out.push((sym.score, true, path_key.clone(), sidx, *original_order));
*original_order += 1;
}
for (nested_idx, nested) in container.nested_containers.iter().enumerate() {
let mut nested_path = container_path.to_vec();
nested_path.push(nested_idx);
collect_container_symbols_v5(nested, &nested_path, out, original_order);
}
}
fn prune_container_symbols_v5(
container: &mut ContainerGroup,
container_path: &[usize],
to_keep: &HashSet<(bool, Vec<usize>, usize)>,
) {
let path_key = container_path.to_vec();
container.symbols = container
.symbols
.drain(..)
.enumerate()
.filter(|(sidx, _)| to_keep.contains(&(true, path_key.clone(), *sidx)))
.map(|(_, sym)| sym)
.collect();
for (nested_idx, nested) in container.nested_containers.iter_mut().enumerate() {
let mut nested_path = container_path.to_vec();
nested_path.push(nested_idx);
prune_container_symbols_v5(nested, &nested_path, to_keep);
}
}
fn refresh_container_metadata(container: &mut ContainerGroup) {
for nested in &mut container.nested_containers {
refresh_container_metadata(nested);
}
let direct_symbols = container.symbols.len() as u64;
let nested_symbols: u64 = container
.nested_containers
.iter()
.map(|n| n.symbol_count)
.sum();
container.symbol_count = direct_symbols + nested_symbols;
container.children_count =
container.symbols.len() as u64 + container.nested_containers.len() as u64;
let symbol_names: Vec<String> = container.symbols.iter().map(|s| s.name.clone()).collect();
let nested_names: Vec<String> = container
.nested_containers
.iter()
.map(|n| n.name.clone())
.collect();
container.children_names = [symbol_names, nested_names].concat();
let direct_tokens: u64 = container.symbols.iter().map(|s| s.estimated_tokens).sum();
let nested_tokens: u64 = container
.nested_containers
.iter()
.map(|n| n.estimated_tokens)
.sum();
container.estimated_tokens = direct_tokens + nested_tokens;
}
fn refresh_file_metadata(file: &mut FileGroup) {
for container in &mut file.containers {
refresh_container_metadata(container);
}
let container_symbols: u64 = file.containers.iter().map(|c| c.symbol_count).sum();
let top_level_symbols = file.top_level_symbols.len() as u64;
file.symbol_count = container_symbols + top_level_symbols;
let container_tokens: u64 = file.containers.iter().map(|c| c.estimated_tokens).sum();
let top_level_tokens: u64 = file
.top_level_symbols
.iter()
.map(|s| s.estimated_tokens)
.sum();
file.estimated_tokens = container_tokens + top_level_tokens;
}
pub fn paginate_files(
files: Vec<FileGroup>,
args: &HierarchicalSearchArgs,
) -> (Vec<FileGroup>, Option<String>, bool) {
let total_files = files.len();
let start_offset = args.pagination.offset;
let page_size = args.pagination.size.min(args.max_files);
if start_offset >= total_files {
return (Vec::new(), None, false);
}
let end_offset = (start_offset + page_size).min(total_files);
let paginated: Vec<FileGroup> = files
.into_iter()
.skip(start_offset)
.take(page_size)
.collect();
let next_token = if end_offset < total_files {
Some(encode_cursor(end_offset))
} else {
None
};
let truncated = next_token.is_some();
(paginated, next_token, truncated)
}