1use crate::traits::{CacheControl, ChatMessage, ChatRole};
39use serde::{Deserialize, Serialize};
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct CachePromptConfig {
68 pub enabled: bool,
70
71 pub min_content_length: usize,
76
77 pub cache_system_prompt: bool,
81
82 pub cache_last_n_messages: usize,
86}
87
88impl Default for CachePromptConfig {
89 fn default() -> Self {
90 Self {
91 enabled: true,
92 min_content_length: 1000,
93 cache_system_prompt: true,
94 cache_last_n_messages: 3,
95 }
96 }
97}
98
99impl CachePromptConfig {
100 pub fn disabled() -> Self {
102 Self {
103 enabled: false,
104 ..Default::default()
105 }
106 }
107
108 pub fn system_only() -> Self {
110 Self {
111 enabled: true,
112 min_content_length: usize::MAX,
113 cache_system_prompt: true,
114 cache_last_n_messages: 0,
115 }
116 }
117
118 pub fn aggressive() -> Self {
122 Self {
123 enabled: true,
124 min_content_length: 100,
125 cache_system_prompt: true,
126 cache_last_n_messages: 10,
127 }
128 }
129}
130
131#[derive(Debug, Clone, Default, Serialize, Deserialize)]
159pub struct CacheStats {
160 pub input_tokens: u64,
162
163 pub output_tokens: u64,
165
166 pub cache_read_tokens: u64,
168
169 pub cache_creation_tokens: u64,
171}
172
173impl CacheStats {
174 pub fn new(
176 input_tokens: u64,
177 output_tokens: u64,
178 cache_read_tokens: u64,
179 cache_creation_tokens: u64,
180 ) -> Self {
181 Self {
182 input_tokens,
183 output_tokens,
184 cache_read_tokens,
185 cache_creation_tokens,
186 }
187 }
188
189 pub fn cache_hit_rate(&self) -> f64 {
193 if self.input_tokens == 0 {
194 0.0
195 } else {
196 self.cache_read_tokens as f64 / self.input_tokens as f64
197 }
198 }
199
200 pub fn savings(&self) -> f64 {
209 const NORMAL_COST_PER_1K: f64 = 0.003;
210 const CACHE_COST_PER_1K: f64 = 0.0003;
211
212 let normal_cost = self.input_tokens as f64 * NORMAL_COST_PER_1K / 1000.0;
214
215 let uncached_tokens = self.input_tokens.saturating_sub(self.cache_read_tokens);
217 let cache_cost = self.cache_read_tokens as f64 * CACHE_COST_PER_1K / 1000.0
218 + uncached_tokens as f64 * NORMAL_COST_PER_1K / 1000.0;
219
220 normal_cost - cache_cost
221 }
222
223 pub fn cost_per_call(&self) -> f64 {
225 const NORMAL_COST_PER_1K: f64 = 0.003;
226 const CACHE_COST_PER_1K: f64 = 0.0003;
227 const OUTPUT_COST_PER_1K: f64 = 0.015; let uncached_tokens = self.input_tokens.saturating_sub(self.cache_read_tokens);
230
231 self.cache_read_tokens as f64 * CACHE_COST_PER_1K / 1000.0
232 + uncached_tokens as f64 * NORMAL_COST_PER_1K / 1000.0
233 + self.output_tokens as f64 * OUTPUT_COST_PER_1K / 1000.0
234 }
235
236 pub fn is_effective(&self) -> bool {
238 self.cache_hit_rate() > 0.5
239 }
240
241 pub fn merge(&mut self, other: &CacheStats) {
243 self.input_tokens += other.input_tokens;
244 self.output_tokens += other.output_tokens;
245 self.cache_read_tokens += other.cache_read_tokens;
246 self.cache_creation_tokens += other.cache_creation_tokens;
247 }
248}
249
250pub fn apply_cache_control(messages: &mut [ChatMessage], config: &CachePromptConfig) {
283 if !config.enabled {
284 return;
285 }
286
287 let user_indices: Vec<usize> = messages
289 .iter()
290 .enumerate()
291 .filter(|(_, m)| matches!(m.role, ChatRole::User))
292 .map(|(i, _)| i)
293 .collect();
294
295 let last_n_start = user_indices
297 .len()
298 .saturating_sub(config.cache_last_n_messages);
299 let last_n_indices: std::collections::HashSet<usize> =
300 user_indices.into_iter().skip(last_n_start).collect();
301
302 for (i, msg) in messages.iter_mut().enumerate() {
303 let should_cache = match msg.role {
304 ChatRole::System => config.cache_system_prompt,
305 ChatRole::User => {
306 msg.content.len() >= config.min_content_length || last_n_indices.contains(&i)
308 }
309 _ => false, };
311
312 if should_cache && msg.cache_control.is_none() {
313 msg.cache_control = Some(CacheControl::ephemeral());
314 }
315 }
316}
317
318pub fn parse_cache_stats(usage: &serde_json::Value) -> CacheStats {
332 CacheStats {
333 input_tokens: usage["input_tokens"].as_u64().unwrap_or(0),
334 output_tokens: usage["output_tokens"].as_u64().unwrap_or(0),
335 cache_read_tokens: usage["cache_read_input_tokens"].as_u64().unwrap_or(0),
336 cache_creation_tokens: usage["cache_creation_input_tokens"].as_u64().unwrap_or(0),
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[test]
345 fn test_default_config() {
346 let config = CachePromptConfig::default();
347 assert!(config.enabled);
348 assert_eq!(config.min_content_length, 1000);
349 assert!(config.cache_system_prompt);
350 assert_eq!(config.cache_last_n_messages, 3);
351 }
352
353 #[test]
354 fn test_disabled_config() {
355 let config = CachePromptConfig::disabled();
356 assert!(!config.enabled);
357 }
358
359 #[test]
360 fn test_system_only_config() {
361 let config = CachePromptConfig::system_only();
362 assert!(config.enabled);
363 assert!(config.cache_system_prompt);
364 assert_eq!(config.cache_last_n_messages, 0);
365 assert_eq!(config.min_content_length, usize::MAX);
366 }
367
368 #[test]
369 fn test_aggressive_config() {
370 let config = CachePromptConfig::aggressive();
371 assert!(config.enabled);
372 assert_eq!(config.min_content_length, 100);
373 assert_eq!(config.cache_last_n_messages, 10);
374 }
375
376 #[test]
377 fn test_cache_control_disabled() {
378 let config = CachePromptConfig::disabled();
379 let mut messages = vec![
380 ChatMessage::system("System prompt"),
381 ChatMessage::user("User message"),
382 ];
383
384 apply_cache_control(&mut messages, &config);
385
386 assert!(messages[0].cache_control.is_none());
387 assert!(messages[1].cache_control.is_none());
388 }
389
390 #[test]
391 fn test_cache_system_prompt() {
392 let config = CachePromptConfig::default();
393 let mut messages = vec![
394 ChatMessage::system("You are a helpful assistant"),
395 ChatMessage::user("Hello"),
396 ];
397
398 apply_cache_control(&mut messages, &config);
399
400 assert!(messages[0].cache_control.is_some());
401 assert_eq!(
402 messages[0].cache_control.as_ref().unwrap().cache_type,
403 "ephemeral"
404 );
405 }
406
407 #[test]
408 fn test_cache_large_messages() {
409 let config = CachePromptConfig {
410 min_content_length: 100,
411 cache_last_n_messages: 0,
412 ..Default::default()
413 };
414
415 let large_content = "x".repeat(150);
416 let small_content = "y".repeat(50);
417
418 let mut messages = vec![
419 ChatMessage::system("System"),
420 ChatMessage::user(&large_content),
421 ChatMessage::user(&small_content),
422 ];
423
424 apply_cache_control(&mut messages, &config);
425
426 assert!(messages[0].cache_control.is_some());
428 assert!(messages[1].cache_control.is_some());
430 assert!(messages[2].cache_control.is_none());
432 }
433
434 #[test]
435 fn test_cache_last_n_messages() {
436 let config = CachePromptConfig {
437 min_content_length: usize::MAX, cache_last_n_messages: 2,
439 cache_system_prompt: false,
440 ..Default::default()
441 };
442
443 let mut messages = vec![
444 ChatMessage::system("System"),
445 ChatMessage::user("First"),
446 ChatMessage::assistant("Response"),
447 ChatMessage::user("Second"),
448 ChatMessage::assistant("Response"),
449 ChatMessage::user("Third"),
450 ChatMessage::user("Fourth"),
451 ];
452
453 apply_cache_control(&mut messages, &config);
454
455 assert!(messages[0].cache_control.is_none());
457 assert!(messages[1].cache_control.is_none());
459 assert!(messages[3].cache_control.is_none());
460 assert!(messages[5].cache_control.is_some()); assert!(messages[6].cache_control.is_some()); }
464
465 #[test]
466 fn test_preserves_existing_cache_control() {
467 let config = CachePromptConfig::default();
468 let mut messages = vec![ChatMessage::system("System")];
469
470 messages[0].cache_control = Some(CacheControl::ephemeral());
472
473 apply_cache_control(&mut messages, &config);
474
475 assert!(messages[0].cache_control.is_some());
477 }
478
479 #[test]
480 fn test_cache_hit_rate_zero_tokens() {
481 let stats = CacheStats::default();
482 assert_eq!(stats.cache_hit_rate(), 0.0);
483 }
484
485 #[test]
486 fn test_cache_hit_rate_full_cache() {
487 let stats = CacheStats {
488 input_tokens: 10000,
489 output_tokens: 500,
490 cache_read_tokens: 10000,
491 cache_creation_tokens: 0,
492 };
493 assert_eq!(stats.cache_hit_rate(), 1.0);
494 }
495
496 #[test]
497 fn test_cache_hit_rate_partial() {
498 let stats = CacheStats {
499 input_tokens: 10000,
500 output_tokens: 500,
501 cache_read_tokens: 8000,
502 cache_creation_tokens: 0,
503 };
504 assert_eq!(stats.cache_hit_rate(), 0.8);
505 }
506
507 #[test]
508 fn test_cache_savings() {
509 let stats = CacheStats {
510 input_tokens: 10000,
511 output_tokens: 500,
512 cache_read_tokens: 8000,
513 cache_creation_tokens: 0,
514 };
515
516 let savings = stats.savings();
517
518 assert!(savings > 0.02);
520 assert!(savings < 0.03);
521 }
522
523 #[test]
524 fn test_cache_savings_no_cache() {
525 let stats = CacheStats {
526 input_tokens: 10000,
527 output_tokens: 500,
528 cache_read_tokens: 0,
529 cache_creation_tokens: 0,
530 };
531
532 assert_eq!(stats.savings(), 0.0);
533 }
534
535 #[test]
536 fn test_is_effective() {
537 let effective = CacheStats {
538 input_tokens: 10000,
539 cache_read_tokens: 6000,
540 ..Default::default()
541 };
542 assert!(effective.is_effective());
543
544 let ineffective = CacheStats {
545 input_tokens: 10000,
546 cache_read_tokens: 4000,
547 ..Default::default()
548 };
549 assert!(!ineffective.is_effective());
550 }
551
552 #[test]
553 fn test_merge_stats() {
554 let mut stats1 = CacheStats {
555 input_tokens: 1000,
556 output_tokens: 100,
557 cache_read_tokens: 500,
558 cache_creation_tokens: 200,
559 };
560
561 let stats2 = CacheStats {
562 input_tokens: 2000,
563 output_tokens: 200,
564 cache_read_tokens: 1000,
565 cache_creation_tokens: 100,
566 };
567
568 stats1.merge(&stats2);
569
570 assert_eq!(stats1.input_tokens, 3000);
571 assert_eq!(stats1.output_tokens, 300);
572 assert_eq!(stats1.cache_read_tokens, 1500);
573 assert_eq!(stats1.cache_creation_tokens, 300);
574 }
575
576 #[test]
577 fn test_parse_cache_stats() {
578 let usage = serde_json::json!({
579 "input_tokens": 10000,
580 "output_tokens": 500,
581 "cache_read_input_tokens": 8000,
582 "cache_creation_input_tokens": 100
583 });
584
585 let stats = parse_cache_stats(&usage);
586
587 assert_eq!(stats.input_tokens, 10000);
588 assert_eq!(stats.output_tokens, 500);
589 assert_eq!(stats.cache_read_tokens, 8000);
590 assert_eq!(stats.cache_creation_tokens, 100);
591 }
592
593 #[test]
594 fn test_parse_cache_stats_missing_fields() {
595 let usage = serde_json::json!({
596 "input_tokens": 5000,
597 "output_tokens": 200
598 });
599
600 let stats = parse_cache_stats(&usage);
601
602 assert_eq!(stats.input_tokens, 5000);
603 assert_eq!(stats.output_tokens, 200);
604 assert_eq!(stats.cache_read_tokens, 0);
605 assert_eq!(stats.cache_creation_tokens, 0);
606 }
607
608 #[test]
609 fn test_cost_per_call() {
610 let stats = CacheStats {
611 input_tokens: 10000,
612 output_tokens: 1000,
613 cache_read_tokens: 8000,
614 cache_creation_tokens: 0,
615 };
616
617 let cost = stats.cost_per_call();
618
619 assert!(cost > 0.02);
624 assert!(cost < 0.03);
625 }
626
627 #[test]
628 fn test_cache_stats_serialization() {
629 let stats = CacheStats {
630 input_tokens: 1000,
631 output_tokens: 100,
632 cache_read_tokens: 800,
633 cache_creation_tokens: 50,
634 };
635
636 let json = serde_json::to_string(&stats).unwrap();
637 let deserialized: CacheStats = serde_json::from_str(&json).unwrap();
638
639 assert_eq!(stats.input_tokens, deserialized.input_tokens);
640 assert_eq!(stats.output_tokens, deserialized.output_tokens);
641 assert_eq!(stats.cache_read_tokens, deserialized.cache_read_tokens);
642 assert_eq!(
643 stats.cache_creation_tokens,
644 deserialized.cache_creation_tokens
645 );
646 }
647
648 #[test]
649 fn test_cache_stats_new_constructor() {
650 let stats = CacheStats::new(5000, 500, 3000, 200);
651 assert_eq!(stats.input_tokens, 5000);
652 assert_eq!(stats.output_tokens, 500);
653 assert_eq!(stats.cache_read_tokens, 3000);
654 assert_eq!(stats.cache_creation_tokens, 200);
655 }
656
657 #[test]
658 fn test_apply_cache_control_empty_messages() {
659 let config = CachePromptConfig::default();
660 let mut messages: Vec<ChatMessage> = vec![];
661 apply_cache_control(&mut messages, &config);
662 assert!(messages.is_empty());
663 }
664
665 #[test]
666 fn test_apply_cache_control_only_assistant_messages() {
667 let config = CachePromptConfig::default();
668 let mut messages = vec![
669 ChatMessage::assistant("I will help you"),
670 ChatMessage::assistant("Here is the answer"),
671 ];
672 apply_cache_control(&mut messages, &config);
673 assert!(messages[0].cache_control.is_none());
675 assert!(messages[1].cache_control.is_none());
676 }
677
678 #[test]
679 fn test_parse_cache_stats_empty_json() {
680 let usage = serde_json::json!({});
681 let stats = parse_cache_stats(&usage);
682 assert_eq!(stats.input_tokens, 0);
683 assert_eq!(stats.output_tokens, 0);
684 assert_eq!(stats.cache_read_tokens, 0);
685 assert_eq!(stats.cache_creation_tokens, 0);
686 }
687
688 #[test]
689 fn test_is_effective_boundary_at_50_percent() {
690 let stats = CacheStats {
692 input_tokens: 10000,
693 cache_read_tokens: 5000,
694 ..Default::default()
695 };
696 assert!(!stats.is_effective());
697 }
698
699 #[test]
700 fn test_cost_per_call_zero_tokens() {
701 let stats = CacheStats::default();
702 assert_eq!(stats.cost_per_call(), 0.0);
703 }
704
705 #[test]
706 fn test_cost_per_call_all_cached() {
707 let stats = CacheStats {
708 input_tokens: 10000,
709 output_tokens: 0,
710 cache_read_tokens: 10000,
711 cache_creation_tokens: 0,
712 };
713 let cost = stats.cost_per_call();
714 assert!((cost - 0.003).abs() < 1e-10);
716 }
717
718 #[test]
719 fn test_config_serialization_roundtrip() {
720 let config = CachePromptConfig::aggressive();
721 let json = serde_json::to_string(&config).unwrap();
722 let deserialized: CachePromptConfig = serde_json::from_str(&json).unwrap();
723 assert_eq!(deserialized.enabled, config.enabled);
724 assert_eq!(deserialized.min_content_length, config.min_content_length);
725 assert_eq!(deserialized.cache_system_prompt, config.cache_system_prompt);
726 assert_eq!(
727 deserialized.cache_last_n_messages,
728 config.cache_last_n_messages
729 );
730 }
731
732 #[test]
733 fn test_savings_when_cache_read_exceeds_input() {
734 let stats = CacheStats {
736 input_tokens: 5000,
737 output_tokens: 100,
738 cache_read_tokens: 8000,
739 cache_creation_tokens: 0,
740 };
741 let _ = stats.savings();
743 }
744
745 #[test]
746 fn test_merge_into_default() {
747 let mut stats = CacheStats::default();
748 let other = CacheStats::new(100, 50, 80, 10);
749 stats.merge(&other);
750 assert_eq!(stats.input_tokens, 100);
751 assert_eq!(stats.output_tokens, 50);
752 assert_eq!(stats.cache_read_tokens, 80);
753 assert_eq!(stats.cache_creation_tokens, 10);
754 }
755
756 #[test]
757 fn test_apply_cache_control_single_user_with_last_n() {
758 let config = CachePromptConfig {
760 min_content_length: usize::MAX,
761 cache_last_n_messages: 3,
762 cache_system_prompt: false,
763 ..Default::default()
764 };
765 let mut messages = vec![ChatMessage::user("Short msg")];
766 apply_cache_control(&mut messages, &config);
767 assert!(messages[0].cache_control.is_some());
768 }
769}