1use crate::segment::ContextSegmentType;
8use chrono::{DateTime, Utc};
9use enact_core::kernel::ExecutionId;
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
16#[serde(rename_all = "camelCase")]
17pub struct SegmentBudget {
18 #[serde(rename = "type")]
20 pub segment_type: ContextSegmentType,
21
22 pub max_tokens: usize,
24
25 pub current_tokens: usize,
27
28 pub reserved_tokens: usize,
30
31 pub can_borrow: bool,
33
34 pub can_lend: bool,
36}
37
38impl SegmentBudget {
39 pub fn new(segment_type: ContextSegmentType, max_tokens: usize) -> Self {
41 Self {
42 segment_type,
43 max_tokens,
44 current_tokens: 0,
45 reserved_tokens: 0,
46 can_borrow: false,
47 can_lend: false,
48 }
49 }
50
51 pub fn available(&self) -> usize {
53 self.max_tokens.saturating_sub(self.current_tokens)
54 }
55
56 pub fn usage_percent(&self) -> u8 {
58 if self.max_tokens == 0 {
59 return 0;
60 }
61 ((self.current_tokens as f64 / self.max_tokens as f64) * 100.0) as u8
62 }
63
64 pub fn is_over_budget(&self) -> bool {
66 self.current_tokens > self.max_tokens
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
74#[serde(rename_all = "camelCase")]
75pub struct ContextBudget {
76 pub execution_id: ExecutionId,
78
79 pub total_tokens: usize,
81
82 pub output_reserve: usize,
84
85 pub available_tokens: usize,
87
88 pub used_tokens: usize,
90
91 pub segments: Vec<SegmentBudget>,
93
94 pub warning_threshold: u8,
96
97 pub critical_threshold: u8,
99
100 pub updated_at: DateTime<Utc>,
102}
103
104impl ContextBudget {
105 pub fn new(execution_id: ExecutionId, total_tokens: usize, output_reserve: usize) -> Self {
107 Self {
108 execution_id,
109 total_tokens,
110 output_reserve,
111 available_tokens: total_tokens.saturating_sub(output_reserve),
112 used_tokens: 0,
113 segments: Vec::new(),
114 warning_threshold: 80,
115 critical_threshold: 95,
116 updated_at: Utc::now(),
117 }
118 }
119
120 pub fn preset_gpt4_128k(execution_id: ExecutionId) -> Self {
122 let mut budget = Self::new(execution_id, 128_000, 4_096);
123 budget.segments = vec![
124 SegmentBudget::new(ContextSegmentType::System, 4_000),
125 SegmentBudget::new(ContextSegmentType::History, 60_000),
126 SegmentBudget::new(ContextSegmentType::WorkingMemory, 20_000),
127 SegmentBudget::new(ContextSegmentType::ToolResults, 20_000),
128 SegmentBudget::new(ContextSegmentType::RagContext, 15_000),
129 SegmentBudget::new(ContextSegmentType::UserInput, 2_000),
130 SegmentBudget::new(ContextSegmentType::AgentScratchpad, 2_000),
131 SegmentBudget::new(ContextSegmentType::ChildSummary, 500),
132 SegmentBudget::new(ContextSegmentType::Guidance, 500),
133 ];
134 budget
135 }
136
137 pub fn preset_gpt4_32k(execution_id: ExecutionId) -> Self {
139 let mut budget = Self::new(execution_id, 32_000, 2_048);
140 budget.segments = vec![
141 SegmentBudget::new(ContextSegmentType::System, 2_000),
142 SegmentBudget::new(ContextSegmentType::History, 15_000),
143 SegmentBudget::new(ContextSegmentType::WorkingMemory, 5_000),
144 SegmentBudget::new(ContextSegmentType::ToolResults, 4_000),
145 SegmentBudget::new(ContextSegmentType::RagContext, 3_000),
146 SegmentBudget::new(ContextSegmentType::UserInput, 1_000),
147 SegmentBudget::new(ContextSegmentType::AgentScratchpad, 500),
148 SegmentBudget::new(ContextSegmentType::ChildSummary, 250),
149 SegmentBudget::new(ContextSegmentType::Guidance, 250),
150 ];
151 budget
152 }
153
154 pub fn preset_claude_200k(execution_id: ExecutionId) -> Self {
156 let mut budget = Self::new(execution_id, 200_000, 4_096);
157 budget.segments = vec![
158 SegmentBudget::new(ContextSegmentType::System, 8_000),
159 SegmentBudget::new(ContextSegmentType::History, 100_000),
160 SegmentBudget::new(ContextSegmentType::WorkingMemory, 40_000),
161 SegmentBudget::new(ContextSegmentType::ToolResults, 25_000),
162 SegmentBudget::new(ContextSegmentType::RagContext, 15_000),
163 SegmentBudget::new(ContextSegmentType::UserInput, 4_000),
164 SegmentBudget::new(ContextSegmentType::AgentScratchpad, 2_000),
165 SegmentBudget::new(ContextSegmentType::ChildSummary, 1_000),
166 SegmentBudget::new(ContextSegmentType::Guidance, 1_000),
167 ];
168 budget
169 }
170
171 pub fn preset_default(execution_id: ExecutionId) -> Self {
173 let mut budget = Self::new(execution_id, 8_000, 1_024);
174 budget.segments = vec![
175 SegmentBudget::new(ContextSegmentType::System, 1_000),
176 SegmentBudget::new(ContextSegmentType::History, 3_000),
177 SegmentBudget::new(ContextSegmentType::WorkingMemory, 1_000),
178 SegmentBudget::new(ContextSegmentType::ToolResults, 1_000),
179 SegmentBudget::new(ContextSegmentType::RagContext, 500),
180 SegmentBudget::new(ContextSegmentType::UserInput, 500),
181 SegmentBudget::new(ContextSegmentType::AgentScratchpad, 0),
182 SegmentBudget::new(ContextSegmentType::ChildSummary, 0),
183 SegmentBudget::new(ContextSegmentType::Guidance, 0),
184 ];
185 budget
186 }
187
188 pub fn get_segment(&self, segment_type: ContextSegmentType) -> Option<&SegmentBudget> {
190 self.segments
191 .iter()
192 .find(|s| s.segment_type == segment_type)
193 }
194
195 pub fn get_segment_mut(
197 &mut self,
198 segment_type: ContextSegmentType,
199 ) -> Option<&mut SegmentBudget> {
200 self.segments
201 .iter_mut()
202 .find(|s| s.segment_type == segment_type)
203 }
204
205 pub fn update_segment_usage(&mut self, segment_type: ContextSegmentType, tokens: usize) {
207 if let Some(segment) = self.get_segment_mut(segment_type) {
208 segment.current_tokens = tokens;
209 }
210 self.recalculate_total();
211 }
212
213 pub fn add_tokens(&mut self, segment_type: ContextSegmentType, tokens: usize) {
215 if let Some(segment) = self.get_segment_mut(segment_type) {
216 segment.current_tokens += tokens;
217 }
218 self.recalculate_total();
219 }
220
221 pub fn remove_tokens(&mut self, segment_type: ContextSegmentType, tokens: usize) {
223 if let Some(segment) = self.get_segment_mut(segment_type) {
224 segment.current_tokens = segment.current_tokens.saturating_sub(tokens);
225 }
226 self.recalculate_total();
227 }
228
229 fn recalculate_total(&mut self) {
231 self.used_tokens = self.segments.iter().map(|s| s.current_tokens).sum();
232 self.updated_at = Utc::now();
233 }
234
235 pub fn remaining(&self) -> usize {
237 self.available_tokens.saturating_sub(self.used_tokens)
238 }
239
240 pub fn usage_percent(&self) -> u8 {
242 if self.available_tokens == 0 {
243 return 0;
244 }
245 ((self.used_tokens as f64 / self.available_tokens as f64) * 100.0) as u8
246 }
247
248 pub fn is_warning(&self) -> bool {
250 self.usage_percent() >= self.warning_threshold
251 }
252
253 pub fn is_critical(&self) -> bool {
255 self.usage_percent() >= self.critical_threshold
256 }
257
258 pub fn health(&self) -> BudgetHealth {
260 if self.is_critical() {
261 BudgetHealth::Critical
262 } else if self.is_warning() {
263 BudgetHealth::Warning
264 } else {
265 BudgetHealth::Healthy
266 }
267 }
268}
269
270#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
272#[serde(rename_all = "snake_case")]
273pub enum BudgetHealth {
274 Healthy,
275 Warning,
276 Critical,
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 fn test_execution_id() -> ExecutionId {
284 ExecutionId::new()
285 }
286
287 #[test]
288 fn test_segment_budget() {
289 let mut budget = SegmentBudget::new(ContextSegmentType::History, 1000);
290 assert_eq!(budget.available(), 1000);
291 assert_eq!(budget.usage_percent(), 0);
292
293 budget.current_tokens = 500;
294 assert_eq!(budget.available(), 500);
295 assert_eq!(budget.usage_percent(), 50);
296 }
297
298 #[test]
299 fn test_context_budget_presets() {
300 let budget = ContextBudget::preset_gpt4_128k(test_execution_id());
301 assert_eq!(budget.total_tokens, 128_000);
302 assert_eq!(budget.output_reserve, 4_096);
303 assert!(!budget.segments.is_empty());
304 }
305
306 #[test]
307 fn test_budget_health() {
308 let mut budget = ContextBudget::preset_default(test_execution_id());
309
310 assert_eq!(budget.health(), BudgetHealth::Healthy);
312
313 budget.used_tokens = (budget.available_tokens as f64 * 0.85) as usize;
315 assert_eq!(budget.health(), BudgetHealth::Warning);
316
317 budget.used_tokens = (budget.available_tokens as f64 * 0.96) as usize;
319 assert_eq!(budget.health(), BudgetHealth::Critical);
320 }
321}