Skip to main content

aster/context/
window_manager.rs

1//! Context Window Manager Module
2//!
3//! Provides dynamic context window management for different LLM models.
4//!
5//! # Context Window Strategy
6//!
7//! - For models with context window ≤50k tokens: reserve 20% for output
8//! - For models with context window >50k tokens: reserve fixed 50k tokens for output
9//!
10//! # Features
11//!
12//! - Model-specific context window sizes
13//! - Token usage tracking (input, output, cache)
14//! - Usage percentage calculation
15//! - Near-limit detection
16
17use crate::context::types::{CacheStats, ContextWindowStats, TokenUsage};
18use std::collections::HashMap;
19use std::sync::LazyLock;
20
21/// Threshold for small context windows (50k tokens)
22const SMALL_CONTEXT_THRESHOLD: usize = 50_000;
23
24/// Output reservation percentage for small context windows
25const SMALL_CONTEXT_OUTPUT_RESERVE_PERCENT: f64 = 0.20;
26
27/// Fixed output reservation for large context windows
28const LARGE_CONTEXT_OUTPUT_RESERVE: usize = 50_000;
29
30/// Model context window sizes mapping.
31///
32/// Maps model IDs to their maximum context window sizes in tokens.
33pub static MODEL_CONTEXT_WINDOWS: LazyLock<HashMap<&'static str, usize>> = LazyLock::new(|| {
34    let mut m = HashMap::new();
35    // Claude models
36    m.insert("claude-3-5-sonnet-20241022", 200_000);
37    m.insert("claude-3-7-sonnet-20250219", 200_000);
38    m.insert("claude-4-0-sonnet-20250514", 200_000);
39    m.insert("claude-3-opus-20240229", 200_000);
40    m.insert("claude-3-sonnet-20240229", 200_000);
41    m.insert("claude-3-haiku-20240307", 200_000);
42    // OpenAI models
43    m.insert("gpt-4o", 128_000);
44    m.insert("gpt-4o-mini", 128_000);
45    m.insert("gpt-4-turbo", 128_000);
46    m.insert("gpt-4", 8_192);
47    m.insert("gpt-3.5-turbo", 16_385);
48    // Default fallback
49    m.insert("default", 200_000);
50    m
51});
52
53/// Context Window Manager for tracking and managing token usage.
54///
55/// Tracks cumulative token usage across API calls and provides
56/// utilities for calculating available context space.
57#[derive(Debug, Clone)]
58pub struct ContextWindowManager {
59    /// Size of the context window for the current model
60    context_window_size: usize,
61    /// Total input tokens consumed across all calls
62    total_input_tokens: usize,
63    /// Total output tokens generated across all calls
64    total_output_tokens: usize,
65    /// Total tokens written to cache
66    total_cache_creation_tokens: usize,
67    /// Total tokens read from cache
68    total_cache_read_tokens: usize,
69    /// Current API call usage (most recent)
70    current_usage: Option<TokenUsage>,
71    /// Current model ID
72    model_id: String,
73}
74
75impl Default for ContextWindowManager {
76    fn default() -> Self {
77        Self::new("default")
78    }
79}
80
81impl ContextWindowManager {
82    /// Create a new ContextWindowManager for the specified model.
83    ///
84    /// # Arguments
85    ///
86    /// * `model_id` - The model identifier (e.g., "claude-3-5-sonnet-20241022")
87    ///
88    /// # Example
89    ///
90    /// ```
91    /// use aster::context::window_manager::ContextWindowManager;
92    ///
93    /// let manager = ContextWindowManager::new("claude-3-5-sonnet-20241022");
94    /// assert_eq!(manager.get_context_window_size(), 200_000);
95    /// ```
96    pub fn new(model_id: &str) -> Self {
97        let context_window_size = Self::get_model_context_window(model_id);
98        Self {
99            context_window_size,
100            total_input_tokens: 0,
101            total_output_tokens: 0,
102            total_cache_creation_tokens: 0,
103            total_cache_read_tokens: 0,
104            current_usage: None,
105            model_id: model_id.to_string(),
106        }
107    }
108
109    /// Get the context window size for a model.
110    ///
111    /// Returns the known context window size for the model, or the default
112    /// if the model is not recognized.
113    ///
114    /// # Arguments
115    ///
116    /// * `model_id` - The model identifier
117    ///
118    /// # Returns
119    ///
120    /// Context window size in tokens
121    pub fn get_model_context_window(model_id: &str) -> usize {
122        MODEL_CONTEXT_WINDOWS
123            .get(model_id)
124            .copied()
125            .unwrap_or_else(|| {
126                // Try to find a partial match
127                for (key, value) in MODEL_CONTEXT_WINDOWS.iter() {
128                    if model_id.contains(key) || key.contains(model_id) {
129                        return *value;
130                    }
131                }
132                // Fall back to default
133                *MODEL_CONTEXT_WINDOWS.get("default").unwrap_or(&200_000)
134            })
135    }
136
137    /// Calculate available context space for input.
138    ///
139    /// Applies the reservation strategy:
140    /// - For context ≤50k: reserve 20% for output
141    /// - For context >50k: reserve fixed 50k for output
142    ///
143    /// # Arguments
144    ///
145    /// * `model_id` - The model identifier
146    ///
147    /// # Returns
148    ///
149    /// Available tokens for input
150    pub fn calculate_available_context(model_id: &str) -> usize {
151        let window_size = Self::get_model_context_window(model_id);
152        Self::calculate_available_from_window(window_size)
153    }
154
155    /// Calculate available context from a given window size.
156    fn calculate_available_from_window(window_size: usize) -> usize {
157        if window_size <= SMALL_CONTEXT_THRESHOLD {
158            // Reserve 20% for output
159            ((window_size as f64) * (1.0 - SMALL_CONTEXT_OUTPUT_RESERVE_PERCENT)) as usize
160        } else {
161            // Reserve fixed 50k for output
162            window_size.saturating_sub(LARGE_CONTEXT_OUTPUT_RESERVE)
163        }
164    }
165
166    /// Calculate output space reservation for a model.
167    ///
168    /// # Arguments
169    ///
170    /// * `model_id` - The model identifier
171    ///
172    /// # Returns
173    ///
174    /// Tokens reserved for output
175    pub fn calculate_output_space(model_id: &str) -> usize {
176        let window_size = Self::get_model_context_window(model_id);
177        Self::calculate_output_from_window(window_size)
178    }
179
180    /// Calculate output space from a given window size.
181    fn calculate_output_from_window(window_size: usize) -> usize {
182        if window_size <= SMALL_CONTEXT_THRESHOLD {
183            // Reserve 20% for output
184            ((window_size as f64) * SMALL_CONTEXT_OUTPUT_RESERVE_PERCENT) as usize
185        } else {
186            // Reserve fixed 50k for output
187            LARGE_CONTEXT_OUTPUT_RESERVE
188        }
189    }
190
191    /// Update the model and recalculate context window size.
192    ///
193    /// # Arguments
194    ///
195    /// * `model_id` - The new model identifier
196    pub fn update_model(&mut self, model_id: &str) {
197        self.model_id = model_id.to_string();
198        self.context_window_size = Self::get_model_context_window(model_id);
199    }
200
201    /// Record token usage from an API call.
202    ///
203    /// Updates cumulative totals and stores the current usage.
204    ///
205    /// # Arguments
206    ///
207    /// * `usage` - Token usage from the API call
208    pub fn record_usage(&mut self, usage: TokenUsage) {
209        self.total_input_tokens += usage.input_tokens;
210        self.total_output_tokens += usage.output_tokens;
211
212        if let Some(cache_creation) = usage.cache_creation_tokens {
213            self.total_cache_creation_tokens += cache_creation;
214        }
215
216        if let Some(cache_read) = usage.cache_read_tokens {
217            self.total_cache_read_tokens += cache_read;
218        }
219
220        self.current_usage = Some(usage);
221    }
222
223    /// Get the current context usage percentage.
224    ///
225    /// Calculates usage based on total input tokens relative to context window.
226    ///
227    /// # Returns
228    ///
229    /// Usage percentage (0.0 - 100.0)
230    pub fn get_usage_percentage(&self) -> f64 {
231        if self.context_window_size == 0 {
232            return 0.0;
233        }
234        (self.total_input_tokens as f64 / self.context_window_size as f64) * 100.0
235    }
236
237    /// Check if context usage is near the limit.
238    ///
239    /// # Arguments
240    ///
241    /// * `threshold` - Percentage threshold (0.0 - 100.0)
242    ///
243    /// # Returns
244    ///
245    /// `true` if usage exceeds the threshold
246    pub fn is_near_limit(&self, threshold: f64) -> bool {
247        self.get_usage_percentage() >= threshold
248    }
249
250    /// Get the context window size.
251    pub fn get_context_window_size(&self) -> usize {
252        self.context_window_size
253    }
254
255    /// Get total input tokens consumed.
256    pub fn get_total_input_tokens(&self) -> usize {
257        self.total_input_tokens
258    }
259
260    /// Get total output tokens generated.
261    pub fn get_total_output_tokens(&self) -> usize {
262        self.total_output_tokens
263    }
264
265    /// Get available context space for the current model.
266    pub fn get_available_context(&self) -> usize {
267        Self::calculate_available_from_window(self.context_window_size)
268    }
269
270    /// Get output space reservation for the current model.
271    pub fn get_output_space(&self) -> usize {
272        Self::calculate_output_from_window(self.context_window_size)
273    }
274
275    /// Get the current model ID.
276    pub fn get_model_id(&self) -> &str {
277        &self.model_id
278    }
279
280    /// Get the most recent API call usage.
281    pub fn get_current_usage(&self) -> Option<&TokenUsage> {
282        self.current_usage.as_ref()
283    }
284
285    /// Get context window statistics.
286    ///
287    /// # Returns
288    ///
289    /// Statistics about context window usage
290    pub fn get_stats(&self) -> ContextWindowStats {
291        ContextWindowStats {
292            total_input_tokens: self.total_input_tokens,
293            total_output_tokens: self.total_output_tokens,
294            context_window_size: self.context_window_size,
295            current_usage: self.current_usage.clone(),
296        }
297    }
298
299    /// Get cache statistics.
300    ///
301    /// # Returns
302    ///
303    /// Statistics about cache usage
304    pub fn get_cache_stats(&self) -> CacheStats {
305        let total_cacheable = self.total_cache_creation_tokens + self.total_cache_read_tokens;
306        let cache_hit_rate = if total_cacheable > 0 {
307            self.total_cache_read_tokens as f64 / total_cacheable as f64
308        } else {
309            0.0
310        };
311
312        CacheStats {
313            total_cache_creation_tokens: self.total_cache_creation_tokens,
314            total_cache_read_tokens: self.total_cache_read_tokens,
315            cache_hit_rate,
316        }
317    }
318
319    /// Reset all statistics.
320    ///
321    /// Clears cumulative token counts and current usage.
322    pub fn reset(&mut self) {
323        self.total_input_tokens = 0;
324        self.total_output_tokens = 0;
325        self.total_cache_creation_tokens = 0;
326        self.total_cache_read_tokens = 0;
327        self.current_usage = None;
328    }
329
330    /// Get remaining available tokens.
331    ///
332    /// Calculates how many more input tokens can be used before
333    /// reaching the available context limit.
334    ///
335    /// # Returns
336    ///
337    /// Remaining available tokens (0 if limit exceeded)
338    pub fn get_remaining_tokens(&self) -> usize {
339        let available = self.get_available_context();
340        available.saturating_sub(self.total_input_tokens)
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_new_with_known_model() {
350        let manager = ContextWindowManager::new("claude-3-5-sonnet-20241022");
351        assert_eq!(manager.get_context_window_size(), 200_000);
352        assert_eq!(manager.get_model_id(), "claude-3-5-sonnet-20241022");
353    }
354
355    #[test]
356    fn test_new_with_unknown_model() {
357        let manager = ContextWindowManager::new("unknown-model");
358        // Should fall back to default
359        assert_eq!(manager.get_context_window_size(), 200_000);
360    }
361
362    #[test]
363    fn test_get_model_context_window() {
364        assert_eq!(
365            ContextWindowManager::get_model_context_window("claude-3-5-sonnet-20241022"),
366            200_000
367        );
368        assert_eq!(
369            ContextWindowManager::get_model_context_window("gpt-4o"),
370            128_000
371        );
372        assert_eq!(
373            ContextWindowManager::get_model_context_window("gpt-4"),
374            8_192
375        );
376    }
377
378    #[test]
379    fn test_calculate_available_context_small_window() {
380        // For small context (≤50k), reserve 20%
381        // gpt-4 has 8192 tokens
382        let available = ContextWindowManager::calculate_available_context("gpt-4");
383        // 8192 * 0.8 = 6553.6 ≈ 6553
384        assert_eq!(available, 6553);
385    }
386
387    #[test]
388    fn test_calculate_available_context_large_window() {
389        // For large context (>50k), reserve fixed 50k
390        let available =
391            ContextWindowManager::calculate_available_context("claude-3-5-sonnet-20241022");
392        // 200000 - 50000 = 150000
393        assert_eq!(available, 150_000);
394    }
395
396    #[test]
397    fn test_calculate_output_space_small_window() {
398        // For small context (≤50k), reserve 20%
399        let output_space = ContextWindowManager::calculate_output_space("gpt-4");
400        // 8192 * 0.2 = 1638.4 ≈ 1638
401        assert_eq!(output_space, 1638);
402    }
403
404    #[test]
405    fn test_calculate_output_space_large_window() {
406        // For large context (>50k), reserve fixed 50k
407        let output_space =
408            ContextWindowManager::calculate_output_space("claude-3-5-sonnet-20241022");
409        assert_eq!(output_space, 50_000);
410    }
411
412    #[test]
413    fn test_record_usage() {
414        let mut manager = ContextWindowManager::new("claude-3-5-sonnet-20241022");
415
416        let usage1 = TokenUsage::new(1000, 500);
417        manager.record_usage(usage1);
418
419        assert_eq!(manager.get_total_input_tokens(), 1000);
420        assert_eq!(manager.get_total_output_tokens(), 500);
421
422        let usage2 = TokenUsage::new(2000, 1000);
423        manager.record_usage(usage2);
424
425        assert_eq!(manager.get_total_input_tokens(), 3000);
426        assert_eq!(manager.get_total_output_tokens(), 1500);
427    }
428
429    #[test]
430    fn test_record_usage_with_cache() {
431        let mut manager = ContextWindowManager::new("claude-3-5-sonnet-20241022");
432
433        let usage = TokenUsage::with_cache(1000, 500, 200, 100);
434        manager.record_usage(usage);
435
436        let cache_stats = manager.get_cache_stats();
437        assert_eq!(cache_stats.total_cache_creation_tokens, 200);
438        assert_eq!(cache_stats.total_cache_read_tokens, 100);
439    }
440
441    #[test]
442    fn test_get_usage_percentage() {
443        let mut manager = ContextWindowManager::new("claude-3-5-sonnet-20241022");
444
445        // 200000 context window
446        let usage = TokenUsage::new(50000, 0);
447        manager.record_usage(usage);
448
449        // 50000 / 200000 = 25%
450        let percentage = manager.get_usage_percentage();
451        assert!((percentage - 25.0).abs() < 0.01);
452    }
453
454    #[test]
455    fn test_is_near_limit() {
456        let mut manager = ContextWindowManager::new("claude-3-5-sonnet-20241022");
457
458        // Add 70% of context window
459        let usage = TokenUsage::new(140000, 0);
460        manager.record_usage(usage);
461
462        assert!(manager.is_near_limit(70.0));
463        assert!(!manager.is_near_limit(80.0));
464    }
465
466    #[test]
467    fn test_update_model() {
468        let mut manager = ContextWindowManager::new("claude-3-5-sonnet-20241022");
469        assert_eq!(manager.get_context_window_size(), 200_000);
470
471        manager.update_model("gpt-4");
472        assert_eq!(manager.get_context_window_size(), 8_192);
473        assert_eq!(manager.get_model_id(), "gpt-4");
474    }
475
476    #[test]
477    fn test_get_stats() {
478        let mut manager = ContextWindowManager::new("claude-3-5-sonnet-20241022");
479
480        let usage = TokenUsage::new(1000, 500);
481        manager.record_usage(usage.clone());
482
483        let stats = manager.get_stats();
484        assert_eq!(stats.total_input_tokens, 1000);
485        assert_eq!(stats.total_output_tokens, 500);
486        assert_eq!(stats.context_window_size, 200_000);
487        assert!(stats.current_usage.is_some());
488    }
489
490    #[test]
491    fn test_reset() {
492        let mut manager = ContextWindowManager::new("claude-3-5-sonnet-20241022");
493
494        let usage = TokenUsage::with_cache(1000, 500, 200, 100);
495        manager.record_usage(usage);
496
497        manager.reset();
498
499        assert_eq!(manager.get_total_input_tokens(), 0);
500        assert_eq!(manager.get_total_output_tokens(), 0);
501        assert!(manager.get_current_usage().is_none());
502
503        let cache_stats = manager.get_cache_stats();
504        assert_eq!(cache_stats.total_cache_creation_tokens, 0);
505        assert_eq!(cache_stats.total_cache_read_tokens, 0);
506    }
507
508    #[test]
509    fn test_get_remaining_tokens() {
510        let mut manager = ContextWindowManager::new("claude-3-5-sonnet-20241022");
511
512        // Available = 200000 - 50000 = 150000
513        assert_eq!(manager.get_remaining_tokens(), 150_000);
514
515        let usage = TokenUsage::new(50000, 0);
516        manager.record_usage(usage);
517
518        // Remaining = 150000 - 50000 = 100000
519        assert_eq!(manager.get_remaining_tokens(), 100_000);
520    }
521
522    #[test]
523    fn test_cache_hit_rate() {
524        let mut manager = ContextWindowManager::new("claude-3-5-sonnet-20241022");
525
526        // First call: cache creation
527        let usage1 = TokenUsage::with_cache(1000, 500, 500, 0);
528        manager.record_usage(usage1);
529
530        // Second call: cache read
531        let usage2 = TokenUsage::with_cache(1000, 500, 0, 500);
532        manager.record_usage(usage2);
533
534        let cache_stats = manager.get_cache_stats();
535        // Total cacheable = 500 + 500 = 1000
536        // Cache read = 500
537        // Hit rate = 500 / 1000 = 0.5
538        assert!((cache_stats.cache_hit_rate - 0.5).abs() < 0.01);
539    }
540
541    #[test]
542    fn test_default() {
543        let manager = ContextWindowManager::default();
544        assert_eq!(manager.get_model_id(), "default");
545        assert_eq!(manager.get_context_window_size(), 200_000);
546    }
547
548    #[test]
549    fn test_boundary_50k() {
550        // Test exactly at 50k boundary
551        // gpt-3.5-turbo has 16385 tokens (< 50k)
552        let available = ContextWindowManager::calculate_available_context("gpt-3.5-turbo");
553        let output = ContextWindowManager::calculate_output_space("gpt-3.5-turbo");
554
555        // Should use percentage-based reservation
556        // 16385 * 0.8 = 13108
557        // 16385 * 0.2 = 3277
558        assert_eq!(available, 13108);
559        assert_eq!(output, 3277);
560    }
561}