use crate::segment::ContextSegmentType;
use chrono::{DateTime, Utc};
use enact_core::kernel::ExecutionId;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SegmentBudget {
#[serde(rename = "type")]
pub segment_type: ContextSegmentType,
pub max_tokens: usize,
pub current_tokens: usize,
pub reserved_tokens: usize,
pub can_borrow: bool,
pub can_lend: bool,
}
impl SegmentBudget {
pub fn new(segment_type: ContextSegmentType, max_tokens: usize) -> Self {
Self {
segment_type,
max_tokens,
current_tokens: 0,
reserved_tokens: 0,
can_borrow: false,
can_lend: false,
}
}
pub fn available(&self) -> usize {
self.max_tokens.saturating_sub(self.current_tokens)
}
pub fn usage_percent(&self) -> u8 {
if self.max_tokens == 0 {
return 0;
}
((self.current_tokens as f64 / self.max_tokens as f64) * 100.0) as u8
}
pub fn is_over_budget(&self) -> bool {
self.current_tokens > self.max_tokens
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ContextBudget {
pub execution_id: ExecutionId,
pub total_tokens: usize,
pub output_reserve: usize,
pub available_tokens: usize,
pub used_tokens: usize,
pub segments: Vec<SegmentBudget>,
pub warning_threshold: u8,
pub critical_threshold: u8,
pub updated_at: DateTime<Utc>,
}
impl ContextBudget {
pub fn new(execution_id: ExecutionId, total_tokens: usize, output_reserve: usize) -> Self {
Self {
execution_id,
total_tokens,
output_reserve,
available_tokens: total_tokens.saturating_sub(output_reserve),
used_tokens: 0,
segments: Vec::new(),
warning_threshold: 80,
critical_threshold: 95,
updated_at: Utc::now(),
}
}
pub fn preset_gpt4_128k(execution_id: ExecutionId) -> Self {
let mut budget = Self::new(execution_id, 128_000, 4_096);
budget.segments = vec![
SegmentBudget::new(ContextSegmentType::System, 4_000),
SegmentBudget::new(ContextSegmentType::History, 60_000),
SegmentBudget::new(ContextSegmentType::WorkingMemory, 20_000),
SegmentBudget::new(ContextSegmentType::ToolResults, 20_000),
SegmentBudget::new(ContextSegmentType::RagContext, 15_000),
SegmentBudget::new(ContextSegmentType::UserInput, 2_000),
SegmentBudget::new(ContextSegmentType::AgentScratchpad, 2_000),
SegmentBudget::new(ContextSegmentType::ChildSummary, 500),
SegmentBudget::new(ContextSegmentType::Guidance, 500),
];
budget
}
pub fn preset_gpt4_32k(execution_id: ExecutionId) -> Self {
let mut budget = Self::new(execution_id, 32_000, 2_048);
budget.segments = vec![
SegmentBudget::new(ContextSegmentType::System, 2_000),
SegmentBudget::new(ContextSegmentType::History, 15_000),
SegmentBudget::new(ContextSegmentType::WorkingMemory, 5_000),
SegmentBudget::new(ContextSegmentType::ToolResults, 4_000),
SegmentBudget::new(ContextSegmentType::RagContext, 3_000),
SegmentBudget::new(ContextSegmentType::UserInput, 1_000),
SegmentBudget::new(ContextSegmentType::AgentScratchpad, 500),
SegmentBudget::new(ContextSegmentType::ChildSummary, 250),
SegmentBudget::new(ContextSegmentType::Guidance, 250),
];
budget
}
pub fn preset_claude_200k(execution_id: ExecutionId) -> Self {
let mut budget = Self::new(execution_id, 200_000, 4_096);
budget.segments = vec![
SegmentBudget::new(ContextSegmentType::System, 8_000),
SegmentBudget::new(ContextSegmentType::History, 100_000),
SegmentBudget::new(ContextSegmentType::WorkingMemory, 40_000),
SegmentBudget::new(ContextSegmentType::ToolResults, 25_000),
SegmentBudget::new(ContextSegmentType::RagContext, 15_000),
SegmentBudget::new(ContextSegmentType::UserInput, 4_000),
SegmentBudget::new(ContextSegmentType::AgentScratchpad, 2_000),
SegmentBudget::new(ContextSegmentType::ChildSummary, 1_000),
SegmentBudget::new(ContextSegmentType::Guidance, 1_000),
];
budget
}
pub fn preset_default(execution_id: ExecutionId) -> Self {
let mut budget = Self::new(execution_id, 8_000, 1_024);
budget.segments = vec![
SegmentBudget::new(ContextSegmentType::System, 1_000),
SegmentBudget::new(ContextSegmentType::History, 3_000),
SegmentBudget::new(ContextSegmentType::WorkingMemory, 1_000),
SegmentBudget::new(ContextSegmentType::ToolResults, 1_000),
SegmentBudget::new(ContextSegmentType::RagContext, 500),
SegmentBudget::new(ContextSegmentType::UserInput, 500),
SegmentBudget::new(ContextSegmentType::AgentScratchpad, 0),
SegmentBudget::new(ContextSegmentType::ChildSummary, 0),
SegmentBudget::new(ContextSegmentType::Guidance, 0),
];
budget
}
pub fn get_segment(&self, segment_type: ContextSegmentType) -> Option<&SegmentBudget> {
self.segments
.iter()
.find(|s| s.segment_type == segment_type)
}
pub fn get_segment_mut(
&mut self,
segment_type: ContextSegmentType,
) -> Option<&mut SegmentBudget> {
self.segments
.iter_mut()
.find(|s| s.segment_type == segment_type)
}
pub fn update_segment_usage(&mut self, segment_type: ContextSegmentType, tokens: usize) {
if let Some(segment) = self.get_segment_mut(segment_type) {
segment.current_tokens = tokens;
}
self.recalculate_total();
}
pub fn add_tokens(&mut self, segment_type: ContextSegmentType, tokens: usize) {
if let Some(segment) = self.get_segment_mut(segment_type) {
segment.current_tokens += tokens;
}
self.recalculate_total();
}
pub fn remove_tokens(&mut self, segment_type: ContextSegmentType, tokens: usize) {
if let Some(segment) = self.get_segment_mut(segment_type) {
segment.current_tokens = segment.current_tokens.saturating_sub(tokens);
}
self.recalculate_total();
}
fn recalculate_total(&mut self) {
self.used_tokens = self.segments.iter().map(|s| s.current_tokens).sum();
self.updated_at = Utc::now();
}
pub fn remaining(&self) -> usize {
self.available_tokens.saturating_sub(self.used_tokens)
}
pub fn usage_percent(&self) -> u8 {
if self.available_tokens == 0 {
return 0;
}
((self.used_tokens as f64 / self.available_tokens as f64) * 100.0) as u8
}
pub fn is_warning(&self) -> bool {
self.usage_percent() >= self.warning_threshold
}
pub fn is_critical(&self) -> bool {
self.usage_percent() >= self.critical_threshold
}
pub fn health(&self) -> BudgetHealth {
if self.is_critical() {
BudgetHealth::Critical
} else if self.is_warning() {
BudgetHealth::Warning
} else {
BudgetHealth::Healthy
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum BudgetHealth {
Healthy,
Warning,
Critical,
}
#[cfg(test)]
mod tests {
use super::*;
fn test_execution_id() -> ExecutionId {
ExecutionId::new()
}
#[test]
fn test_segment_budget() {
let mut budget = SegmentBudget::new(ContextSegmentType::History, 1000);
assert_eq!(budget.available(), 1000);
assert_eq!(budget.usage_percent(), 0);
budget.current_tokens = 500;
assert_eq!(budget.available(), 500);
assert_eq!(budget.usage_percent(), 50);
}
#[test]
fn test_context_budget_presets() {
let budget = ContextBudget::preset_gpt4_128k(test_execution_id());
assert_eq!(budget.total_tokens, 128_000);
assert_eq!(budget.output_reserve, 4_096);
assert!(!budget.segments.is_empty());
}
#[test]
fn test_budget_health() {
let mut budget = ContextBudget::preset_default(test_execution_id());
assert_eq!(budget.health(), BudgetHealth::Healthy);
budget.used_tokens = (budget.available_tokens as f64 * 0.85) as usize;
assert_eq!(budget.health(), BudgetHealth::Warning);
budget.used_tokens = (budget.available_tokens as f64 * 0.96) as usize;
assert_eq!(budget.health(), BudgetHealth::Critical);
}
}