use crate::soch_ql::SochValue;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug, Clone)]
pub struct TokenEstimatorConfig {
pub int_factor: f32,
pub float_factor: f32,
pub string_factor: f32,
pub hex_factor: f32,
pub bytes_per_token: f32,
pub safety_margin: f32,
pub separator_tokens: usize,
pub newline_tokens: usize,
pub header_tokens: usize,
}
impl Default for TokenEstimatorConfig {
fn default() -> Self {
Self {
int_factor: 1.0,
float_factor: 1.2,
string_factor: 1.1,
hex_factor: 2.5,
bytes_per_token: 4.0, safety_margin: 1.15, separator_tokens: 1,
newline_tokens: 1,
header_tokens: 10, }
}
}
impl TokenEstimatorConfig {
pub fn gpt4() -> Self {
Self {
bytes_per_token: 3.8,
safety_margin: 1.15,
..Default::default()
}
}
pub fn claude() -> Self {
Self {
bytes_per_token: 4.2,
safety_margin: 1.15,
..Default::default()
}
}
pub fn conservative() -> Self {
Self {
int_factor: 1.2,
float_factor: 1.4,
string_factor: 1.3,
hex_factor: 3.0,
bytes_per_token: 3.5,
safety_margin: 1.25, ..Default::default()
}
}
}
pub struct TokenEstimator {
config: TokenEstimatorConfig,
}
impl TokenEstimator {
pub fn new() -> Self {
Self {
config: TokenEstimatorConfig::default(),
}
}
pub fn with_config(config: TokenEstimatorConfig) -> Self {
Self { config }
}
pub fn estimate_value(&self, value: &SochValue) -> usize {
let raw = self.estimate_value_raw(value);
((raw as f32) * self.config.safety_margin).ceil() as usize
}
fn estimate_value_raw(&self, value: &SochValue) -> usize {
match value {
SochValue::Null => 1,
SochValue::Bool(_) => 1, SochValue::Int(n) => {
let digits = if *n == 0 {
1
} else {
((*n).abs() as f64).log10().ceil() as usize + if *n < 0 { 1 } else { 0 }
};
((digits as f32 * self.config.int_factor) / self.config.bytes_per_token).ceil()
as usize
}
SochValue::UInt(n) => {
let digits = if *n == 0 {
1
} else {
((*n as f64).log10().ceil() as usize).max(1)
};
((digits as f32 * self.config.int_factor) / self.config.bytes_per_token).ceil()
as usize
}
SochValue::Float(f) => {
let s = format!("{:.2}", f);
((s.len() as f32 * self.config.float_factor) / self.config.bytes_per_token).ceil()
as usize
}
SochValue::Text(s) => {
((s.len() as f32 * self.config.string_factor) / self.config.bytes_per_token).ceil()
as usize
}
SochValue::Binary(b) => {
let hex_len = 2 + b.len() * 2;
((hex_len as f32 * self.config.hex_factor) / self.config.bytes_per_token).ceil()
as usize
}
SochValue::Array(arr) => {
let elem_tokens: usize = arr.iter().map(|v| self.estimate_value(v)).sum();
let separator_tokens = if arr.is_empty() { 0 } else { arr.len() - 1 };
2 + elem_tokens + separator_tokens }
}
}
pub fn estimate_row(&self, values: &[SochValue]) -> usize {
if values.is_empty() {
return 0;
}
let value_tokens: usize = values.iter().map(|v| self.estimate_value(v)).sum();
let separator_tokens = (values.len() - 1) * self.config.separator_tokens;
let newline = self.config.newline_tokens;
value_tokens + separator_tokens + newline
}
pub fn estimate_header(&self, table: &str, columns: &[String], row_count: usize) -> usize {
let base = self.config.header_tokens;
let table_tokens = ((table.len() as f32) / self.config.bytes_per_token).ceil() as usize;
let count_tokens = ((row_count as f64).log10().ceil() as usize).max(1);
let col_tokens: usize = columns
.iter()
.map(|c| ((c.len() as f32) / self.config.bytes_per_token).ceil() as usize)
.sum();
base + table_tokens + count_tokens + col_tokens
}
pub fn estimate_table(
&self,
table: &str,
columns: &[String],
rows: &[Vec<SochValue>],
) -> usize {
let header = self.estimate_header(table, columns, rows.len());
let row_tokens: usize = rows.iter().map(|r| self.estimate_row(r)).sum();
header + row_tokens
}
pub fn estimate_text(&self, text: &str) -> usize {
let raw = ((text.len() as f32) / self.config.bytes_per_token).ceil() as usize;
((raw as f32) * self.config.safety_margin).ceil() as usize
}
pub fn truncate_to_tokens(&self, text: &str, max_tokens: usize) -> String {
truncate_to_tokens(text, max_tokens, self, "...")
}
}
impl Default for TokenEstimator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct BudgetAllocation {
pub full_sections: Vec<String>,
pub truncated_sections: Vec<(String, usize, usize)>,
pub dropped_sections: Vec<String>,
pub tokens_allocated: usize,
pub tokens_remaining: usize,
pub explain: Vec<AllocationDecision>,
}
#[derive(Debug, Clone)]
pub struct AllocationDecision {
pub section: String,
pub priority: i32,
pub requested: usize,
pub allocated: usize,
pub outcome: AllocationOutcome,
pub reason: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AllocationOutcome {
Full,
Truncated,
Dropped,
}
#[derive(Debug, Clone)]
pub struct BudgetSection {
pub name: String,
pub priority: i32,
pub estimated_tokens: usize,
pub minimum_tokens: Option<usize>,
pub required: bool,
pub weight: f32,
}
impl Default for BudgetSection {
fn default() -> Self {
Self {
name: String::new(),
priority: 0,
estimated_tokens: 0,
minimum_tokens: None,
required: false,
weight: 1.0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum AllocationStrategy {
#[default]
GreedyPriority,
Proportional,
StrictPriority,
}
pub struct TokenBudgetEnforcer {
budget: usize,
allocated: AtomicUsize,
estimator: TokenEstimator,
reserved: usize,
strategy: AllocationStrategy,
}
#[derive(Debug, Clone)]
pub struct TokenBudgetConfig {
pub total_budget: usize,
pub reserved_tokens: usize,
pub strict: bool,
pub default_priority: i32,
pub strategy: AllocationStrategy,
}
impl Default for TokenBudgetConfig {
fn default() -> Self {
Self {
total_budget: 4096,
reserved_tokens: 100,
strict: false,
default_priority: 10,
strategy: AllocationStrategy::GreedyPriority,
}
}
}
impl TokenBudgetEnforcer {
pub fn new(config: TokenBudgetConfig) -> Self {
Self {
budget: config.total_budget,
allocated: AtomicUsize::new(0),
estimator: TokenEstimator::new(),
reserved: config.reserved_tokens,
strategy: config.strategy,
}
}
pub fn with_budget(budget: usize) -> Self {
Self {
budget,
allocated: AtomicUsize::new(0),
estimator: TokenEstimator::new(),
reserved: 0,
strategy: AllocationStrategy::GreedyPriority,
}
}
pub fn with_estimator(budget: usize, estimator: TokenEstimator) -> Self {
Self {
budget,
allocated: AtomicUsize::new(0),
estimator,
reserved: 0,
strategy: AllocationStrategy::GreedyPriority,
}
}
pub fn with_strategy(mut self, strategy: AllocationStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn reserve(&mut self, tokens: usize) {
self.reserved = tokens;
}
pub fn available(&self) -> usize {
let allocated = self.allocated.load(Ordering::Acquire);
self.budget.saturating_sub(self.reserved + allocated)
}
pub fn total_budget(&self) -> usize {
self.budget
}
pub fn allocated(&self) -> usize {
self.allocated.load(Ordering::Acquire)
}
pub fn try_allocate(&self, tokens: usize) -> bool {
loop {
let current = self.allocated.load(Ordering::Acquire);
let new_total = current + tokens;
if new_total + self.reserved > self.budget {
return false;
}
if self
.allocated
.compare_exchange(current, new_total, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return true;
}
}
}
pub fn allocate_sections(&self, sections: &[BudgetSection]) -> BudgetAllocation {
match self.strategy {
AllocationStrategy::GreedyPriority => self.allocate_greedy(sections),
AllocationStrategy::Proportional => self.allocate_proportional(sections),
AllocationStrategy::StrictPriority => self.allocate_strict(sections),
}
}
fn allocate_greedy(&self, sections: &[BudgetSection]) -> BudgetAllocation {
let mut sorted: Vec<_> = sections.iter().collect();
sorted.sort_by_key(|s| s.priority);
let mut allocation = BudgetAllocation {
full_sections: Vec::new(),
truncated_sections: Vec::new(),
dropped_sections: Vec::new(),
tokens_allocated: 0,
tokens_remaining: self.budget.saturating_sub(self.reserved),
explain: Vec::new(),
};
for section in sorted {
let remaining = allocation.tokens_remaining;
if section.estimated_tokens <= remaining {
allocation.full_sections.push(section.name.clone());
allocation.tokens_allocated += section.estimated_tokens;
allocation.tokens_remaining -= section.estimated_tokens;
allocation.explain.push(AllocationDecision {
section: section.name.clone(),
priority: section.priority,
requested: section.estimated_tokens,
allocated: section.estimated_tokens,
outcome: AllocationOutcome::Full,
reason: format!("Fits in remaining budget ({} tokens)", remaining),
});
} else if let Some(min) = section.minimum_tokens {
if min <= remaining {
let truncated_to = remaining;
allocation.truncated_sections.push((
section.name.clone(),
section.estimated_tokens,
truncated_to,
));
allocation.tokens_allocated += truncated_to;
allocation.explain.push(AllocationDecision {
section: section.name.clone(),
priority: section.priority,
requested: section.estimated_tokens,
allocated: truncated_to,
outcome: AllocationOutcome::Truncated,
reason: format!(
"Truncated from {} to {} tokens (min: {})",
section.estimated_tokens, truncated_to, min
),
});
allocation.tokens_remaining = 0;
} else {
allocation.dropped_sections.push(section.name.clone());
allocation.explain.push(AllocationDecision {
section: section.name.clone(),
priority: section.priority,
requested: section.estimated_tokens,
allocated: 0,
outcome: AllocationOutcome::Dropped,
reason: format!("Minimum {} exceeds remaining {} tokens", min, remaining),
});
}
} else {
allocation.dropped_sections.push(section.name.clone());
allocation.explain.push(AllocationDecision {
section: section.name.clone(),
priority: section.priority,
requested: section.estimated_tokens,
allocated: 0,
outcome: AllocationOutcome::Dropped,
reason: format!(
"Requested {} exceeds remaining {} (no truncation allowed)",
section.estimated_tokens, remaining
),
});
}
}
allocation
}
fn allocate_proportional(&self, sections: &[BudgetSection]) -> BudgetAllocation {
let available = self.budget.saturating_sub(self.reserved);
let total_weight: f32 = sections.iter().map(|s| s.weight).sum();
if total_weight == 0.0 {
return self.allocate_greedy(sections);
}
let mut allocation = BudgetAllocation {
full_sections: Vec::new(),
truncated_sections: Vec::new(),
dropped_sections: Vec::new(),
tokens_allocated: 0,
tokens_remaining: available,
explain: Vec::new(),
};
let mut allocations: Vec<(usize, usize, bool)> = sections
.iter()
.map(|s| {
let proportional = ((available as f32) * s.weight / total_weight).floor() as usize;
let capped = proportional.min(s.estimated_tokens);
let min = s.minimum_tokens.unwrap_or(0);
(
capped.max(min),
s.estimated_tokens,
capped < s.estimated_tokens,
)
})
.collect();
let mut total: usize = allocations.iter().map(|(a, _, _)| *a).sum();
while total > available {
let max_idx = allocations
.iter()
.enumerate()
.filter(|(i, (a, _, _))| *a > sections[*i].minimum_tokens.unwrap_or(0))
.max_by_key(|(_, (a, _, _))| *a)
.map(|(i, _)| i);
match max_idx {
Some(idx) => {
let reduce = (total - available)
.min(allocations[idx].0 - sections[idx].minimum_tokens.unwrap_or(0));
allocations[idx].0 -= reduce;
total -= reduce;
}
None => break, }
}
for (i, section) in sections.iter().enumerate() {
let (allocated, requested, truncated) = allocations[i];
if allocated == 0 {
allocation.dropped_sections.push(section.name.clone());
allocation.explain.push(AllocationDecision {
section: section.name.clone(),
priority: section.priority,
requested,
allocated: 0,
outcome: AllocationOutcome::Dropped,
reason: "No budget available after proportional allocation".to_string(),
});
} else if truncated {
allocation
.truncated_sections
.push((section.name.clone(), requested, allocated));
allocation.tokens_allocated += allocated;
allocation.tokens_remaining = allocation.tokens_remaining.saturating_sub(allocated);
allocation.explain.push(AllocationDecision {
section: section.name.clone(),
priority: section.priority,
requested,
allocated,
outcome: AllocationOutcome::Truncated,
reason: format!(
"Proportional allocation: {:.1}% of budget (weight {:.1})",
(allocated as f32 / available as f32) * 100.0,
section.weight
),
});
} else {
allocation.full_sections.push(section.name.clone());
allocation.tokens_allocated += allocated;
allocation.tokens_remaining = allocation.tokens_remaining.saturating_sub(allocated);
allocation.explain.push(AllocationDecision {
section: section.name.clone(),
priority: section.priority,
requested,
allocated,
outcome: AllocationOutcome::Full,
reason: format!(
"Full allocation within proportional budget (weight {:.1})",
section.weight
),
});
}
}
allocation
}
fn allocate_strict(&self, sections: &[BudgetSection]) -> BudgetAllocation {
let mut sorted: Vec<_> = sections.iter().collect();
sorted.sort_by_key(|s| (if s.required { 0 } else { 1 }, s.priority));
let mut allocation = BudgetAllocation {
full_sections: Vec::new(),
truncated_sections: Vec::new(),
dropped_sections: Vec::new(),
tokens_allocated: 0,
tokens_remaining: self.budget.saturating_sub(self.reserved),
explain: Vec::new(),
};
for section in sorted.iter().filter(|s| s.required) {
let remaining = allocation.tokens_remaining;
let min = section.minimum_tokens.unwrap_or(section.estimated_tokens);
if section.estimated_tokens <= remaining {
allocation.full_sections.push(section.name.clone());
allocation.tokens_allocated += section.estimated_tokens;
allocation.tokens_remaining -= section.estimated_tokens;
allocation.explain.push(AllocationDecision {
section: section.name.clone(),
priority: section.priority,
requested: section.estimated_tokens,
allocated: section.estimated_tokens,
outcome: AllocationOutcome::Full,
reason: "Required section - full allocation".to_string(),
});
} else if min <= remaining {
allocation.truncated_sections.push((
section.name.clone(),
section.estimated_tokens,
remaining,
));
allocation.tokens_allocated += remaining;
allocation.explain.push(AllocationDecision {
section: section.name.clone(),
priority: section.priority,
requested: section.estimated_tokens,
allocated: remaining,
outcome: AllocationOutcome::Truncated,
reason: "Required section - truncated to fit".to_string(),
});
allocation.tokens_remaining = 0;
}
}
for section in sorted.iter().filter(|s| !s.required) {
let remaining = allocation.tokens_remaining;
if remaining == 0 {
allocation.dropped_sections.push(section.name.clone());
allocation.explain.push(AllocationDecision {
section: section.name.clone(),
priority: section.priority,
requested: section.estimated_tokens,
allocated: 0,
outcome: AllocationOutcome::Dropped,
reason: "No budget remaining after required sections".to_string(),
});
continue;
}
if section.estimated_tokens <= remaining {
allocation.full_sections.push(section.name.clone());
allocation.tokens_allocated += section.estimated_tokens;
allocation.tokens_remaining -= section.estimated_tokens;
allocation.explain.push(AllocationDecision {
section: section.name.clone(),
priority: section.priority,
requested: section.estimated_tokens,
allocated: section.estimated_tokens,
outcome: AllocationOutcome::Full,
reason: "Optional section - fits in remaining budget".to_string(),
});
} else if let Some(min) = section.minimum_tokens {
if min <= remaining {
allocation.truncated_sections.push((
section.name.clone(),
section.estimated_tokens,
remaining,
));
allocation.tokens_allocated += remaining;
allocation.explain.push(AllocationDecision {
section: section.name.clone(),
priority: section.priority,
requested: section.estimated_tokens,
allocated: remaining,
outcome: AllocationOutcome::Truncated,
reason: "Optional section - truncated to fit".to_string(),
});
allocation.tokens_remaining = 0;
} else {
allocation.dropped_sections.push(section.name.clone());
allocation.explain.push(AllocationDecision {
section: section.name.clone(),
priority: section.priority,
requested: section.estimated_tokens,
allocated: 0,
outcome: AllocationOutcome::Dropped,
reason: format!("Minimum {} exceeds remaining {}", min, remaining),
});
}
} else {
allocation.dropped_sections.push(section.name.clone());
allocation.explain.push(AllocationDecision {
section: section.name.clone(),
priority: section.priority,
requested: section.estimated_tokens,
allocated: 0,
outcome: AllocationOutcome::Dropped,
reason: format!(
"Requested {} exceeds remaining {}",
section.estimated_tokens, remaining
),
});
}
}
allocation
}
pub fn reset(&self) {
self.allocated.store(0, Ordering::Release);
}
pub fn estimator(&self) -> &TokenEstimator {
&self.estimator
}
}
impl BudgetAllocation {
pub fn explain_text(&self) -> String {
let mut output = String::new();
output.push_str("=== CONTEXT BUDGET ALLOCATION ===\n\n");
output.push_str(&format!(
"Total Allocated: {} tokens\n",
self.tokens_allocated
));
output.push_str(&format!("Remaining: {} tokens\n\n", self.tokens_remaining));
output.push_str("SECTIONS:\n");
for decision in &self.explain {
let status = match decision.outcome {
AllocationOutcome::Full => "✓ FULL",
AllocationOutcome::Truncated => "◐ TRUNCATED",
AllocationOutcome::Dropped => "✗ DROPPED",
};
output.push_str(&format!(
" [{:^12}] {} (priority {})\n",
status, decision.section, decision.priority
));
output.push_str(&format!(
" Requested: {}, Allocated: {}\n",
decision.requested, decision.allocated
));
output.push_str(&format!(" Reason: {}\n", decision.reason));
}
output
}
pub fn explain_json(&self) -> String {
serde_json::to_string_pretty(&ExplainOutput {
tokens_allocated: self.tokens_allocated,
tokens_remaining: self.tokens_remaining,
full_sections: self.full_sections.clone(),
truncated_sections: self.truncated_sections.clone(),
dropped_sections: self.dropped_sections.clone(),
decisions: self
.explain
.iter()
.map(|d| ExplainDecision {
section: d.section.clone(),
priority: d.priority,
requested: d.requested,
allocated: d.allocated,
outcome: format!("{:?}", d.outcome),
reason: d.reason.clone(),
})
.collect(),
})
.unwrap_or_else(|_| "{}".to_string())
}
}
#[derive(serde::Serialize)]
struct ExplainOutput {
tokens_allocated: usize,
tokens_remaining: usize,
full_sections: Vec<String>,
truncated_sections: Vec<(String, usize, usize)>,
dropped_sections: Vec<String>,
decisions: Vec<ExplainDecision>,
}
#[derive(serde::Serialize)]
struct ExplainDecision {
section: String,
priority: i32,
requested: usize,
allocated: usize,
outcome: String,
reason: String,
}
pub fn truncate_to_tokens(
text: &str,
max_tokens: usize,
estimator: &TokenEstimator,
suffix: &str,
) -> String {
let current = estimator.estimate_text(text);
if current <= max_tokens {
return text.to_string();
}
let suffix_tokens = estimator.estimate_text(suffix);
let target_tokens = max_tokens.saturating_sub(suffix_tokens);
if target_tokens == 0 {
return suffix.to_string();
}
let mut low = 0;
let mut high = text.len();
while low < high {
let mid = (low + high).div_ceil(2);
let boundary = text
.char_indices()
.take_while(|(i, _)| *i < mid)
.last()
.map(|(i, c)| i + c.len_utf8())
.unwrap_or(0);
let truncated = &text[..boundary];
let tokens = estimator.estimate_text(truncated);
if tokens <= target_tokens {
low = boundary;
} else {
high = boundary.saturating_sub(1);
}
}
let truncated = &text[..low];
let word_boundary = truncated.rfind(|c: char| c.is_whitespace()).unwrap_or(low);
format!("{}{}", &text[..word_boundary], suffix)
}
pub fn truncate_rows(
rows: &[Vec<SochValue>],
max_tokens: usize,
estimator: &TokenEstimator,
) -> Vec<Vec<SochValue>> {
let mut result = Vec::new();
let mut used = 0;
for row in rows {
let row_tokens = estimator.estimate_row(row);
if used + row_tokens <= max_tokens {
result.push(row.clone());
used += row_tokens;
} else {
break; }
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_estimate_value_int() {
let est = TokenEstimator::new();
assert!(est.estimate_value(&SochValue::Int(0)) >= 1);
assert!(est.estimate_value(&SochValue::Int(42)) >= 1);
let small = est.estimate_value(&SochValue::Int(42));
let large = est.estimate_value(&SochValue::Int(1_000_000_000));
assert!(large >= small);
}
#[test]
fn test_estimate_value_text() {
let est = TokenEstimator::new();
let short = est.estimate_value(&SochValue::Text("hello".to_string()));
let long = est.estimate_value(&SochValue::Text(
"hello world this is a longer string".to_string(),
));
assert!(long > short);
}
#[test]
#[allow(clippy::approx_constant)]
fn test_estimate_row() {
let est = TokenEstimator::new();
let row = vec![
SochValue::Int(1),
SochValue::Text("Alice".to_string()),
SochValue::Float(3.14),
];
let tokens = est.estimate_row(&row);
assert!(tokens >= 3); }
#[test]
fn test_estimate_table() {
let est = TokenEstimator::new();
let columns = vec!["id".to_string(), "name".to_string()];
let rows = vec![
vec![SochValue::Int(1), SochValue::Text("Alice".to_string())],
vec![SochValue::Int(2), SochValue::Text("Bob".to_string())],
];
let tokens = est.estimate_table("users", &columns, &rows);
assert!(tokens > est.estimate_row(&rows[0]) * 2);
}
#[test]
fn test_budget_enforcer_allocation() {
let enforcer = TokenBudgetEnforcer::with_budget(1000);
assert!(enforcer.try_allocate(500));
assert_eq!(enforcer.allocated(), 500);
assert_eq!(enforcer.available(), 500);
assert!(enforcer.try_allocate(400));
assert_eq!(enforcer.allocated(), 900);
assert!(!enforcer.try_allocate(200));
assert_eq!(enforcer.allocated(), 900);
}
#[test]
fn test_budget_enforcer_reset() {
let enforcer = TokenBudgetEnforcer::with_budget(1000);
enforcer.try_allocate(800);
assert_eq!(enforcer.allocated(), 800);
enforcer.reset();
assert_eq!(enforcer.allocated(), 0);
}
#[test]
fn test_allocate_sections() {
let enforcer = TokenBudgetEnforcer::with_budget(1000);
let sections = vec![
BudgetSection {
name: "A".to_string(),
priority: 0,
estimated_tokens: 300,
minimum_tokens: None,
required: true,
weight: 1.0,
},
BudgetSection {
name: "B".to_string(),
priority: 1,
estimated_tokens: 400,
minimum_tokens: Some(200),
required: false,
weight: 1.0,
},
BudgetSection {
name: "C".to_string(),
priority: 2,
estimated_tokens: 500,
minimum_tokens: None,
required: false,
weight: 1.0,
},
];
let allocation = enforcer.allocate_sections(§ions);
assert!(allocation.full_sections.contains(&"A".to_string()));
assert!(allocation.dropped_sections.contains(&"C".to_string()));
assert!(allocation.tokens_allocated <= 1000);
}
#[test]
fn test_allocate_by_priority() {
let enforcer = TokenBudgetEnforcer::with_budget(500);
let sections = vec![
BudgetSection {
name: "LowPriority".to_string(),
priority: 10,
estimated_tokens: 200,
minimum_tokens: None,
required: false,
weight: 1.0,
},
BudgetSection {
name: "HighPriority".to_string(),
priority: 0,
estimated_tokens: 400,
minimum_tokens: None,
required: true,
weight: 1.0,
},
];
let allocation = enforcer.allocate_sections(§ions);
assert!(
allocation
.full_sections
.contains(&"HighPriority".to_string())
);
assert!(
allocation
.dropped_sections
.contains(&"LowPriority".to_string())
);
}
#[test]
fn test_truncate_to_tokens() {
let est = TokenEstimator::new();
let text = "This is a long text that needs to be truncated to fit within the token budget";
let truncated = truncate_to_tokens(text, 10, &est, "...");
assert!(truncated.len() < text.len());
assert!(truncated.ends_with("..."));
assert!(est.estimate_text(&truncated) <= 10);
}
#[test]
fn test_truncate_rows() {
let est = TokenEstimator::new();
let rows: Vec<Vec<SochValue>> = (0..100)
.map(|i| vec![SochValue::Int(i), SochValue::Text(format!("row{}", i))])
.collect();
let truncated = truncate_rows(&rows, 50, &est);
assert!(truncated.len() < rows.len());
let total: usize = truncated.iter().map(|r| est.estimate_row(r)).sum();
assert!(total <= 50);
}
#[test]
fn test_reserved_budget() {
let mut enforcer = TokenBudgetEnforcer::with_budget(1000);
enforcer.reserve(200);
assert_eq!(enforcer.available(), 800);
assert!(enforcer.try_allocate(700));
assert_eq!(enforcer.available(), 100);
assert!(!enforcer.try_allocate(200));
}
#[test]
fn test_estimator_configs() {
let default = TokenEstimator::new();
let gpt4 = TokenEstimator::with_config(TokenEstimatorConfig::gpt4());
let conservative = TokenEstimator::with_config(TokenEstimatorConfig::conservative());
let text = "Hello, this is a test string for comparing token estimation across different configurations.";
let default_est = default.estimate_text(text);
let gpt4_est = gpt4.estimate_text(text);
let conservative_est = conservative.estimate_text(text);
assert!(conservative_est >= default_est);
assert!(default_est > 0);
assert!(gpt4_est > 0);
assert!(conservative_est > 0);
}
#[test]
fn test_section_with_truncation() {
let enforcer = TokenBudgetEnforcer::with_budget(600);
let sections = vec![
BudgetSection {
name: "Required".to_string(),
priority: 0,
estimated_tokens: 500,
minimum_tokens: None,
required: true,
weight: 1.0,
},
BudgetSection {
name: "Optional".to_string(),
priority: 1,
estimated_tokens: 300,
minimum_tokens: Some(50), required: false,
weight: 1.0,
},
];
let allocation = enforcer.allocate_sections(§ions);
assert!(allocation.full_sections.contains(&"Required".to_string()));
assert!(
allocation
.truncated_sections
.iter()
.any(|(n, _, _)| n == "Optional")
);
}
}