use crate::traits::{CacheControl, ChatMessage, ChatRole};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachePromptConfig {
pub enabled: bool,
pub min_content_length: usize,
pub cache_system_prompt: bool,
pub cache_last_n_messages: usize,
}
impl Default for CachePromptConfig {
fn default() -> Self {
Self {
enabled: true,
min_content_length: 1000,
cache_system_prompt: true,
cache_last_n_messages: 3,
}
}
}
impl CachePromptConfig {
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
pub fn system_only() -> Self {
Self {
enabled: true,
min_content_length: usize::MAX,
cache_system_prompt: true,
cache_last_n_messages: 0,
}
}
pub fn aggressive() -> Self {
Self {
enabled: true,
min_content_length: 100,
cache_system_prompt: true,
cache_last_n_messages: 10,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CacheStats {
pub input_tokens: u64,
pub output_tokens: u64,
pub cache_read_tokens: u64,
pub cache_creation_tokens: u64,
}
impl CacheStats {
pub fn new(
input_tokens: u64,
output_tokens: u64,
cache_read_tokens: u64,
cache_creation_tokens: u64,
) -> Self {
Self {
input_tokens,
output_tokens,
cache_read_tokens,
cache_creation_tokens,
}
}
pub fn cache_hit_rate(&self) -> f64 {
if self.input_tokens == 0 {
0.0
} else {
self.cache_read_tokens as f64 / self.input_tokens as f64
}
}
pub fn savings(&self) -> f64 {
const NORMAL_COST_PER_1K: f64 = 0.003;
const CACHE_COST_PER_1K: f64 = 0.0003;
let normal_cost = self.input_tokens as f64 * NORMAL_COST_PER_1K / 1000.0;
let uncached_tokens = self.input_tokens.saturating_sub(self.cache_read_tokens);
let cache_cost = self.cache_read_tokens as f64 * CACHE_COST_PER_1K / 1000.0
+ uncached_tokens as f64 * NORMAL_COST_PER_1K / 1000.0;
normal_cost - cache_cost
}
pub fn cost_per_call(&self) -> f64 {
const NORMAL_COST_PER_1K: f64 = 0.003;
const CACHE_COST_PER_1K: f64 = 0.0003;
const OUTPUT_COST_PER_1K: f64 = 0.015;
let uncached_tokens = self.input_tokens.saturating_sub(self.cache_read_tokens);
self.cache_read_tokens as f64 * CACHE_COST_PER_1K / 1000.0
+ uncached_tokens as f64 * NORMAL_COST_PER_1K / 1000.0
+ self.output_tokens as f64 * OUTPUT_COST_PER_1K / 1000.0
}
pub fn is_effective(&self) -> bool {
self.cache_hit_rate() > 0.5
}
pub fn merge(&mut self, other: &CacheStats) {
self.input_tokens += other.input_tokens;
self.output_tokens += other.output_tokens;
self.cache_read_tokens += other.cache_read_tokens;
self.cache_creation_tokens += other.cache_creation_tokens;
}
}
pub fn apply_cache_control(messages: &mut [ChatMessage], config: &CachePromptConfig) {
if !config.enabled {
return;
}
let user_indices: Vec<usize> = messages
.iter()
.enumerate()
.filter(|(_, m)| matches!(m.role, ChatRole::User))
.map(|(i, _)| i)
.collect();
let last_n_start = user_indices
.len()
.saturating_sub(config.cache_last_n_messages);
let last_n_indices: std::collections::HashSet<usize> =
user_indices.into_iter().skip(last_n_start).collect();
for (i, msg) in messages.iter_mut().enumerate() {
let should_cache = match msg.role {
ChatRole::System => config.cache_system_prompt,
ChatRole::User => {
msg.content.len() >= config.min_content_length || last_n_indices.contains(&i)
}
_ => false, };
if should_cache && msg.cache_control.is_none() {
msg.cache_control = Some(CacheControl::ephemeral());
}
}
}
pub fn parse_cache_stats(usage: &serde_json::Value) -> CacheStats {
CacheStats {
input_tokens: usage["input_tokens"].as_u64().unwrap_or(0),
output_tokens: usage["output_tokens"].as_u64().unwrap_or(0),
cache_read_tokens: usage["cache_read_input_tokens"].as_u64().unwrap_or(0),
cache_creation_tokens: usage["cache_creation_input_tokens"].as_u64().unwrap_or(0),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = CachePromptConfig::default();
assert!(config.enabled);
assert_eq!(config.min_content_length, 1000);
assert!(config.cache_system_prompt);
assert_eq!(config.cache_last_n_messages, 3);
}
#[test]
fn test_disabled_config() {
let config = CachePromptConfig::disabled();
assert!(!config.enabled);
}
#[test]
fn test_system_only_config() {
let config = CachePromptConfig::system_only();
assert!(config.enabled);
assert!(config.cache_system_prompt);
assert_eq!(config.cache_last_n_messages, 0);
assert_eq!(config.min_content_length, usize::MAX);
}
#[test]
fn test_aggressive_config() {
let config = CachePromptConfig::aggressive();
assert!(config.enabled);
assert_eq!(config.min_content_length, 100);
assert_eq!(config.cache_last_n_messages, 10);
}
#[test]
fn test_cache_control_disabled() {
let config = CachePromptConfig::disabled();
let mut messages = vec![
ChatMessage::system("System prompt"),
ChatMessage::user("User message"),
];
apply_cache_control(&mut messages, &config);
assert!(messages[0].cache_control.is_none());
assert!(messages[1].cache_control.is_none());
}
#[test]
fn test_cache_system_prompt() {
let config = CachePromptConfig::default();
let mut messages = vec![
ChatMessage::system("You are a helpful assistant"),
ChatMessage::user("Hello"),
];
apply_cache_control(&mut messages, &config);
assert!(messages[0].cache_control.is_some());
assert_eq!(
messages[0].cache_control.as_ref().unwrap().cache_type,
"ephemeral"
);
}
#[test]
fn test_cache_large_messages() {
let config = CachePromptConfig {
min_content_length: 100,
cache_last_n_messages: 0,
..Default::default()
};
let large_content = "x".repeat(150);
let small_content = "y".repeat(50);
let mut messages = vec![
ChatMessage::system("System"),
ChatMessage::user(&large_content),
ChatMessage::user(&small_content),
];
apply_cache_control(&mut messages, &config);
assert!(messages[0].cache_control.is_some());
assert!(messages[1].cache_control.is_some());
assert!(messages[2].cache_control.is_none());
}
#[test]
fn test_cache_last_n_messages() {
let config = CachePromptConfig {
min_content_length: usize::MAX, cache_last_n_messages: 2,
cache_system_prompt: false,
..Default::default()
};
let mut messages = vec![
ChatMessage::system("System"),
ChatMessage::user("First"),
ChatMessage::assistant("Response"),
ChatMessage::user("Second"),
ChatMessage::assistant("Response"),
ChatMessage::user("Third"),
ChatMessage::user("Fourth"),
];
apply_cache_control(&mut messages, &config);
assert!(messages[0].cache_control.is_none());
assert!(messages[1].cache_control.is_none());
assert!(messages[3].cache_control.is_none());
assert!(messages[5].cache_control.is_some()); assert!(messages[6].cache_control.is_some()); }
#[test]
fn test_preserves_existing_cache_control() {
let config = CachePromptConfig::default();
let mut messages = vec![ChatMessage::system("System")];
messages[0].cache_control = Some(CacheControl::ephemeral());
apply_cache_control(&mut messages, &config);
assert!(messages[0].cache_control.is_some());
}
#[test]
fn test_cache_hit_rate_zero_tokens() {
let stats = CacheStats::default();
assert_eq!(stats.cache_hit_rate(), 0.0);
}
#[test]
fn test_cache_hit_rate_full_cache() {
let stats = CacheStats {
input_tokens: 10000,
output_tokens: 500,
cache_read_tokens: 10000,
cache_creation_tokens: 0,
};
assert_eq!(stats.cache_hit_rate(), 1.0);
}
#[test]
fn test_cache_hit_rate_partial() {
let stats = CacheStats {
input_tokens: 10000,
output_tokens: 500,
cache_read_tokens: 8000,
cache_creation_tokens: 0,
};
assert_eq!(stats.cache_hit_rate(), 0.8);
}
#[test]
fn test_cache_savings() {
let stats = CacheStats {
input_tokens: 10000,
output_tokens: 500,
cache_read_tokens: 8000,
cache_creation_tokens: 0,
};
let savings = stats.savings();
assert!(savings > 0.02);
assert!(savings < 0.03);
}
#[test]
fn test_cache_savings_no_cache() {
let stats = CacheStats {
input_tokens: 10000,
output_tokens: 500,
cache_read_tokens: 0,
cache_creation_tokens: 0,
};
assert_eq!(stats.savings(), 0.0);
}
#[test]
fn test_is_effective() {
let effective = CacheStats {
input_tokens: 10000,
cache_read_tokens: 6000,
..Default::default()
};
assert!(effective.is_effective());
let ineffective = CacheStats {
input_tokens: 10000,
cache_read_tokens: 4000,
..Default::default()
};
assert!(!ineffective.is_effective());
}
#[test]
fn test_merge_stats() {
let mut stats1 = CacheStats {
input_tokens: 1000,
output_tokens: 100,
cache_read_tokens: 500,
cache_creation_tokens: 200,
};
let stats2 = CacheStats {
input_tokens: 2000,
output_tokens: 200,
cache_read_tokens: 1000,
cache_creation_tokens: 100,
};
stats1.merge(&stats2);
assert_eq!(stats1.input_tokens, 3000);
assert_eq!(stats1.output_tokens, 300);
assert_eq!(stats1.cache_read_tokens, 1500);
assert_eq!(stats1.cache_creation_tokens, 300);
}
#[test]
fn test_parse_cache_stats() {
let usage = serde_json::json!({
"input_tokens": 10000,
"output_tokens": 500,
"cache_read_input_tokens": 8000,
"cache_creation_input_tokens": 100
});
let stats = parse_cache_stats(&usage);
assert_eq!(stats.input_tokens, 10000);
assert_eq!(stats.output_tokens, 500);
assert_eq!(stats.cache_read_tokens, 8000);
assert_eq!(stats.cache_creation_tokens, 100);
}
#[test]
fn test_parse_cache_stats_missing_fields() {
let usage = serde_json::json!({
"input_tokens": 5000,
"output_tokens": 200
});
let stats = parse_cache_stats(&usage);
assert_eq!(stats.input_tokens, 5000);
assert_eq!(stats.output_tokens, 200);
assert_eq!(stats.cache_read_tokens, 0);
assert_eq!(stats.cache_creation_tokens, 0);
}
#[test]
fn test_cost_per_call() {
let stats = CacheStats {
input_tokens: 10000,
output_tokens: 1000,
cache_read_tokens: 8000,
cache_creation_tokens: 0,
};
let cost = stats.cost_per_call();
assert!(cost > 0.02);
assert!(cost < 0.03);
}
#[test]
fn test_cache_stats_serialization() {
let stats = CacheStats {
input_tokens: 1000,
output_tokens: 100,
cache_read_tokens: 800,
cache_creation_tokens: 50,
};
let json = serde_json::to_string(&stats).unwrap();
let deserialized: CacheStats = serde_json::from_str(&json).unwrap();
assert_eq!(stats.input_tokens, deserialized.input_tokens);
assert_eq!(stats.output_tokens, deserialized.output_tokens);
assert_eq!(stats.cache_read_tokens, deserialized.cache_read_tokens);
assert_eq!(
stats.cache_creation_tokens,
deserialized.cache_creation_tokens
);
}
#[test]
fn test_cache_stats_new_constructor() {
let stats = CacheStats::new(5000, 500, 3000, 200);
assert_eq!(stats.input_tokens, 5000);
assert_eq!(stats.output_tokens, 500);
assert_eq!(stats.cache_read_tokens, 3000);
assert_eq!(stats.cache_creation_tokens, 200);
}
#[test]
fn test_apply_cache_control_empty_messages() {
let config = CachePromptConfig::default();
let mut messages: Vec<ChatMessage> = vec![];
apply_cache_control(&mut messages, &config);
assert!(messages.is_empty());
}
#[test]
fn test_apply_cache_control_only_assistant_messages() {
let config = CachePromptConfig::default();
let mut messages = vec![
ChatMessage::assistant("I will help you"),
ChatMessage::assistant("Here is the answer"),
];
apply_cache_control(&mut messages, &config);
assert!(messages[0].cache_control.is_none());
assert!(messages[1].cache_control.is_none());
}
#[test]
fn test_parse_cache_stats_empty_json() {
let usage = serde_json::json!({});
let stats = parse_cache_stats(&usage);
assert_eq!(stats.input_tokens, 0);
assert_eq!(stats.output_tokens, 0);
assert_eq!(stats.cache_read_tokens, 0);
assert_eq!(stats.cache_creation_tokens, 0);
}
#[test]
fn test_is_effective_boundary_at_50_percent() {
let stats = CacheStats {
input_tokens: 10000,
cache_read_tokens: 5000,
..Default::default()
};
assert!(!stats.is_effective());
}
#[test]
fn test_cost_per_call_zero_tokens() {
let stats = CacheStats::default();
assert_eq!(stats.cost_per_call(), 0.0);
}
#[test]
fn test_cost_per_call_all_cached() {
let stats = CacheStats {
input_tokens: 10000,
output_tokens: 0,
cache_read_tokens: 10000,
cache_creation_tokens: 0,
};
let cost = stats.cost_per_call();
assert!((cost - 0.003).abs() < 1e-10);
}
#[test]
fn test_config_serialization_roundtrip() {
let config = CachePromptConfig::aggressive();
let json = serde_json::to_string(&config).unwrap();
let deserialized: CachePromptConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.enabled, config.enabled);
assert_eq!(deserialized.min_content_length, config.min_content_length);
assert_eq!(deserialized.cache_system_prompt, config.cache_system_prompt);
assert_eq!(
deserialized.cache_last_n_messages,
config.cache_last_n_messages
);
}
#[test]
fn test_savings_when_cache_read_exceeds_input() {
let stats = CacheStats {
input_tokens: 5000,
output_tokens: 100,
cache_read_tokens: 8000,
cache_creation_tokens: 0,
};
let _ = stats.savings();
}
#[test]
fn test_merge_into_default() {
let mut stats = CacheStats::default();
let other = CacheStats::new(100, 50, 80, 10);
stats.merge(&other);
assert_eq!(stats.input_tokens, 100);
assert_eq!(stats.output_tokens, 50);
assert_eq!(stats.cache_read_tokens, 80);
assert_eq!(stats.cache_creation_tokens, 10);
}
#[test]
fn test_apply_cache_control_single_user_with_last_n() {
let config = CachePromptConfig {
min_content_length: usize::MAX,
cache_last_n_messages: 3,
cache_system_prompt: false,
..Default::default()
};
let mut messages = vec![ChatMessage::user("Short msg")];
apply_cache_control(&mut messages, &config);
assert!(messages[0].cache_control.is_some());
}
}