1use crate::budget::{BudgetHealth, ContextBudget};
8use crate::compactor::{CompactionResult, CompactionStrategyType, Compactor};
9use crate::segment::{ContextSegment, ContextSegmentType};
10use crate::token_counter::TokenCounter;
11use chrono::{DateTime, Utc};
12use enact_core::kernel::ExecutionId;
13use serde::{Deserialize, Serialize};
14use std::time::Instant;
15use thiserror::Error;
16
17#[derive(Debug, Error)]
19pub enum ContextWindowError {
20 #[error("Token counter error: {0}")]
21 TokenCounter(String),
22
23 #[error("Budget exceeded: need {needed} tokens, only {available} available")]
24 BudgetExceeded { needed: usize, available: usize },
25
26 #[error("Segment budget exceeded for {segment_type:?}: need {needed}, max {max}")]
27 SegmentBudgetExceeded {
28 segment_type: ContextSegmentType,
29 needed: usize,
30 max: usize,
31 },
32
33 #[error("Compaction failed: {0}")]
34 CompactionFailed(String),
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
41#[serde(rename_all = "camelCase")]
42pub struct ContextWindowState {
43 pub execution_id: ExecutionId,
45
46 pub segments: Vec<ContextSegment>,
48
49 pub budget: ContextBudget,
51
52 pub compaction_history: Vec<CompactionResult>,
54
55 pub compaction_count: u32,
57
58 pub total_tokens_saved: usize,
60
61 pub health: BudgetHealth,
63
64 pub updated_at: DateTime<Utc>,
66}
67
68pub struct ContextWindow {
70 execution_id: ExecutionId,
72
73 segments: Vec<ContextSegment>,
75
76 budget: ContextBudget,
78
79 token_counter: TokenCounter,
81
82 compaction_history: Vec<CompactionResult>,
84
85 next_sequence: u64,
87}
88
89impl ContextWindow {
90 pub fn new(budget: ContextBudget) -> Result<Self, ContextWindowError> {
92 let token_counter =
93 TokenCounter::new().map_err(|e| ContextWindowError::TokenCounter(e.to_string()))?;
94
95 Ok(Self {
96 execution_id: budget.execution_id.clone(),
97 segments: Vec::new(),
98 budget,
99 token_counter,
100 compaction_history: Vec::new(),
101 next_sequence: 0,
102 })
103 }
104
105 pub fn with_preset_gpt4_128k(execution_id: ExecutionId) -> Result<Self, ContextWindowError> {
107 Self::new(ContextBudget::preset_gpt4_128k(execution_id))
108 }
109
110 pub fn with_preset_claude_200k(execution_id: ExecutionId) -> Result<Self, ContextWindowError> {
112 Self::new(ContextBudget::preset_claude_200k(execution_id))
113 }
114
115 pub fn with_preset_default(execution_id: ExecutionId) -> Result<Self, ContextWindowError> {
117 Self::new(ContextBudget::preset_default(execution_id))
118 }
119
120 pub fn execution_id(&self) -> &ExecutionId {
122 &self.execution_id
123 }
124
125 pub fn segments(&self) -> &[ContextSegment] {
127 &self.segments
128 }
129
130 pub fn budget(&self) -> &ContextBudget {
132 &self.budget
133 }
134
135 pub fn budget_mut(&mut self) -> &mut ContextBudget {
137 &mut self.budget
138 }
139
140 pub fn count_tokens(&self, text: &str) -> usize {
142 self.token_counter.count(text)
143 }
144
145 pub fn add_segment_auto(
147 &mut self,
148 mut segment: ContextSegment,
149 ) -> Result<(), ContextWindowError> {
150 if segment.token_count == 0 {
152 segment.token_count = self.token_counter.count(&segment.content);
153 }
154
155 self.add_segment(segment)
156 }
157
158 pub fn add_segment(&mut self, mut segment: ContextSegment) -> Result<(), ContextWindowError> {
160 if let Some(seg_budget) = self.budget.get_segment(segment.segment_type) {
162 let new_usage = seg_budget.current_tokens + segment.token_count;
163 if new_usage > seg_budget.max_tokens {
164 return Err(ContextWindowError::SegmentBudgetExceeded {
165 segment_type: segment.segment_type,
166 needed: segment.token_count,
167 max: seg_budget.max_tokens - seg_budget.current_tokens,
168 });
169 }
170 }
171
172 let new_total = self.budget.used_tokens + segment.token_count;
174 if new_total > self.budget.available_tokens {
175 return Err(ContextWindowError::BudgetExceeded {
176 needed: segment.token_count,
177 available: self.budget.remaining(),
178 });
179 }
180
181 segment.sequence = self.next_sequence;
183 self.next_sequence += 1;
184
185 self.budget
187 .add_tokens(segment.segment_type, segment.token_count);
188
189 self.segments.push(segment);
191
192 Ok(())
193 }
194
195 pub fn remove_segment(&mut self, segment_id: &str) -> bool {
197 if let Some(pos) = self.segments.iter().position(|s| s.id == segment_id) {
198 let segment = self.segments.remove(pos);
199 self.budget
200 .remove_tokens(segment.segment_type, segment.token_count);
201 true
202 } else {
203 false
204 }
205 }
206
207 pub fn segments_of_type(&self, segment_type: ContextSegmentType) -> Vec<&ContextSegment> {
209 self.segments
210 .iter()
211 .filter(|s| s.segment_type == segment_type)
212 .collect()
213 }
214
215 pub fn used_tokens(&self) -> usize {
217 self.budget.used_tokens
218 }
219
220 pub fn remaining_tokens(&self) -> usize {
222 self.budget.remaining()
223 }
224
225 pub fn needs_compaction(&self) -> bool {
227 self.budget.is_warning()
228 }
229
230 pub fn is_critical(&self) -> bool {
232 self.budget.is_critical()
233 }
234
235 pub fn health(&self) -> BudgetHealth {
237 self.budget.health()
238 }
239
240 pub fn compact(
242 &mut self,
243 compactor: &Compactor,
244 ) -> Result<CompactionResult, ContextWindowError> {
245 let start = Instant::now();
246 let tokens_before = self.budget.used_tokens;
247
248 let result = match compactor.strategy().strategy_type {
249 CompactionStrategyType::Truncate => {
250 compactor.compact_truncate(&mut self.segments, tokens_before)
251 }
252 CompactionStrategyType::SlidingWindow => {
253 compactor.compact_sliding_window(&mut self.segments)
254 }
255 _ => {
256 return Err(ContextWindowError::CompactionFailed(format!(
258 "Strategy {:?} not implemented",
259 compactor.strategy().strategy_type
260 )));
261 }
262 };
263
264 let duration_ms = start.elapsed().as_millis() as u64;
265
266 match result {
267 Ok(tokens_removed) => {
268 self.recalculate_budget();
270
271 let tokens_after = self.budget.used_tokens;
272 let segments_compacted = (tokens_removed > 0) as usize;
273
274 let compaction_result = CompactionResult::success(
275 self.execution_id.clone(),
276 compactor.strategy().strategy_type,
277 tokens_before,
278 tokens_after,
279 segments_compacted,
280 duration_ms,
281 );
282
283 self.compaction_history.push(compaction_result.clone());
284 Ok(compaction_result)
285 }
286 Err(e) => {
287 let compaction_result = CompactionResult::failure(
288 self.execution_id.clone(),
289 compactor.strategy().strategy_type,
290 tokens_before,
291 e.to_string(),
292 duration_ms,
293 );
294
295 self.compaction_history.push(compaction_result.clone());
296 Err(ContextWindowError::CompactionFailed(e.to_string()))
297 }
298 }
299 }
300
301 fn recalculate_budget(&mut self) {
303 for seg_budget in &mut self.budget.segments {
305 seg_budget.current_tokens = 0;
306 }
307
308 for segment in &self.segments {
310 self.budget
311 .add_tokens(segment.segment_type, segment.token_count);
312 }
313 }
314
315 pub fn build_context(&self) -> String {
317 let mut parts: Vec<&str> = Vec::new();
318
319 let mut sorted: Vec<&ContextSegment> = self.segments.iter().collect();
321 sorted.sort_by_key(|s| s.sequence);
322
323 for segment in sorted {
324 parts.push(&segment.content);
325 }
326
327 parts.join("\n\n")
328 }
329
330 pub fn state(&self) -> ContextWindowState {
332 ContextWindowState {
333 execution_id: self.execution_id.clone(),
334 segments: self.segments.clone(),
335 budget: self.budget.clone(),
336 compaction_history: self.compaction_history.clone(),
337 compaction_count: self.compaction_history.len() as u32,
338 total_tokens_saved: self.compaction_history.iter().map(|r| r.tokens_saved).sum(),
339 health: self.budget.health(),
340 updated_at: Utc::now(),
341 }
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 fn test_execution_id() -> ExecutionId {
350 ExecutionId::new()
351 }
352
353 #[test]
354 fn test_create_window() {
355 let budget = ContextBudget::preset_default(test_execution_id());
356 let window = ContextWindow::new(budget).unwrap();
357
358 assert_eq!(window.used_tokens(), 0);
359 assert!(window.remaining_tokens() > 0);
360 }
361
362 #[test]
363 fn test_add_segment() {
364 let budget = ContextBudget::preset_default(test_execution_id());
365 let mut window = ContextWindow::new(budget).unwrap();
366
367 let segment = ContextSegment::system("You are a helpful assistant.", 10);
368 window.add_segment(segment).unwrap();
369
370 assert_eq!(window.segments().len(), 1);
371 assert_eq!(window.used_tokens(), 10);
372 }
373
374 #[test]
375 fn test_health_tracking() {
376 let budget = ContextBudget::preset_default(test_execution_id());
377 let window = ContextWindow::new(budget).unwrap();
378
379 assert_eq!(window.health(), BudgetHealth::Healthy);
380 assert!(!window.needs_compaction());
381 }
382}