agtrace_engine/domain/
token.rs1use crate::domain::model::SessionState;
2use agtrace_types::ModelLimitResolver;
3
4#[derive(Debug, Clone)]
9pub struct TokenLimit {
10 pub total_limit: u64,
12 pub compaction_buffer_pct: f64,
14}
15
16impl TokenLimit {
17 pub fn new(total_limit: u64, compaction_buffer_pct: f64) -> Self {
22 assert!(
23 (0.0..=100.0).contains(&compaction_buffer_pct),
24 "compaction_buffer_pct must be in range 0-100, got: {}",
25 compaction_buffer_pct
26 );
27
28 Self {
29 total_limit,
30 compaction_buffer_pct,
31 }
32 }
33
34 pub fn effective_limit(&self) -> u64 {
39 if self.compaction_buffer_pct == 0.0 {
40 return self.total_limit;
41 }
42
43 let usable_pct = 100.0 - self.compaction_buffer_pct;
44 let effective = (self.total_limit as f64 * usable_pct / 100.0).floor() as u64;
45
46 effective.max(1)
47 }
48}
49
50pub struct TokenLimits<R: ModelLimitResolver> {
55 resolver: R,
56}
57
58impl<R: ModelLimitResolver> TokenLimits<R> {
59 pub fn new(resolver: R) -> Self {
61 Self { resolver }
62 }
63
64 pub fn get_limit(&self, model: &str) -> Option<TokenLimit> {
68 self.resolver
69 .resolve_model_limit(model)
70 .map(|spec| TokenLimit::new(spec.max_tokens, spec.compaction_buffer_pct))
71 }
72
73 pub fn get_usage_percentage_from_state(&self, state: &SessionState) -> Option<(f64, f64, f64)> {
79 let limit_total = if let Some(l) = state.context_window_limit {
80 l
81 } else {
82 let model = state.model.as_ref()?;
83 self.get_limit(model)?.total_limit
84 };
85
86 let input_side = state.total_input_side_tokens() as u64;
87 let output_side = state.total_output_side_tokens() as u64;
88 let total = state.total_tokens().as_u64();
89
90 let input_pct = (input_side as f64 / limit_total as f64) * 100.0;
91 let output_pct = (output_side as f64 / limit_total as f64) * 100.0;
92 let total_pct = (total as f64 / limit_total as f64) * 100.0;
93
94 Some((input_pct, output_pct, total_pct))
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use crate::ContextWindowUsage;
102 use agtrace_types::ModelSpec;
103 use chrono::Utc;
104
105 struct MockResolver;
106
107 impl ModelLimitResolver for MockResolver {
108 fn resolve_model_limit(&self, model: &str) -> Option<ModelSpec> {
109 if model.starts_with("claude-3-5-sonnet") {
110 Some(ModelSpec {
111 max_tokens: 200_000,
112 compaction_buffer_pct: 22.5,
113 })
114 } else {
115 None
116 }
117 }
118 }
119
120 #[test]
121 fn test_get_limit_exact_match() {
122 let limits = TokenLimits::new(MockResolver);
123 let limit = limits.get_limit("claude-3-5-sonnet-20241022");
124 assert!(limit.is_some());
125 let limit = limit.unwrap();
126 assert_eq!(limit.total_limit, 200_000);
127 assert_eq!(limit.compaction_buffer_pct, 22.5);
128 assert_eq!(limit.effective_limit(), 155_000);
129 }
130
131 #[test]
132 fn test_get_limit_prefix_match() {
133 let limits = TokenLimits::new(MockResolver);
134 let limit = limits.get_limit("claude-3-5-sonnet");
135 assert!(limit.is_some());
136 }
137
138 #[test]
139 fn test_unknown_model() {
140 let limits = TokenLimits::new(MockResolver);
141 let result = limits.get_limit("unknown-model");
142 assert!(result.is_none());
143 }
144
145 #[test]
146 fn test_get_usage_percentage_from_state() {
147 let limits = TokenLimits::new(MockResolver);
148 let mut state = SessionState::new("test".to_string(), None, None, Utc::now());
149 state.model = Some("claude-3-5-sonnet-20241022".to_string());
150 state.current_usage = ContextWindowUsage::from_raw(1000, 0, 12000, 500);
151
152 let (input_pct, output_pct, total_pct) =
153 limits.get_usage_percentage_from_state(&state).unwrap();
154
155 let eps = 1e-6;
156 assert!((input_pct - 6.5).abs() < eps);
157 assert!((output_pct - 0.25).abs() < eps);
158 assert!((total_pct - 6.75).abs() < eps);
159 }
160
161 #[test]
162 fn test_get_usage_percentage_from_state_no_cache() {
163 let limits = TokenLimits::new(MockResolver);
164 let mut state = SessionState::new("test".to_string(), None, None, Utc::now());
165 state.model = Some("claude-3-5-sonnet-20241022".to_string());
166 state.context_window_limit = Some(200_000);
167 state.current_usage = ContextWindowUsage::from_raw(100_000, 0, 0, 4_000);
168
169 let (input_pct, output_pct, total_pct) =
170 limits.get_usage_percentage_from_state(&state).unwrap();
171
172 assert_eq!(input_pct, 50.0);
173 assert_eq!(output_pct, 2.0);
174 assert_eq!(total_pct, 52.0);
175 }
176
177 #[test]
178 fn test_effective_limit() {
179 let limit = TokenLimit::new(200_000, 22.5);
180 assert_eq!(limit.effective_limit(), 155_000);
181
182 let limit_no_buffer = TokenLimit::new(400_000, 0.0);
183 assert_eq!(limit_no_buffer.effective_limit(), 400_000);
184
185 let limit_high_buffer = TokenLimit::new(1000, 99.0);
186 assert_eq!(limit_high_buffer.effective_limit(), 10);
187 }
188
189 #[test]
190 #[should_panic(expected = "compaction_buffer_pct must be in range 0-100")]
191 fn test_invalid_buffer_pct_negative() {
192 TokenLimit::new(200_000, -10.0);
193 }
194
195 #[test]
196 #[should_panic(expected = "compaction_buffer_pct must be in range 0-100")]
197 fn test_invalid_buffer_pct_over_100() {
198 TokenLimit::new(200_000, 150.0);
199 }
200}