1use serde::{Deserialize, Serialize};
7use std::ops::Add;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default, Serialize, Deserialize)]
11pub struct TokenCount(u64);
12
13impl TokenCount {
14 pub const fn new(value: u64) -> Self {
15 Self(value)
16 }
17
18 pub const fn zero() -> Self {
19 Self(0)
20 }
21
22 pub const fn as_u64(self) -> u64 {
23 self.0
24 }
25
26 pub fn saturating_add(self, other: Self) -> Self {
27 Self(self.0.saturating_add(other.0))
28 }
29
30 pub fn saturating_sub(self, other: Self) -> Self {
31 Self(self.0.saturating_sub(other.0))
32 }
33}
34
35impl From<u64> for TokenCount {
36 fn from(value: u64) -> Self {
37 Self(value)
38 }
39}
40
41impl Add for TokenCount {
42 type Output = Self;
43 fn add(self, rhs: Self) -> Self {
44 Self(self.0 + rhs.0)
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
50pub struct ContextLimit(u64);
51
52impl ContextLimit {
53 pub const fn new(value: u64) -> Self {
54 Self(value)
55 }
56
57 pub const fn as_u64(self) -> u64 {
58 self.0
59 }
60
61 pub fn usage_ratio(self, used: TokenCount) -> f64 {
63 if self.0 == 0 {
64 0.0
65 } else {
66 used.as_u64() as f64 / self.0 as f64
67 }
68 }
69
70 pub fn remaining(self, used: TokenCount) -> TokenCount {
72 TokenCount::new(self.0.saturating_sub(used.as_u64()))
73 }
74
75 pub fn is_exceeded(self, used: TokenCount) -> bool {
77 used.as_u64() >= self.0
78 }
79
80 pub fn is_warning_zone(self, used: TokenCount) -> bool {
82 self.usage_ratio(used) > 0.8
83 }
84
85 pub fn is_danger_zone(self, used: TokenCount) -> bool {
87 self.usage_ratio(used) > 0.9
88 }
89}
90
91impl From<u64> for ContextLimit {
92 fn from(value: u64) -> Self {
93 Self(value)
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
99pub struct FreshInputTokens(pub i32);
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
103pub struct CacheCreationTokens(pub i32);
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
107pub struct CacheReadTokens(pub i32);
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
111pub struct OutputTokens(pub i32);
112
113impl Add for FreshInputTokens {
114 type Output = Self;
115 fn add(self, rhs: Self) -> Self {
116 Self(self.0 + rhs.0)
117 }
118}
119
120impl Add for CacheCreationTokens {
121 type Output = Self;
122 fn add(self, rhs: Self) -> Self {
123 Self(self.0 + rhs.0)
124 }
125}
126
127impl Add for CacheReadTokens {
128 type Output = Self;
129 fn add(self, rhs: Self) -> Self {
130 Self(self.0 + rhs.0)
131 }
132}
133
134impl Add for OutputTokens {
135 type Output = Self;
136 fn add(self, rhs: Self) -> Self {
137 Self(self.0 + rhs.0)
138 }
139}
140
141#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
143pub struct ContextWindowUsage {
144 pub fresh_input: FreshInputTokens,
145 pub cache_creation: CacheCreationTokens,
146 pub cache_read: CacheReadTokens,
147 pub output: OutputTokens,
148}
149
150impl ContextWindowUsage {
151 pub fn new(
152 fresh_input: FreshInputTokens,
153 cache_creation: CacheCreationTokens,
154 cache_read: CacheReadTokens,
155 output: OutputTokens,
156 ) -> Self {
157 Self {
158 fresh_input,
159 cache_creation,
160 cache_read,
161 output,
162 }
163 }
164
165 pub fn from_raw(fresh_input: i32, cache_creation: i32, cache_read: i32, output: i32) -> Self {
166 Self {
167 fresh_input: FreshInputTokens(fresh_input),
168 cache_creation: CacheCreationTokens(cache_creation),
169 cache_read: CacheReadTokens(cache_read),
170 output: OutputTokens(output),
171 }
172 }
173
174 pub fn input_tokens(&self) -> i32 {
176 self.fresh_input.0 + self.cache_creation.0 + self.cache_read.0
177 }
178
179 pub fn output_tokens(&self) -> i32 {
181 self.output.0
182 }
183
184 pub fn context_window_tokens(&self) -> i32 {
186 self.input_tokens() + self.output_tokens()
187 }
188
189 pub fn total_tokens(&self) -> TokenCount {
191 let total = self.context_window_tokens().max(0);
192 TokenCount::new(total as u64)
193 }
194
195 pub fn is_empty(&self) -> bool {
196 self.context_window_tokens() == 0
197 }
198}
199
200impl Add for ContextWindowUsage {
201 type Output = Self;
202 fn add(self, rhs: Self) -> Self {
203 Self {
204 fresh_input: self.fresh_input + rhs.fresh_input,
205 cache_creation: self.cache_creation + rhs.cache_creation,
206 cache_read: self.cache_read + rhs.cache_read,
207 output: self.output + rhs.output,
208 }
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn test_context_window_usage_calculation() {
218 let usage = ContextWindowUsage::from_raw(100, 200, 300, 50);
219
220 assert_eq!(usage.input_tokens(), 600);
221 assert_eq!(usage.output_tokens(), 50);
222 assert_eq!(usage.context_window_tokens(), 650);
223 }
224
225 #[test]
226 fn test_cache_read_always_included() {
227 let usage = ContextWindowUsage::from_raw(10, 20, 5000, 30);
228 assert_eq!(usage.context_window_tokens(), 5060);
229 }
230
231 #[test]
232 fn test_add_usage() {
233 let usage1 = ContextWindowUsage::from_raw(100, 200, 300, 50);
234 let usage2 = ContextWindowUsage::from_raw(10, 20, 30, 5);
235
236 let total = usage1 + usage2;
237
238 assert_eq!(total.fresh_input.0, 110);
239 assert_eq!(total.cache_creation.0, 220);
240 assert_eq!(total.cache_read.0, 330);
241 assert_eq!(total.output.0, 55);
242 assert_eq!(total.context_window_tokens(), 715);
243 }
244
245 #[test]
246 fn test_default_is_empty() {
247 let usage = ContextWindowUsage::default();
248 assert!(usage.is_empty());
249 assert_eq!(usage.context_window_tokens(), 0);
250 }
251
252 #[test]
253 fn test_token_count_operations() {
254 let count1 = TokenCount::new(100);
255 let count2 = TokenCount::new(50);
256
257 assert_eq!(count1 + count2, TokenCount::new(150));
258 assert_eq!(count1.saturating_sub(count2), TokenCount::new(50));
259 assert_eq!(count2.saturating_sub(count1), TokenCount::zero());
260 assert_eq!(count1.as_u64(), 100);
261 }
262
263 #[test]
264 fn test_context_limit_usage_ratio() {
265 let limit = ContextLimit::new(200_000);
266 let used = TokenCount::new(100_000);
267
268 assert_eq!(limit.usage_ratio(used), 0.5);
269 assert!(!limit.is_exceeded(used));
270 assert!(!limit.is_warning_zone(used));
271 assert!(!limit.is_danger_zone(used));
272 }
273
274 #[test]
275 fn test_context_limit_warning_zones() {
276 let limit = ContextLimit::new(100_000);
277
278 let used_normal = TokenCount::new(50_000);
279 assert!(!limit.is_warning_zone(used_normal));
280 assert!(!limit.is_danger_zone(used_normal));
281
282 let used_warning = TokenCount::new(85_000);
283 assert!(limit.is_warning_zone(used_warning));
284 assert!(!limit.is_danger_zone(used_warning));
285
286 let used_danger = TokenCount::new(95_000);
287 assert!(limit.is_warning_zone(used_danger));
288 assert!(limit.is_danger_zone(used_danger));
289
290 let used_exceeded = TokenCount::new(105_000);
291 assert!(limit.is_exceeded(used_exceeded));
292 }
293
294 #[test]
295 fn test_context_limit_remaining() {
296 let limit = ContextLimit::new(200_000);
297 let used = TokenCount::new(150_000);
298
299 assert_eq!(limit.remaining(used), TokenCount::new(50_000));
300
301 let used_over = TokenCount::new(250_000);
302 assert_eq!(limit.remaining(used_over), TokenCount::zero());
303 }
304
305 #[test]
306 fn test_context_window_usage_total_tokens() {
307 let usage = ContextWindowUsage::from_raw(100, 200, 300, 50);
308 assert_eq!(usage.total_tokens(), TokenCount::new(650));
309
310 let usage_negative = ContextWindowUsage::from_raw(-100, 0, 0, 0);
311 assert_eq!(usage_negative.total_tokens(), TokenCount::zero());
312 }
313}