agtrace_engine/domain/
token.rs

1use crate::domain::model::SessionState;
2use agtrace_types::ModelLimitResolver;
3
4#[derive(Debug, Clone)]
5pub struct TokenLimit {
6    pub total_limit: u64,
7    pub compaction_buffer_pct: f64,
8}
9
10impl TokenLimit {
11    pub fn new(total_limit: u64, compaction_buffer_pct: f64) -> Self {
12        assert!(
13            (0.0..=100.0).contains(&compaction_buffer_pct),
14            "compaction_buffer_pct must be in range 0-100, got: {}",
15            compaction_buffer_pct
16        );
17
18        Self {
19            total_limit,
20            compaction_buffer_pct,
21        }
22    }
23
24    pub fn effective_limit(&self) -> u64 {
25        if self.compaction_buffer_pct == 0.0 {
26            return self.total_limit;
27        }
28
29        let usable_pct = 100.0 - self.compaction_buffer_pct;
30        let effective = (self.total_limit as f64 * usable_pct / 100.0).floor() as u64;
31
32        effective.max(1)
33    }
34}
35
36pub struct TokenLimits<R: ModelLimitResolver> {
37    resolver: R,
38}
39
40impl<R: ModelLimitResolver> TokenLimits<R> {
41    pub fn new(resolver: R) -> Self {
42        Self { resolver }
43    }
44
45    pub fn get_limit(&self, model: &str) -> Option<TokenLimit> {
46        self.resolver
47            .resolve_model_limit(model)
48            .map(|spec| TokenLimit::new(spec.max_tokens, spec.compaction_buffer_pct))
49    }
50
51    pub fn get_usage_percentage_from_state(&self, state: &SessionState) -> Option<(f64, f64, f64)> {
52        let limit_total = if let Some(l) = state.context_window_limit {
53            l
54        } else {
55            let model = state.model.as_ref()?;
56            self.get_limit(model)?.total_limit
57        };
58
59        let input_side = state.total_input_side_tokens() as u64;
60        let output_side = state.total_output_side_tokens() as u64;
61        let total = state.total_tokens().as_u64();
62
63        let input_pct = (input_side as f64 / limit_total as f64) * 100.0;
64        let output_pct = (output_side as f64 / limit_total as f64) * 100.0;
65        let total_pct = (total as f64 / limit_total as f64) * 100.0;
66
67        Some((input_pct, output_pct, total_pct))
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74    use crate::ContextWindowUsage;
75    use agtrace_types::ModelSpec;
76    use chrono::Utc;
77
78    struct MockResolver;
79
80    impl ModelLimitResolver for MockResolver {
81        fn resolve_model_limit(&self, model: &str) -> Option<ModelSpec> {
82            if model.starts_with("claude-3-5-sonnet") {
83                Some(ModelSpec {
84                    max_tokens: 200_000,
85                    compaction_buffer_pct: 22.5,
86                })
87            } else {
88                None
89            }
90        }
91    }
92
93    #[test]
94    fn test_get_limit_exact_match() {
95        let limits = TokenLimits::new(MockResolver);
96        let limit = limits.get_limit("claude-3-5-sonnet-20241022");
97        assert!(limit.is_some());
98        let limit = limit.unwrap();
99        assert_eq!(limit.total_limit, 200_000);
100        assert_eq!(limit.compaction_buffer_pct, 22.5);
101        assert_eq!(limit.effective_limit(), 155_000);
102    }
103
104    #[test]
105    fn test_get_limit_prefix_match() {
106        let limits = TokenLimits::new(MockResolver);
107        let limit = limits.get_limit("claude-3-5-sonnet");
108        assert!(limit.is_some());
109    }
110
111    #[test]
112    fn test_unknown_model() {
113        let limits = TokenLimits::new(MockResolver);
114        let result = limits.get_limit("unknown-model");
115        assert!(result.is_none());
116    }
117
118    #[test]
119    fn test_get_usage_percentage_from_state() {
120        let limits = TokenLimits::new(MockResolver);
121        let mut state = SessionState::new("test".to_string(), None, None, Utc::now());
122        state.model = Some("claude-3-5-sonnet-20241022".to_string());
123        state.current_usage = ContextWindowUsage::from_raw(1000, 2000, 10000, 500);
124
125        let (input_pct, output_pct, total_pct) =
126            limits.get_usage_percentage_from_state(&state).unwrap();
127
128        let eps = 1e-6;
129        assert!((input_pct - 6.5).abs() < eps);
130        assert!((output_pct - 0.25).abs() < eps);
131        assert!((total_pct - 6.75).abs() < eps);
132    }
133
134    #[test]
135    fn test_get_usage_percentage_from_state_no_cache() {
136        let limits = TokenLimits::new(MockResolver);
137        let mut state = SessionState::new("test".to_string(), None, None, Utc::now());
138        state.model = Some("claude-3-5-sonnet-20241022".to_string());
139        state.context_window_limit = Some(200_000);
140        state.current_usage = ContextWindowUsage::from_raw(100_000, 0, 0, 4_000);
141
142        let (input_pct, output_pct, total_pct) =
143            limits.get_usage_percentage_from_state(&state).unwrap();
144
145        assert_eq!(input_pct, 50.0);
146        assert_eq!(output_pct, 2.0);
147        assert_eq!(total_pct, 52.0);
148    }
149
150    #[test]
151    fn test_effective_limit() {
152        let limit = TokenLimit::new(200_000, 22.5);
153        assert_eq!(limit.effective_limit(), 155_000);
154
155        let limit_no_buffer = TokenLimit::new(400_000, 0.0);
156        assert_eq!(limit_no_buffer.effective_limit(), 400_000);
157
158        let limit_high_buffer = TokenLimit::new(1000, 99.0);
159        assert_eq!(limit_high_buffer.effective_limit(), 10);
160    }
161
162    #[test]
163    #[should_panic(expected = "compaction_buffer_pct must be in range 0-100")]
164    fn test_invalid_buffer_pct_negative() {
165        TokenLimit::new(200_000, -10.0);
166    }
167
168    #[test]
169    #[should_panic(expected = "compaction_buffer_pct must be in range 0-100")]
170    fn test_invalid_buffer_pct_over_100() {
171        TokenLimit::new(200_000, 150.0);
172    }
173}