1use crate::types::LLMChunk;
28use serde_json::Value;
29
30pub const MAX_ANTHROPIC_CACHE_BREAKPOINTS: usize = 4;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
36pub enum CacheTtl {
37 #[default]
39 Default,
40 Extended,
43}
44
45impl CacheTtl {
46 pub fn anthropic_ttl(self) -> Option<&'static str> {
48 match self {
49 CacheTtl::Default => None,
50 CacheTtl::Extended => Some("1h"),
51 }
52 }
53}
54
55#[derive(Debug, Clone, Default)]
63pub struct PromptCachePlan {
64 pub cache_tools: bool,
66 pub cache_system: bool,
70 pub breakpoint_message_ids: Vec<String>,
73 pub ttl: CacheTtl,
75}
76
77impl PromptCachePlan {
78 pub fn disabled() -> Self {
80 Self::default()
81 }
82
83 pub fn is_enabled(&self) -> bool {
85 self.cache_tools || self.cache_system || !self.breakpoint_message_ids.is_empty()
86 }
87
88 pub fn is_breakpoint(&self, message_id: &str) -> bool {
90 self.breakpoint_message_ids
91 .iter()
92 .any(|id| id == message_id)
93 }
94}
95
96pub fn cache_usage_from_openai_usage(usage: &Value) -> Option<LLMChunk> {
103 let cached = usage
104 .get("prompt_tokens_details")
105 .or_else(|| usage.get("input_tokens_details"))
106 .and_then(|details| details.get("cached_tokens"))
107 .and_then(Value::as_u64)
108 .unwrap_or(0);
109 let prompt = usage
111 .get("prompt_tokens")
112 .or_else(|| usage.get("input_tokens"))
113 .and_then(Value::as_u64)
114 .unwrap_or(0);
115 (cached > 0).then_some(LLMChunk::CacheUsage {
116 cache_creation_input_tokens: 0,
117 cache_read_input_tokens: cached,
118 input_tokens: prompt.saturating_sub(cached),
119 })
120}
121
122pub fn cache_usage_from_gemini_usage(usage_metadata: &Value) -> Option<LLMChunk> {
125 let cached = usage_metadata
126 .get("cachedContentTokenCount")
127 .and_then(Value::as_u64)
128 .unwrap_or(0);
129 let prompt = usage_metadata
130 .get("promptTokenCount")
131 .and_then(Value::as_u64)
132 .unwrap_or(0);
133 (cached > 0).then_some(LLMChunk::CacheUsage {
134 cache_creation_input_tokens: 0,
135 cache_read_input_tokens: cached,
136 input_tokens: prompt.saturating_sub(cached),
137 })
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn disabled_plan_is_not_enabled() {
146 assert!(!PromptCachePlan::disabled().is_enabled());
147 }
148
149 #[test]
150 fn plan_with_any_region_is_enabled() {
151 assert!(PromptCachePlan {
152 cache_system: true,
153 ..Default::default()
154 }
155 .is_enabled());
156 assert!(PromptCachePlan {
157 breakpoint_message_ids: vec!["m1".to_string()],
158 ..Default::default()
159 }
160 .is_enabled());
161 }
162
163 #[test]
164 fn is_breakpoint_matches_only_listed_ids() {
165 let plan = PromptCachePlan {
166 breakpoint_message_ids: vec!["a".to_string(), "b".to_string()],
167 ..Default::default()
168 };
169 assert!(plan.is_breakpoint("a"));
170 assert!(plan.is_breakpoint("b"));
171 assert!(!plan.is_breakpoint("c"));
172 }
173
174 #[test]
175 fn extended_ttl_maps_to_one_hour() {
176 assert_eq!(CacheTtl::Extended.anthropic_ttl(), Some("1h"));
177 assert_eq!(CacheTtl::Default.anthropic_ttl(), None);
178 }
179
180 #[test]
181 fn openai_cache_usage_reads_prompt_and_input_details() {
182 let chat = serde_json::json!({"prompt_tokens_details": {"cached_tokens": 1234}});
183 match cache_usage_from_openai_usage(&chat) {
184 Some(LLMChunk::CacheUsage {
185 cache_read_input_tokens,
186 ..
187 }) => assert_eq!(cache_read_input_tokens, 1234),
188 other => panic!("expected CacheUsage, got {other:?}"),
189 }
190
191 let responses = serde_json::json!({"input_tokens_details": {"cached_tokens": 99}});
192 match cache_usage_from_openai_usage(&responses) {
193 Some(LLMChunk::CacheUsage {
194 cache_read_input_tokens,
195 ..
196 }) => assert_eq!(cache_read_input_tokens, 99),
197 other => panic!("expected CacheUsage, got {other:?}"),
198 }
199 }
200
201 #[test]
202 fn openai_cache_usage_none_when_no_cache_hit() {
203 let usage = serde_json::json!({"prompt_tokens_details": {"cached_tokens": 0}});
204 assert!(cache_usage_from_openai_usage(&usage).is_none());
205 assert!(cache_usage_from_openai_usage(&serde_json::json!({})).is_none());
206 }
207
208 #[test]
209 fn gemini_cache_usage_reads_cached_content_tokens() {
210 let usage = serde_json::json!({"cachedContentTokenCount": 555});
211 match cache_usage_from_gemini_usage(&usage) {
212 Some(LLMChunk::CacheUsage {
213 cache_read_input_tokens,
214 ..
215 }) => assert_eq!(cache_read_input_tokens, 555),
216 other => panic!("expected CacheUsage, got {other:?}"),
217 }
218 assert!(cache_usage_from_gemini_usage(&serde_json::json!({})).is_none());
219 }
220}