use std::collections::HashMap;
use crate::document::NodeId;
use crate::utils::estimate_tokens;
use super::scorer::ContentRelevance;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum AllocationStrategy {
Greedy,
Proportional,
Hierarchical {
min_per_level: f32,
},
}
impl Default for AllocationStrategy {
fn default() -> Self {
Self::Hierarchical { min_per_level: 0.1 }
}
}
#[derive(Debug, Clone)]
pub struct TruncationInfo {
pub original_len: usize,
pub truncated_len: usize,
pub reason: TruncationReason,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TruncationReason {
BudgetExceeded,
LowRelevanceTail,
}
#[derive(Debug, Clone)]
pub struct SelectedContent {
pub node_id: NodeId,
pub title: String,
pub content: String,
pub tokens: usize,
pub score: f32,
pub depth: usize,
pub truncation: Option<TruncationInfo>,
}
impl SelectedContent {
#[must_use]
pub fn is_truncated(&self) -> bool {
self.truncation.is_some()
}
}
#[derive(Debug, Clone, Default)]
pub struct AllocationStats {
pub items_considered: usize,
pub items_selected: usize,
pub items_truncated: usize,
pub items_filtered: usize,
pub avg_score: f32,
}
#[derive(Debug, Clone)]
pub struct AllocationResult {
pub selected: Vec<SelectedContent>,
pub tokens_used: usize,
pub remaining_budget: usize,
pub stats: AllocationStats,
}
impl AllocationResult {
#[must_use]
pub fn is_empty(&self) -> bool {
self.selected.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.selected.len()
}
}
#[derive(Debug)]
pub struct BudgetAllocator {
total_budget: usize,
min_reserve: usize,
strategy: AllocationStrategy,
min_score: f32,
}
impl BudgetAllocator {
#[must_use]
pub fn new(budget: usize) -> Self {
Self {
total_budget: budget,
min_reserve: budget / 10,
strategy: AllocationStrategy::default(),
min_score: 0.0,
}
}
#[must_use]
pub fn with_strategy(mut self, strategy: AllocationStrategy) -> Self {
self.strategy = strategy;
self
}
#[must_use]
pub fn with_min_score(mut self, min_score: f32) -> Self {
self.min_score = min_score;
self
}
#[must_use]
pub fn allocate(
&self,
scored_content: Vec<ContentRelevance>,
max_depth: usize,
) -> AllocationResult {
let filtered: Vec<_> = scored_content
.into_iter()
.filter(|c| c.score >= self.min_score)
.collect();
let stats = AllocationStats {
items_considered: filtered.len(),
..Default::default()
};
match &self.strategy {
AllocationStrategy::Greedy => self.allocate_greedy(filtered, stats),
AllocationStrategy::Proportional => self.allocate_proportional(filtered, stats),
AllocationStrategy::Hierarchical { min_per_level } => {
self.allocate_hierarchical(filtered, max_depth, *min_per_level, stats)
}
}
}
fn allocate_greedy(
&self,
mut content: Vec<ContentRelevance>,
mut stats: AllocationStats,
) -> AllocationResult {
content.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut selected = Vec::new();
let mut tokens_used = 0;
for relevance in content {
let tokens = relevance.chunk.token_count();
if tokens_used + tokens <= self.total_budget {
selected.push(SelectedContent {
node_id: relevance.chunk.node_id,
title: relevance.chunk.title,
content: relevance.chunk.content,
tokens,
score: relevance.score,
depth: relevance.chunk.depth,
truncation: None,
});
tokens_used += tokens;
} else {
let remaining = self.total_budget - tokens_used;
if remaining >= 50 {
if let Some(truncated) =
self.truncate_content(&relevance.chunk.content, remaining)
{
let truncated_tokens = estimate_tokens(&truncated);
selected.push(SelectedContent {
node_id: relevance.chunk.node_id,
title: relevance.chunk.title,
content: truncated,
tokens: truncated_tokens,
score: relevance.score,
depth: relevance.chunk.depth,
truncation: Some(TruncationInfo {
original_len: relevance.chunk.content.len(),
truncated_len: remaining,
reason: TruncationReason::BudgetExceeded,
}),
});
tokens_used += truncated_tokens;
stats.items_truncated += 1;
}
}
break;
}
}
stats.items_selected = selected.len();
stats.avg_score = if selected.is_empty() {
0.0
} else {
selected.iter().map(|s| s.score).sum::<f32>() / selected.len() as f32
};
AllocationResult {
selected,
tokens_used,
remaining_budget: self.total_budget - tokens_used,
stats,
}
}
fn allocate_proportional(
&self,
content: Vec<ContentRelevance>,
mut stats: AllocationStats,
) -> AllocationResult {
let total_score: f32 = content.iter().map(|c| c.score).sum();
if total_score == 0.0 {
return AllocationResult {
selected: Vec::new(),
tokens_used: 0,
remaining_budget: self.total_budget,
stats,
};
}
let mut selected = Vec::new();
let mut tokens_used = 0;
for relevance in content {
let proportion = relevance.score / total_score;
let allocated_budget = ((self.total_budget as f32 * proportion) as usize).max(50);
let content_tokens = relevance.chunk.token_count();
if content_tokens <= allocated_budget {
if tokens_used + content_tokens <= self.total_budget {
selected.push(SelectedContent {
node_id: relevance.chunk.node_id,
title: relevance.chunk.title,
content: relevance.chunk.content,
tokens: content_tokens,
score: relevance.score,
depth: relevance.chunk.depth,
truncation: None,
});
tokens_used += content_tokens;
}
} else {
let remaining = self.total_budget - tokens_used;
if remaining >= 50 && remaining >= allocated_budget / 2 {
if let Some(truncated) = self
.truncate_content(&relevance.chunk.content, remaining.min(allocated_budget))
{
let truncated_tokens = estimate_tokens(&truncated);
let truncated_len = truncated.len();
selected.push(SelectedContent {
node_id: relevance.chunk.node_id,
title: relevance.chunk.title,
content: truncated,
tokens: truncated_tokens,
score: relevance.score,
depth: relevance.chunk.depth,
truncation: Some(TruncationInfo {
original_len: relevance.chunk.content.len(),
truncated_len,
reason: TruncationReason::BudgetExceeded,
}),
});
tokens_used += truncated_tokens;
stats.items_truncated += 1;
}
}
}
}
stats.items_selected = selected.len();
stats.avg_score = if selected.is_empty() {
0.0
} else {
selected.iter().map(|s| s.score).sum::<f32>() / selected.len() as f32
};
AllocationResult {
selected,
tokens_used,
remaining_budget: self.total_budget - tokens_used,
stats,
}
}
fn allocate_hierarchical(
&self,
content: Vec<ContentRelevance>,
max_depth: usize,
min_per_level: f32,
mut stats: AllocationStats,
) -> AllocationResult {
let mut by_depth: HashMap<usize, Vec<ContentRelevance>> = HashMap::new();
for c in content {
by_depth.entry(c.chunk.depth).or_default().push(c);
}
for (_depth, items) in by_depth.iter_mut() {
items.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
let per_level_budget = (self.total_budget as f32 * min_per_level) as usize;
let mut selected = Vec::new();
let mut tokens_used = 0;
for depth in 0..=max_depth {
if tokens_used >= self.total_budget {
break;
}
if let Some(level_content) = by_depth.get(&depth) {
let mut level_used = 0;
for relevance in level_content {
if tokens_used >= self.total_budget {
break;
}
let tokens = relevance.chunk.token_count();
let can_include_full = tokens_used + tokens <= self.total_budget;
let level_budget_ok = level_used < per_level_budget || depth == 0;
if can_include_full && level_budget_ok {
selected.push(SelectedContent {
node_id: relevance.chunk.node_id,
title: relevance.chunk.title.clone(),
content: relevance.chunk.content.clone(),
tokens,
score: relevance.score,
depth,
truncation: None,
});
tokens_used += tokens;
level_used += tokens;
} else if level_used < per_level_budget {
let remaining =
(self.total_budget - tokens_used).min(per_level_budget - level_used);
if remaining >= 50 {
if let Some(truncated) =
self.truncate_content(&relevance.chunk.content, remaining)
{
let truncated_tokens = estimate_tokens(&truncated);
selected.push(SelectedContent {
node_id: relevance.chunk.node_id,
title: relevance.chunk.title.clone(),
content: truncated,
tokens: truncated_tokens,
score: relevance.score,
depth,
truncation: Some(TruncationInfo {
original_len: relevance.chunk.content.len(),
truncated_len: remaining,
reason: TruncationReason::BudgetExceeded,
}),
});
tokens_used += truncated_tokens;
level_used += truncated_tokens;
stats.items_truncated += 1;
}
}
}
}
}
}
if tokens_used < self.total_budget - self.min_reserve {
let mut all_remaining: Vec<_> = by_depth
.values()
.flat_map(|v| v.iter())
.filter(|c| !selected.iter().any(|s| s.node_id == c.chunk.node_id))
.collect();
all_remaining.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
for relevance in all_remaining {
if tokens_used >= self.total_budget - self.min_reserve {
break;
}
let tokens = relevance.chunk.token_count();
if tokens_used + tokens <= self.total_budget {
selected.push(SelectedContent {
node_id: relevance.chunk.node_id,
title: relevance.chunk.title.clone(),
content: relevance.chunk.content.clone(),
tokens,
score: relevance.score,
depth: relevance.chunk.depth,
truncation: None,
});
tokens_used += tokens;
}
}
}
stats.items_selected = selected.len();
stats.avg_score = if selected.is_empty() {
0.0
} else {
selected.iter().map(|s| s.score).sum::<f32>() / selected.len() as f32
};
AllocationResult {
selected,
tokens_used,
remaining_budget: self.total_budget - tokens_used,
stats,
}
}
fn truncate_content(&self, content: &str, max_tokens: usize) -> Option<String> {
if max_tokens < 20 {
return None;
}
let max_chars = max_tokens * 4;
if content.len() <= max_chars {
return Some(content.to_string());
}
let truncated = &content[..max_chars];
if let Some(pos) = truncated.rfind(|c| c == '.' || c == '!' || c == '?') {
Some(format!("{}...", &truncated[..=pos]))
} else if let Some(pos) = truncated.rfind(' ') {
Some(format!("{}...", &truncated[..pos]))
} else {
Some(format!("{}...", truncated))
}
}
}
impl Default for BudgetAllocator {
fn default() -> Self {
Self::new(4000)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::retrieval::content::{ContentChunk, ScoreComponents};
use indextree::Arena;
fn make_test_node_id() -> NodeId {
let mut arena = Arena::new();
let node = crate::document::TreeNode {
title: "Test".to_string(),
structure: String::new(),
content: String::new(),
summary: String::new(),
depth: 0,
start_index: 0,
end_index: 0,
start_page: None,
end_page: None,
node_id: None,
physical_index: None,
token_count: None,
references: Vec::new(),
};
NodeId(arena.new_node(node))
}
fn make_relevance(content: &str, score: f32, depth: usize) -> ContentRelevance {
let chunk = ContentChunk::new(
make_test_node_id(),
"Test".to_string(),
content.to_string(),
depth,
);
ContentRelevance::new(chunk, score, ScoreComponents::default())
}
#[test]
fn test_allocator_creation() {
let allocator = BudgetAllocator::new(1000);
assert_eq!(allocator.total_budget, 1000);
}
#[test]
fn test_greedy_allocation() {
let allocator = BudgetAllocator::new(100).with_strategy(AllocationStrategy::Greedy);
let content = vec![
make_relevance("High score content with enough text", 0.9, 0),
make_relevance("Low score content", 0.3, 0),
];
let result = allocator.allocate(content, 1);
assert!(!result.is_empty());
assert!(result.tokens_used <= 100);
}
#[test]
fn test_min_score_filter() {
let allocator = BudgetAllocator::new(1000).with_min_score(0.5);
let content = vec![
make_relevance("Good content", 0.8, 0),
make_relevance("Bad content", 0.2, 0),
];
let result = allocator.allocate(content, 1);
assert_eq!(result.selected.len(), 1);
}
#[test]
fn test_truncation() {
let allocator = BudgetAllocator::new(50);
let truncated = allocator.truncate_content(
"This is a very long piece of content. It has multiple sentences. We want to test truncation at sentence boundary.",
25, );
assert!(truncated.is_some());
let text = truncated.unwrap();
assert!(text.len() < 200); }
#[test]
fn test_hierarchical_allocation() {
let allocator = BudgetAllocator::new(200)
.with_strategy(AllocationStrategy::Hierarchical { min_per_level: 0.2 });
let content = vec![
make_relevance("Depth 0 content", 0.9, 0),
make_relevance("Depth 1 content A", 0.7, 1),
make_relevance("Depth 1 content B", 0.6, 1),
make_relevance("Depth 2 content", 0.8, 2),
];
let result = allocator.allocate(content, 2);
let depths: std::collections::HashSet<usize> =
result.selected.iter().map(|s| s.depth).collect();
assert!(depths.len() >= 2);
}
}