1use serde::{Deserialize, Serialize};
7use std::ops::Add;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default, Serialize, Deserialize)]
11pub struct TokenCount(u64);
12
13impl TokenCount {
14 pub const fn new(value: u64) -> Self {
15 Self(value)
16 }
17
18 pub const fn zero() -> Self {
19 Self(0)
20 }
21
22 pub const fn as_u64(self) -> u64 {
23 self.0
24 }
25
26 pub fn saturating_add(self, other: Self) -> Self {
27 Self(self.0.saturating_add(other.0))
28 }
29
30 pub fn saturating_sub(self, other: Self) -> Self {
31 Self(self.0.saturating_sub(other.0))
32 }
33}
34
35impl From<u64> for TokenCount {
36 fn from(value: u64) -> Self {
37 Self(value)
38 }
39}
40
41impl Add for TokenCount {
42 type Output = Self;
43 fn add(self, rhs: Self) -> Self {
44 Self(self.0 + rhs.0)
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
50pub struct ContextLimit(u64);
51
52impl ContextLimit {
53 pub const fn new(value: u64) -> Self {
54 Self(value)
55 }
56
57 pub const fn as_u64(self) -> u64 {
58 self.0
59 }
60
61 pub fn usage_ratio(self, used: TokenCount) -> f64 {
63 if self.0 == 0 {
64 0.0
65 } else {
66 used.as_u64() as f64 / self.0 as f64
67 }
68 }
69
70 pub fn remaining(self, used: TokenCount) -> TokenCount {
72 TokenCount::new(self.0.saturating_sub(used.as_u64()))
73 }
74
75 pub fn is_exceeded(self, used: TokenCount) -> bool {
77 used.as_u64() >= self.0
78 }
79
80 pub fn is_warning_zone(self, used: TokenCount) -> bool {
82 self.usage_ratio(used) > 0.8
83 }
84
85 pub fn is_danger_zone(self, used: TokenCount) -> bool {
87 self.usage_ratio(used) > 0.9
88 }
89}
90
91impl From<u64> for ContextLimit {
92 fn from(value: u64) -> Self {
93 Self(value)
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
99pub struct FreshInputTokens(pub i32);
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
103pub struct CacheCreationTokens(pub i32);
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
107pub struct CacheReadTokens(pub i32);
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
111pub struct OutputTokens(pub i32);
112
113impl Add for FreshInputTokens {
114 type Output = Self;
115 fn add(self, rhs: Self) -> Self {
116 Self(self.0 + rhs.0)
117 }
118}
119
120impl Add for CacheCreationTokens {
121 type Output = Self;
122 fn add(self, rhs: Self) -> Self {
123 Self(self.0 + rhs.0)
124 }
125}
126
127impl Add for CacheReadTokens {
128 type Output = Self;
129 fn add(self, rhs: Self) -> Self {
130 Self(self.0 + rhs.0)
131 }
132}
133
134impl Add for OutputTokens {
135 type Output = Self;
136 fn add(self, rhs: Self) -> Self {
137 Self(self.0 + rhs.0)
138 }
139}
140
141#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
143pub struct ContextWindowUsage {
144 pub fresh_input: FreshInputTokens,
145 pub cache_creation: CacheCreationTokens,
146 pub cache_read: CacheReadTokens,
147 pub output: OutputTokens,
148}
149
150impl ContextWindowUsage {
151 pub fn new(
152 fresh_input: FreshInputTokens,
153 cache_creation: CacheCreationTokens,
154 cache_read: CacheReadTokens,
155 output: OutputTokens,
156 ) -> Self {
157 Self {
158 fresh_input,
159 cache_creation,
160 cache_read,
161 output,
162 }
163 }
164
165 pub fn from_raw(fresh_input: i32, cache_creation: i32, cache_read: i32, output: i32) -> Self {
166 Self {
167 fresh_input: FreshInputTokens(fresh_input),
168 cache_creation: CacheCreationTokens(cache_creation),
169 cache_read: CacheReadTokens(cache_read),
170 output: OutputTokens(output),
171 }
172 }
173
174 pub fn input_tokens(&self) -> i32 {
176 self.fresh_input.0 + self.cache_creation.0 + self.cache_read.0
177 }
178
179 pub fn output_tokens(&self) -> i32 {
181 self.output.0
182 }
183
184 pub fn total_tokens(&self) -> TokenCount {
186 let total = (self.input_tokens() + self.output_tokens()).max(0);
187 TokenCount::new(total as u64)
188 }
189
190 pub fn is_empty(&self) -> bool {
191 self.total_tokens() == TokenCount::zero()
192 }
193}
194
195impl Add for ContextWindowUsage {
196 type Output = Self;
197 fn add(self, rhs: Self) -> Self {
198 Self {
199 fresh_input: self.fresh_input + rhs.fresh_input,
200 cache_creation: self.cache_creation + rhs.cache_creation,
201 cache_read: self.cache_read + rhs.cache_read,
202 output: self.output + rhs.output,
203 }
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_context_window_usage_calculation() {
213 let usage = ContextWindowUsage::from_raw(100, 200, 300, 50);
214
215 assert_eq!(usage.input_tokens(), 600);
216 assert_eq!(usage.output_tokens(), 50);
217 assert_eq!(usage.total_tokens(), TokenCount::new(650));
218 }
219
220 #[test]
221 fn test_cache_read_always_included() {
222 let usage = ContextWindowUsage::from_raw(10, 20, 5000, 30);
223 assert_eq!(usage.total_tokens(), TokenCount::new(5060));
224 }
225
226 #[test]
227 fn test_add_usage() {
228 let usage1 = ContextWindowUsage::from_raw(100, 200, 300, 50);
229 let usage2 = ContextWindowUsage::from_raw(10, 20, 30, 5);
230
231 let total = usage1 + usage2;
232
233 assert_eq!(total.fresh_input.0, 110);
234 assert_eq!(total.cache_creation.0, 220);
235 assert_eq!(total.cache_read.0, 330);
236 assert_eq!(total.output.0, 55);
237 assert_eq!(total.total_tokens(), TokenCount::new(715));
238 }
239
240 #[test]
241 fn test_default_is_empty() {
242 let usage = ContextWindowUsage::default();
243 assert!(usage.is_empty());
244 assert_eq!(usage.total_tokens(), TokenCount::zero());
245 }
246
247 #[test]
248 fn test_token_count_operations() {
249 let count1 = TokenCount::new(100);
250 let count2 = TokenCount::new(50);
251
252 assert_eq!(count1 + count2, TokenCount::new(150));
253 assert_eq!(count1.saturating_sub(count2), TokenCount::new(50));
254 assert_eq!(count2.saturating_sub(count1), TokenCount::zero());
255 assert_eq!(count1.as_u64(), 100);
256 }
257
258 #[test]
259 fn test_context_limit_usage_ratio() {
260 let limit = ContextLimit::new(200_000);
261 let used = TokenCount::new(100_000);
262
263 assert_eq!(limit.usage_ratio(used), 0.5);
264 assert!(!limit.is_exceeded(used));
265 assert!(!limit.is_warning_zone(used));
266 assert!(!limit.is_danger_zone(used));
267 }
268
269 #[test]
270 fn test_context_limit_warning_zones() {
271 let limit = ContextLimit::new(100_000);
272
273 let used_normal = TokenCount::new(50_000);
274 assert!(!limit.is_warning_zone(used_normal));
275 assert!(!limit.is_danger_zone(used_normal));
276
277 let used_warning = TokenCount::new(85_000);
278 assert!(limit.is_warning_zone(used_warning));
279 assert!(!limit.is_danger_zone(used_warning));
280
281 let used_danger = TokenCount::new(95_000);
282 assert!(limit.is_warning_zone(used_danger));
283 assert!(limit.is_danger_zone(used_danger));
284
285 let used_exceeded = TokenCount::new(105_000);
286 assert!(limit.is_exceeded(used_exceeded));
287 }
288
289 #[test]
290 fn test_context_limit_remaining() {
291 let limit = ContextLimit::new(200_000);
292 let used = TokenCount::new(150_000);
293
294 assert_eq!(limit.remaining(used), TokenCount::new(50_000));
295
296 let used_over = TokenCount::new(250_000);
297 assert_eq!(limit.remaining(used_over), TokenCount::zero());
298 }
299
300 #[test]
301 fn test_context_window_usage_total_tokens() {
302 let usage = ContextWindowUsage::from_raw(100, 200, 300, 50);
303 assert_eq!(usage.total_tokens(), TokenCount::new(650));
304
305 let usage_negative = ContextWindowUsage::from_raw(-100, 0, 0, 0);
306 assert_eq!(usage_negative.total_tokens(), TokenCount::zero());
307 }
308}