1use crate::context::token_estimator::TokenEstimator;
21use crate::context::types::{CacheConfig, CacheSavings, CacheStats, TokenUsage};
22use crate::conversation::message::Message;
23
24const BASE_INPUT_PRICE_PER_MILLION: f64 = 3.0;
27
28const CACHE_WRITE_MULTIPLIER: f64 = 1.25;
30
31const CACHE_READ_MULTIPLIER: f64 = 0.1;
33
34#[derive(Debug, Clone, Default)]
36pub struct CacheEligibility {
37 pub cacheable_indices: Vec<usize>,
39 pub cacheable_tokens: usize,
41}
42
43pub struct CacheController;
54
55impl CacheController {
56 pub fn get_cache_eligibility(messages: &[Message], config: &CacheConfig) -> CacheEligibility {
88 if messages.is_empty() {
89 return CacheEligibility::default();
90 }
91
92 let len = messages.len();
93 let mut cacheable_indices = Vec::new();
94 let mut cacheable_tokens = 0;
95
96 let start_index = len.saturating_sub(config.cache_recent_messages);
99
100 for (i, message) in messages.iter().enumerate().take(len).skip(start_index) {
102 if Self::is_cacheable(message, config.min_tokens_for_cache) {
103 let tokens = TokenEstimator::estimate_message_tokens(message);
104 cacheable_indices.push(i);
105 cacheable_tokens += tokens;
106 }
107 }
108
109 CacheEligibility {
110 cacheable_indices,
111 cacheable_tokens,
112 }
113 }
114
115 pub fn add_cache_control(
131 messages: &[Message],
132 config: &CacheConfig,
133 ) -> (Vec<Message>, Vec<usize>) {
134 let eligibility = Self::get_cache_eligibility(messages, config);
135 (messages.to_vec(), eligibility.cacheable_indices)
136 }
137
138 pub fn is_cacheable(message: &Message, min_tokens: usize) -> bool {
153 if message.content.is_empty() {
154 return false;
155 }
156
157 let tokens = TokenEstimator::estimate_message_tokens(message);
158 tokens >= min_tokens
159 }
160
161 pub fn calculate_cache_savings(usage: &TokenUsage) -> CacheSavings {
186 let base_price = BASE_INPUT_PRICE_PER_MILLION / 1_000_000.0;
187
188 let base_cost = usage.input_tokens as f64 * base_price;
190
191 let cache_creation_tokens = usage.cache_creation_tokens.unwrap_or(0);
193 let cache_read_tokens = usage.cache_read_tokens.unwrap_or(0);
194
195 let cache_write_cost = cache_creation_tokens as f64 * base_price * CACHE_WRITE_MULTIPLIER;
197
198 let cache_read_cost = cache_read_tokens as f64 * base_price * CACHE_READ_MULTIPLIER;
200
201 let non_cached_tokens = usage.input_tokens.saturating_sub(cache_read_tokens);
203 let non_cached_cost = non_cached_tokens as f64 * base_price;
204
205 let actual_cost = non_cached_cost + cache_write_cost + cache_read_cost;
207
208 CacheSavings::new(base_cost, actual_cost)
209 }
210
211 pub fn calculate_cache_stats(usage: &TokenUsage) -> CacheStats {
221 let cache_creation = usage.cache_creation_tokens.unwrap_or(0);
222 let cache_read = usage.cache_read_tokens.unwrap_or(0);
223
224 let total_cache_tokens = cache_creation + cache_read;
225 let hit_rate = if total_cache_tokens > 0 {
226 cache_read as f64 / total_cache_tokens as f64
227 } else {
228 0.0
229 };
230
231 CacheStats {
232 total_cache_creation_tokens: cache_creation,
233 total_cache_read_tokens: cache_read,
234 cache_hit_rate: hit_rate,
235 }
236 }
237
238 pub fn accumulate_cache_stats<'a>(usages: impl Iterator<Item = &'a TokenUsage>) -> CacheStats {
248 let mut total_creation = 0usize;
249 let mut total_read = 0usize;
250
251 for usage in usages {
252 total_creation += usage.cache_creation_tokens.unwrap_or(0);
253 total_read += usage.cache_read_tokens.unwrap_or(0);
254 }
255
256 let total = total_creation + total_read;
257 let hit_rate = if total > 0 {
258 total_read as f64 / total as f64
259 } else {
260 0.0
261 };
262
263 CacheStats {
264 total_cache_creation_tokens: total_creation,
265 total_cache_read_tokens: total_read,
266 cache_hit_rate: hit_rate,
267 }
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 fn create_message_with_tokens(text: &str) -> Message {
276 Message::user().with_text(text)
277 }
278
279 fn create_long_message() -> Message {
280 let long_text = "x".repeat(4000); Message::user().with_text(long_text)
283 }
284
285 fn create_short_message() -> Message {
286 Message::user().with_text("Hello")
287 }
288
289 #[test]
290 fn test_is_cacheable_empty_message() {
291 let message = Message::user();
292 assert!(!CacheController::is_cacheable(&message, 1024));
293 }
294
295 #[test]
296 fn test_is_cacheable_short_message() {
297 let message = create_short_message();
298 assert!(!CacheController::is_cacheable(&message, 1024));
299 }
300
301 #[test]
302 fn test_is_cacheable_long_message() {
303 let message = create_long_message();
304 assert!(CacheController::is_cacheable(&message, 1024));
305 }
306
307 #[test]
308 fn test_is_cacheable_with_low_threshold() {
309 let message = create_short_message();
310 assert!(CacheController::is_cacheable(&message, 1));
312 }
313
314 #[test]
315 fn test_add_cache_control_empty_messages() {
316 let messages: Vec<Message> = vec![];
317 let config = CacheConfig::default();
318 let (result, indices) = CacheController::add_cache_control(&messages, &config);
319 assert!(result.is_empty());
320 assert!(indices.is_empty());
321 }
322
323 #[test]
324 fn test_add_cache_control_respects_recent_limit() {
325 let messages: Vec<Message> = (0..5).map(|_| create_long_message()).collect();
327
328 let config = CacheConfig {
329 cache_recent_messages: 2,
330 min_tokens_for_cache: 100, ..Default::default()
332 };
333
334 let (result, indices) = CacheController::add_cache_control(&messages, &config);
335
336 assert_eq!(result.len(), 5);
338 assert!(indices.iter().all(|&i| i >= 3));
340 }
341
342 #[test]
343 fn test_add_cache_control_respects_token_threshold() {
344 let messages = vec![create_short_message(), create_long_message()];
345
346 let config = CacheConfig::default();
347 let (result, indices) = CacheController::add_cache_control(&messages, &config);
348
349 assert_eq!(result.len(), 2);
351 assert!(indices.contains(&1) || indices.is_empty());
353 }
354
355 #[test]
356 fn test_get_cache_eligibility_empty() {
357 let messages: Vec<Message> = vec![];
358 let config = CacheConfig::default();
359 let eligibility = CacheController::get_cache_eligibility(&messages, &config);
360 assert!(eligibility.cacheable_indices.is_empty());
361 assert_eq!(eligibility.cacheable_tokens, 0);
362 }
363
364 #[test]
365 fn test_get_cache_eligibility_with_long_messages() {
366 let messages: Vec<Message> = (0..3).map(|_| create_long_message()).collect();
367
368 let config = CacheConfig {
369 min_tokens_for_cache: 100,
370 cache_recent_messages: 10,
371 ..Default::default()
372 };
373
374 let eligibility = CacheController::get_cache_eligibility(&messages, &config);
375
376 assert_eq!(eligibility.cacheable_indices.len(), 3);
378 assert!(eligibility.cacheable_tokens > 0);
379 }
380
381 #[test]
382 fn test_calculate_cache_savings_no_cache() {
383 let usage = TokenUsage::new(1000, 500);
384 let savings = CacheController::calculate_cache_savings(&usage);
385
386 assert!((savings.base_cost - savings.cache_cost).abs() < 0.0001);
388 assert!(savings.savings.abs() < 0.0001);
389 }
390
391 #[test]
392 fn test_calculate_cache_savings_with_cache_read() {
393 let usage = TokenUsage {
394 input_tokens: 1000,
395 output_tokens: 500,
396 cache_creation_tokens: Some(0),
397 cache_read_tokens: Some(800), thinking_tokens: None,
399 };
400
401 let savings = CacheController::calculate_cache_savings(&usage);
402
403 assert!(savings.savings > 0.0);
405 assert!(savings.cache_cost < savings.base_cost);
406 }
407
408 #[test]
409 fn test_calculate_cache_savings_with_cache_write() {
410 let usage = TokenUsage {
411 input_tokens: 1000,
412 output_tokens: 500,
413 cache_creation_tokens: Some(500), cache_read_tokens: Some(0),
415 thinking_tokens: None,
416 };
417
418 let savings = CacheController::calculate_cache_savings(&usage);
419
420 assert!(savings.savings < 0.0);
422 }
423
424 #[test]
425 fn test_calculate_cache_savings_mixed() {
426 let usage = TokenUsage {
427 input_tokens: 10000,
428 output_tokens: 1000,
429 cache_creation_tokens: Some(1000), cache_read_tokens: Some(8000), thinking_tokens: None,
432 };
433
434 let savings = CacheController::calculate_cache_savings(&usage);
435
436 assert!(savings.savings > 0.0);
438 assert!(savings.savings_percentage() > 0.0);
439 }
440
441 #[test]
442 fn test_calculate_cache_stats_no_cache() {
443 let usage = TokenUsage::new(1000, 500);
444 let stats = CacheController::calculate_cache_stats(&usage);
445
446 assert_eq!(stats.total_cache_creation_tokens, 0);
447 assert_eq!(stats.total_cache_read_tokens, 0);
448 assert_eq!(stats.cache_hit_rate, 0.0);
449 }
450
451 #[test]
452 fn test_calculate_cache_stats_with_cache() {
453 let usage = TokenUsage {
454 input_tokens: 1000,
455 output_tokens: 500,
456 cache_creation_tokens: Some(200),
457 cache_read_tokens: Some(800),
458 thinking_tokens: None,
459 };
460
461 let stats = CacheController::calculate_cache_stats(&usage);
462
463 assert_eq!(stats.total_cache_creation_tokens, 200);
464 assert_eq!(stats.total_cache_read_tokens, 800);
465 assert!((stats.cache_hit_rate - 0.8).abs() < 0.001);
466 }
467
468 #[test]
469 fn test_accumulate_cache_stats() {
470 let usages = [
471 TokenUsage {
472 input_tokens: 1000,
473 output_tokens: 500,
474 cache_creation_tokens: Some(100),
475 cache_read_tokens: Some(400),
476 thinking_tokens: None,
477 },
478 TokenUsage {
479 input_tokens: 2000,
480 output_tokens: 1000,
481 cache_creation_tokens: Some(200),
482 cache_read_tokens: Some(600),
483 thinking_tokens: None,
484 },
485 ];
486
487 let stats = CacheController::accumulate_cache_stats(usages.iter());
488
489 assert_eq!(stats.total_cache_creation_tokens, 300);
490 assert_eq!(stats.total_cache_read_tokens, 1000);
491 assert!((stats.cache_hit_rate - 0.769).abs() < 0.01);
493 }
494
495 #[test]
496 fn test_cache_savings_percentage() {
497 let savings = CacheSavings::new(100.0, 60.0);
498 assert!((savings.savings_percentage() - 40.0).abs() < 0.001);
499 }
500
501 #[test]
502 fn test_cache_savings_percentage_zero_base() {
503 let savings = CacheSavings::new(0.0, 0.0);
504 assert_eq!(savings.savings_percentage(), 0.0);
505 }
506}