agtrace_engine/domain/
token.rs1use crate::domain::model::SessionState;
2use agtrace_types::ModelLimitResolver;
3
4#[derive(Debug, Clone)]
5pub struct TokenLimit {
6 pub total_limit: u64,
7 pub compaction_buffer_pct: f64,
8}
9
10impl TokenLimit {
11 pub fn new(total_limit: u64, compaction_buffer_pct: f64) -> Self {
12 assert!(
13 (0.0..=100.0).contains(&compaction_buffer_pct),
14 "compaction_buffer_pct must be in range 0-100, got: {}",
15 compaction_buffer_pct
16 );
17
18 Self {
19 total_limit,
20 compaction_buffer_pct,
21 }
22 }
23
24 pub fn effective_limit(&self) -> u64 {
25 if self.compaction_buffer_pct == 0.0 {
26 return self.total_limit;
27 }
28
29 let usable_pct = 100.0 - self.compaction_buffer_pct;
30 let effective = (self.total_limit as f64 * usable_pct / 100.0).floor() as u64;
31
32 effective.max(1)
33 }
34}
35
36pub struct TokenLimits<R: ModelLimitResolver> {
37 resolver: R,
38}
39
40impl<R: ModelLimitResolver> TokenLimits<R> {
41 pub fn new(resolver: R) -> Self {
42 Self { resolver }
43 }
44
45 pub fn get_limit(&self, model: &str) -> Option<TokenLimit> {
46 self.resolver
47 .resolve_model_limit(model)
48 .map(|spec| TokenLimit::new(spec.max_tokens, spec.compaction_buffer_pct))
49 }
50
51 pub fn get_usage_percentage_from_state(&self, state: &SessionState) -> Option<(f64, f64, f64)> {
52 let limit_total = if let Some(l) = state.context_window_limit {
53 l
54 } else {
55 let model = state.model.as_ref()?;
56 self.get_limit(model)?.total_limit
57 };
58
59 let input_side = state.total_input_side_tokens() as u64;
60 let output_side = state.total_output_side_tokens() as u64;
61 let total = state.total_tokens().as_u64();
62
63 let input_pct = (input_side as f64 / limit_total as f64) * 100.0;
64 let output_pct = (output_side as f64 / limit_total as f64) * 100.0;
65 let total_pct = (total as f64 / limit_total as f64) * 100.0;
66
67 Some((input_pct, output_pct, total_pct))
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74 use crate::ContextWindowUsage;
75 use agtrace_types::ModelSpec;
76 use chrono::Utc;
77
78 struct MockResolver;
79
80 impl ModelLimitResolver for MockResolver {
81 fn resolve_model_limit(&self, model: &str) -> Option<ModelSpec> {
82 if model.starts_with("claude-3-5-sonnet") {
83 Some(ModelSpec {
84 max_tokens: 200_000,
85 compaction_buffer_pct: 22.5,
86 })
87 } else {
88 None
89 }
90 }
91 }
92
93 #[test]
94 fn test_get_limit_exact_match() {
95 let limits = TokenLimits::new(MockResolver);
96 let limit = limits.get_limit("claude-3-5-sonnet-20241022");
97 assert!(limit.is_some());
98 let limit = limit.unwrap();
99 assert_eq!(limit.total_limit, 200_000);
100 assert_eq!(limit.compaction_buffer_pct, 22.5);
101 assert_eq!(limit.effective_limit(), 155_000);
102 }
103
104 #[test]
105 fn test_get_limit_prefix_match() {
106 let limits = TokenLimits::new(MockResolver);
107 let limit = limits.get_limit("claude-3-5-sonnet");
108 assert!(limit.is_some());
109 }
110
111 #[test]
112 fn test_unknown_model() {
113 let limits = TokenLimits::new(MockResolver);
114 let result = limits.get_limit("unknown-model");
115 assert!(result.is_none());
116 }
117
118 #[test]
119 fn test_get_usage_percentage_from_state() {
120 let limits = TokenLimits::new(MockResolver);
121 let mut state = SessionState::new("test".to_string(), None, None, Utc::now());
122 state.model = Some("claude-3-5-sonnet-20241022".to_string());
123 state.current_usage = ContextWindowUsage::from_raw(1000, 2000, 10000, 500);
124
125 let (input_pct, output_pct, total_pct) =
126 limits.get_usage_percentage_from_state(&state).unwrap();
127
128 let eps = 1e-6;
129 assert!((input_pct - 6.5).abs() < eps);
130 assert!((output_pct - 0.25).abs() < eps);
131 assert!((total_pct - 6.75).abs() < eps);
132 }
133
134 #[test]
135 fn test_get_usage_percentage_from_state_no_cache() {
136 let limits = TokenLimits::new(MockResolver);
137 let mut state = SessionState::new("test".to_string(), None, None, Utc::now());
138 state.model = Some("claude-3-5-sonnet-20241022".to_string());
139 state.context_window_limit = Some(200_000);
140 state.current_usage = ContextWindowUsage::from_raw(100_000, 0, 0, 4_000);
141
142 let (input_pct, output_pct, total_pct) =
143 limits.get_usage_percentage_from_state(&state).unwrap();
144
145 assert_eq!(input_pct, 50.0);
146 assert_eq!(output_pct, 2.0);
147 assert_eq!(total_pct, 52.0);
148 }
149
150 #[test]
151 fn test_effective_limit() {
152 let limit = TokenLimit::new(200_000, 22.5);
153 assert_eq!(limit.effective_limit(), 155_000);
154
155 let limit_no_buffer = TokenLimit::new(400_000, 0.0);
156 assert_eq!(limit_no_buffer.effective_limit(), 400_000);
157
158 let limit_high_buffer = TokenLimit::new(1000, 99.0);
159 assert_eq!(limit_high_buffer.effective_limit(), 10);
160 }
161
162 #[test]
163 #[should_panic(expected = "compaction_buffer_pct must be in range 0-100")]
164 fn test_invalid_buffer_pct_negative() {
165 TokenLimit::new(200_000, -10.0);
166 }
167
168 #[test]
169 #[should_panic(expected = "compaction_buffer_pct must be in range 0-100")]
170 fn test_invalid_buffer_pct_over_100() {
171 TokenLimit::new(200_000, 150.0);
172 }
173}