agtrace_runtime/domain/
token.rs

1use crate::domain::model::SessionState;
2
3#[derive(Debug, Clone)]
4pub struct TokenLimit {
5    pub total_limit: u64,
6    pub compaction_buffer_pct: f64,
7}
8
9impl TokenLimit {
10    pub fn new(total_limit: u64, compaction_buffer_pct: f64) -> Self {
11        assert!(
12            (0.0..=100.0).contains(&compaction_buffer_pct),
13            "compaction_buffer_pct must be in range 0-100, got: {}",
14            compaction_buffer_pct
15        );
16
17        Self {
18            total_limit,
19            compaction_buffer_pct,
20        }
21    }
22
23    pub fn effective_limit(&self) -> u64 {
24        if self.compaction_buffer_pct == 0.0 {
25            return self.total_limit;
26        }
27
28        let usable_pct = 100.0 - self.compaction_buffer_pct;
29        let effective = (self.total_limit as f64 * usable_pct / 100.0).floor() as u64;
30
31        effective.max(1)
32    }
33}
34
35pub struct TokenLimits;
36
37impl TokenLimits {
38    pub fn new() -> Self {
39        Self
40    }
41
42    pub fn get_limit(&self, model: &str) -> Option<TokenLimit> {
43        agtrace_providers::token_limits::resolve_model_limit(model)
44            .map(|spec| TokenLimit::new(spec.max_tokens, spec.compaction_buffer_pct))
45    }
46
47    pub fn get_usage_percentage_from_state(&self, state: &SessionState) -> Option<(f64, f64, f64)> {
48        let limit_total = if let Some(l) = state.context_window_limit {
49            l
50        } else {
51            let model = state.model.as_ref()?;
52            self.get_limit(model)?.total_limit
53        };
54
55        let input_side = state.total_input_side_tokens() as u64;
56        let output_side = state.total_output_side_tokens() as u64;
57        let total = state.total_context_window_tokens() as u64;
58
59        let input_pct = (input_side as f64 / limit_total as f64) * 100.0;
60        let output_pct = (output_side as f64 / limit_total as f64) * 100.0;
61        let total_pct = (total as f64 / limit_total as f64) * 100.0;
62
63        Some((input_pct, output_pct, total_pct))
64    }
65}
66
67impl Default for TokenLimits {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76    use agtrace_engine::ContextWindowUsage;
77    use chrono::Utc;
78
79    #[test]
80    fn test_get_limit_exact_match() {
81        let limits = TokenLimits::new();
82        let limit = limits.get_limit("claude-3-5-sonnet-20241022");
83        assert!(limit.is_some());
84        let limit = limit.unwrap();
85        assert_eq!(limit.total_limit, 200_000);
86        assert_eq!(limit.compaction_buffer_pct, 22.5);
87        assert_eq!(limit.effective_limit(), 155_000);
88    }
89
90    #[test]
91    fn test_get_limit_prefix_match() {
92        let limits = TokenLimits::new();
93        let limit = limits.get_limit("claude-3-5-sonnet");
94        assert!(limit.is_some());
95    }
96
97    #[test]
98    fn test_unknown_model() {
99        let limits = TokenLimits::new();
100        let result = limits.get_limit("unknown-model");
101        assert!(result.is_none());
102    }
103
104    #[test]
105    fn test_get_usage_percentage_from_state() {
106        let limits = TokenLimits::new();
107        let mut state = SessionState::new("test".to_string(), None, Utc::now());
108        state.model = Some("claude-3-5-sonnet-20241022".to_string());
109        state.current_usage = ContextWindowUsage::from_raw(1000, 2000, 10000, 500);
110
111        let (input_pct, output_pct, total_pct) =
112            limits.get_usage_percentage_from_state(&state).unwrap();
113
114        let eps = 1e-6;
115        assert!((input_pct - 6.5).abs() < eps);
116        assert!((output_pct - 0.25).abs() < eps);
117        assert!((total_pct - 6.75).abs() < eps);
118    }
119
120    #[test]
121    fn test_get_usage_percentage_from_state_no_cache() {
122        let limits = TokenLimits::new();
123        let mut state = SessionState::new("test".to_string(), None, Utc::now());
124        state.model = Some("claude-3-5-sonnet-20241022".to_string());
125        state.context_window_limit = Some(200_000);
126        state.current_usage = ContextWindowUsage::from_raw(100_000, 0, 0, 4_000);
127
128        let (input_pct, output_pct, total_pct) =
129            limits.get_usage_percentage_from_state(&state).unwrap();
130
131        assert_eq!(input_pct, 50.0);
132        assert_eq!(output_pct, 2.0);
133        assert_eq!(total_pct, 52.0);
134    }
135
136    #[test]
137    fn test_effective_limit() {
138        let limit = TokenLimit::new(200_000, 22.5);
139        assert_eq!(limit.effective_limit(), 155_000);
140
141        let limit_no_buffer = TokenLimit::new(400_000, 0.0);
142        assert_eq!(limit_no_buffer.effective_limit(), 400_000);
143
144        let limit_high_buffer = TokenLimit::new(1000, 99.0);
145        assert_eq!(limit_high_buffer.effective_limit(), 10);
146    }
147
148    #[test]
149    #[should_panic(expected = "compaction_buffer_pct must be in range 0-100")]
150    fn test_invalid_buffer_pct_negative() {
151        TokenLimit::new(200_000, -10.0);
152    }
153
154    #[test]
155    #[should_panic(expected = "compaction_buffer_pct must be in range 0-100")]
156    fn test_invalid_buffer_pct_over_100() {
157        TokenLimit::new(200_000, 150.0);
158    }
159}