use crate::llm::types::LLMChunk;
use serde_json::Value;
pub const MAX_ANTHROPIC_CACHE_BREAKPOINTS: usize = 4;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CacheTtl {
#[default]
Default,
Extended,
}
impl CacheTtl {
pub fn anthropic_ttl(self) -> Option<&'static str> {
match self {
CacheTtl::Default => None,
CacheTtl::Extended => Some("1h"),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct PromptCachePlan {
pub cache_tools: bool,
pub cache_system: bool,
pub breakpoint_message_ids: Vec<String>,
pub ttl: CacheTtl,
}
impl PromptCachePlan {
pub fn disabled() -> Self {
Self::default()
}
pub fn is_enabled(&self) -> bool {
self.cache_tools || self.cache_system || !self.breakpoint_message_ids.is_empty()
}
pub fn is_breakpoint(&self, message_id: &str) -> bool {
self.breakpoint_message_ids
.iter()
.any(|id| id == message_id)
}
}
pub fn cache_usage_from_openai_usage(usage: &Value) -> Option<LLMChunk> {
let cached = usage
.get("prompt_tokens_details")
.or_else(|| usage.get("input_tokens_details"))
.and_then(|details| details.get("cached_tokens"))
.and_then(Value::as_u64)
.unwrap_or(0);
(cached > 0).then_some(LLMChunk::CacheUsage {
cache_creation_input_tokens: 0,
cache_read_input_tokens: cached,
})
}
pub fn cache_usage_from_gemini_usage(usage_metadata: &Value) -> Option<LLMChunk> {
let cached = usage_metadata
.get("cachedContentTokenCount")
.and_then(Value::as_u64)
.unwrap_or(0);
(cached > 0).then_some(LLMChunk::CacheUsage {
cache_creation_input_tokens: 0,
cache_read_input_tokens: cached,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn disabled_plan_is_not_enabled() {
assert!(!PromptCachePlan::disabled().is_enabled());
}
#[test]
fn plan_with_any_region_is_enabled() {
assert!(PromptCachePlan {
cache_system: true,
..Default::default()
}
.is_enabled());
assert!(PromptCachePlan {
breakpoint_message_ids: vec!["m1".to_string()],
..Default::default()
}
.is_enabled());
}
#[test]
fn is_breakpoint_matches_only_listed_ids() {
let plan = PromptCachePlan {
breakpoint_message_ids: vec!["a".to_string(), "b".to_string()],
..Default::default()
};
assert!(plan.is_breakpoint("a"));
assert!(plan.is_breakpoint("b"));
assert!(!plan.is_breakpoint("c"));
}
#[test]
fn extended_ttl_maps_to_one_hour() {
assert_eq!(CacheTtl::Extended.anthropic_ttl(), Some("1h"));
assert_eq!(CacheTtl::Default.anthropic_ttl(), None);
}
#[test]
fn openai_cache_usage_reads_prompt_and_input_details() {
let chat = serde_json::json!({"prompt_tokens_details": {"cached_tokens": 1234}});
match cache_usage_from_openai_usage(&chat) {
Some(LLMChunk::CacheUsage {
cache_read_input_tokens,
..
}) => assert_eq!(cache_read_input_tokens, 1234),
other => panic!("expected CacheUsage, got {other:?}"),
}
let responses = serde_json::json!({"input_tokens_details": {"cached_tokens": 99}});
match cache_usage_from_openai_usage(&responses) {
Some(LLMChunk::CacheUsage {
cache_read_input_tokens,
..
}) => assert_eq!(cache_read_input_tokens, 99),
other => panic!("expected CacheUsage, got {other:?}"),
}
}
#[test]
fn openai_cache_usage_none_when_no_cache_hit() {
let usage = serde_json::json!({"prompt_tokens_details": {"cached_tokens": 0}});
assert!(cache_usage_from_openai_usage(&usage).is_none());
assert!(cache_usage_from_openai_usage(&serde_json::json!({})).is_none());
}
#[test]
fn gemini_cache_usage_reads_cached_content_tokens() {
let usage = serde_json::json!({"cachedContentTokenCount": 555});
match cache_usage_from_gemini_usage(&usage) {
Some(LLMChunk::CacheUsage {
cache_read_input_tokens,
..
}) => assert_eq!(cache_read_input_tokens, 555),
other => panic!("expected CacheUsage, got {other:?}"),
}
assert!(cache_usage_from_gemini_usage(&serde_json::json!({})).is_none());
}
}