Skip to main content

bamboo_llm/
cache.rs

1//! Provider-agnostic prompt caching.
2//!
3//! Prompt caching has two halves that differ per provider but share a single
4//! policy, so the policy lives here and each provider only renders it:
5//!
6//! 1. **Where the cacheable prefix ends.** Anthropic needs explicit
7//!    `cache_control` breakpoints (at most [`MAX_ANTHROPIC_CACHE_BREAKPOINTS`]);
8//!    OpenAI / Gemini / Copilot cache an identical prefix automatically. In every
9//!    case a cache *hit* requires the bytes before the breakpoint to be identical
10//!    to a previous request. That means the engine must keep per-round volatile
11//!    content (task list, recalled memory, plan state) **out** of the cacheable
12//!    prefix and order it last — otherwise the breakpoint moves every round and
13//!    the cache read size swings or drops to zero.
14//!
15//! 2. **How cached-token usage is reported.** Anthropic reports
16//!    `cache_read_input_tokens`; OpenAI-compatible APIs report
17//!    `prompt_tokens_details.cached_tokens` (or `input_tokens_details.cached_tokens`
18//!    on the Responses API); Gemini reports `cachedContentTokenCount`. The
19//!    `cache_usage_from_*` helpers normalize these into [`LLMChunk::CacheUsage`]
20//!    so the same downstream accounting (and the frontend cache badge) works for
21//!    every provider.
22//!
23//! [`PromptCachePlan`] is the provider-agnostic description of (1): the engine
24//! builds it once from the prompt envelope and each provider renders it in its
25//! own dialect.
26
27use crate::types::LLMChunk;
28use serde_json::Value;
29
30/// Anthropic accepts at most this many `cache_control` breakpoints per request.
31/// Exceeding it is an API error, so renderers must clamp to this budget.
32pub const MAX_ANTHROPIC_CACHE_BREAKPOINTS: usize = 4;
33
34/// TTL hint for providers that expose a configurable cache lifetime.
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
36pub enum CacheTtl {
37    /// Provider default (Anthropic: 5 minutes).
38    #[default]
39    Default,
40    /// Extended lifetime (Anthropic: 1 hour). Requires the
41    /// `extended-cache-ttl-2025-04-11` beta header on the request.
42    Extended,
43}
44
45impl CacheTtl {
46    /// The `ttl` string Anthropic expects inside `cache_control`, if any.
47    pub fn anthropic_ttl(self) -> Option<&'static str> {
48        match self {
49            CacheTtl::Default => None,
50            CacheTtl::Extended => Some("1h"),
51        }
52    }
53}
54
55/// Provider-agnostic description of which logical regions of a request form a
56/// stable, cacheable prefix.
57///
58/// Breakpoints are identified by **message id** rather than position so the plan
59/// survives provider-side message reshaping — for example, Anthropic merges
60/// consecutive tool-result messages into the preceding user message, which would
61/// invalidate positional indices.
62#[derive(Debug, Clone, Default)]
63pub struct PromptCachePlan {
64    /// Cache the tool-definition block (stable for the whole session).
65    pub cache_tools: bool,
66    /// Cache the system prompt. Only set this when the system prompt is free of
67    /// per-round volatile content (the engine guarantees this by moving volatile
68    /// context blocks to the conversation tail).
69    pub cache_system: bool,
70    /// Ids of messages that end a stable prefix; each becomes a cache
71    /// breakpoint. Order is not significant.
72    pub breakpoint_message_ids: Vec<String>,
73    /// TTL hint for providers that support it.
74    pub ttl: CacheTtl,
75}
76
77impl PromptCachePlan {
78    /// A plan that requests no caching.
79    pub fn disabled() -> Self {
80        Self::default()
81    }
82
83    /// True when the plan asks for at least one cache breakpoint.
84    pub fn is_enabled(&self) -> bool {
85        self.cache_tools || self.cache_system || !self.breakpoint_message_ids.is_empty()
86    }
87
88    /// Whether the given message id is marked as a cache breakpoint.
89    pub fn is_breakpoint(&self, message_id: &str) -> bool {
90        self.breakpoint_message_ids
91            .iter()
92            .any(|id| id == message_id)
93    }
94}
95
96/// Normalize an OpenAI-style `usage` object into a [`LLMChunk::CacheUsage`], if
97/// it reports cached prompt tokens.
98///
99/// OpenAI exposes cached input tokens under `prompt_tokens_details.cached_tokens`
100/// (Chat Completions) or `input_tokens_details.cached_tokens` (Responses API).
101/// Returns `None` when no cache hit is reported, so callers can skip emitting.
102pub fn cache_usage_from_openai_usage(usage: &Value) -> Option<LLMChunk> {
103    let cached = usage
104        .get("prompt_tokens_details")
105        .or_else(|| usage.get("input_tokens_details"))
106        .and_then(|details| details.get("cached_tokens"))
107        .and_then(Value::as_u64)
108        .unwrap_or(0);
109    // Non-cached fresh input = total prompt input minus the cached portion.
110    let prompt = usage
111        .get("prompt_tokens")
112        .or_else(|| usage.get("input_tokens"))
113        .and_then(Value::as_u64)
114        .unwrap_or(0);
115    (cached > 0).then_some(LLMChunk::CacheUsage {
116        cache_creation_input_tokens: 0,
117        cache_read_input_tokens: cached,
118        input_tokens: prompt.saturating_sub(cached),
119    })
120}
121
122/// Normalize a Gemini `usageMetadata` object into a [`LLMChunk::CacheUsage`], if
123/// it reports cached content tokens (`cachedContentTokenCount`).
124pub fn cache_usage_from_gemini_usage(usage_metadata: &Value) -> Option<LLMChunk> {
125    let cached = usage_metadata
126        .get("cachedContentTokenCount")
127        .and_then(Value::as_u64)
128        .unwrap_or(0);
129    let prompt = usage_metadata
130        .get("promptTokenCount")
131        .and_then(Value::as_u64)
132        .unwrap_or(0);
133    (cached > 0).then_some(LLMChunk::CacheUsage {
134        cache_creation_input_tokens: 0,
135        cache_read_input_tokens: cached,
136        input_tokens: prompt.saturating_sub(cached),
137    })
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn disabled_plan_is_not_enabled() {
146        assert!(!PromptCachePlan::disabled().is_enabled());
147    }
148
149    #[test]
150    fn plan_with_any_region_is_enabled() {
151        assert!(PromptCachePlan {
152            cache_system: true,
153            ..Default::default()
154        }
155        .is_enabled());
156        assert!(PromptCachePlan {
157            breakpoint_message_ids: vec!["m1".to_string()],
158            ..Default::default()
159        }
160        .is_enabled());
161    }
162
163    #[test]
164    fn is_breakpoint_matches_only_listed_ids() {
165        let plan = PromptCachePlan {
166            breakpoint_message_ids: vec!["a".to_string(), "b".to_string()],
167            ..Default::default()
168        };
169        assert!(plan.is_breakpoint("a"));
170        assert!(plan.is_breakpoint("b"));
171        assert!(!plan.is_breakpoint("c"));
172    }
173
174    #[test]
175    fn extended_ttl_maps_to_one_hour() {
176        assert_eq!(CacheTtl::Extended.anthropic_ttl(), Some("1h"));
177        assert_eq!(CacheTtl::Default.anthropic_ttl(), None);
178    }
179
180    #[test]
181    fn openai_cache_usage_reads_prompt_and_input_details() {
182        let chat = serde_json::json!({"prompt_tokens_details": {"cached_tokens": 1234}});
183        match cache_usage_from_openai_usage(&chat) {
184            Some(LLMChunk::CacheUsage {
185                cache_read_input_tokens,
186                ..
187            }) => assert_eq!(cache_read_input_tokens, 1234),
188            other => panic!("expected CacheUsage, got {other:?}"),
189        }
190
191        let responses = serde_json::json!({"input_tokens_details": {"cached_tokens": 99}});
192        match cache_usage_from_openai_usage(&responses) {
193            Some(LLMChunk::CacheUsage {
194                cache_read_input_tokens,
195                ..
196            }) => assert_eq!(cache_read_input_tokens, 99),
197            other => panic!("expected CacheUsage, got {other:?}"),
198        }
199    }
200
201    #[test]
202    fn openai_cache_usage_none_when_no_cache_hit() {
203        let usage = serde_json::json!({"prompt_tokens_details": {"cached_tokens": 0}});
204        assert!(cache_usage_from_openai_usage(&usage).is_none());
205        assert!(cache_usage_from_openai_usage(&serde_json::json!({})).is_none());
206    }
207
208    #[test]
209    fn gemini_cache_usage_reads_cached_content_tokens() {
210        let usage = serde_json::json!({"cachedContentTokenCount": 555});
211        match cache_usage_from_gemini_usage(&usage) {
212            Some(LLMChunk::CacheUsage {
213                cache_read_input_tokens,
214                ..
215            }) => assert_eq!(cache_read_input_tokens, 555),
216            other => panic!("expected CacheUsage, got {other:?}"),
217        }
218        assert!(cache_usage_from_gemini_usage(&serde_json::json!({})).is_none());
219    }
220}