Skip to main content

aster/context/
cache_controller.rs

1//! Prompt Caching Controller Module
2//!
3//! Provides prompt caching support for reducing API costs and latency.
4//! This module implements cache control markers for eligible message blocks
5//! and calculates cache cost savings.
6//!
7//! # Features
8//!
9//! - Add cache control markers to eligible messages
10//! - Check cache eligibility based on token thresholds
11//! - Calculate cache cost savings
12//! - Track cache hit rates
13//!
14//! # Pricing Model
15//!
16//! Based on Anthropic's prompt caching pricing:
17//! - Cache write: 1.25x base input price
18//! - Cache read: 0.1x base input price (90% discount)
19
20use crate::context::token_estimator::TokenEstimator;
21use crate::context::types::{CacheConfig, CacheSavings, CacheStats, TokenUsage};
22use crate::conversation::message::Message;
23
24/// Base input price per million tokens (used for cost calculations)
25/// This is a reference value; actual pricing may vary by model
26const BASE_INPUT_PRICE_PER_MILLION: f64 = 3.0;
27
28/// Cache write multiplier (1.25x base price)
29const CACHE_WRITE_MULTIPLIER: f64 = 1.25;
30
31/// Cache read multiplier (0.1x base price - 90% discount)
32const CACHE_READ_MULTIPLIER: f64 = 0.1;
33
34/// Result of cache eligibility check with indices of cacheable messages
35#[derive(Debug, Clone, Default)]
36pub struct CacheEligibility {
37    /// Indices of messages that are eligible for caching
38    pub cacheable_indices: Vec<usize>,
39    /// Total estimated tokens in cacheable messages
40    pub cacheable_tokens: usize,
41}
42
43/// Prompt Caching Controller
44///
45/// Manages cache control markers for messages and calculates cache savings.
46///
47/// # Note on Cache Control Implementation
48///
49/// Cache control markers are typically added at the API request level by the
50/// provider implementation, not stored in the message content itself. This
51/// controller identifies which messages are eligible for caching and provides
52/// the information needed for providers to add appropriate cache control headers.
53pub struct CacheController;
54
55impl CacheController {
56    /// Identify messages eligible for cache control.
57    ///
58    /// This method analyzes messages and returns information about which
59    /// messages are eligible for caching based on the provided configuration.
60    ///
61    /// # Arguments
62    ///
63    /// * `messages` - The messages to analyze
64    /// * `config` - Cache configuration specifying thresholds and options
65    ///
66    /// # Returns
67    ///
68    /// `CacheEligibility` containing indices of cacheable messages
69    ///
70    /// # Cache Eligibility Rules
71    ///
72    /// Messages are eligible for caching if:
73    /// 1. They meet the minimum token threshold
74    /// 2. They are within the most recent N messages (as configured)
75    ///
76    /// # Example
77    ///
78    /// ```ignore
79    /// use aster::context::cache_controller::CacheController;
80    /// use aster::context::types::CacheConfig;
81    ///
82    /// let messages = vec![/* ... */];
83    /// let config = CacheConfig::default();
84    /// let eligibility = CacheController::get_cache_eligibility(&messages, &config);
85    /// println!("Cacheable messages: {:?}", eligibility.cacheable_indices);
86    /// ```
87    pub fn get_cache_eligibility(messages: &[Message], config: &CacheConfig) -> CacheEligibility {
88        if messages.is_empty() {
89            return CacheEligibility::default();
90        }
91
92        let len = messages.len();
93        let mut cacheable_indices = Vec::new();
94        let mut cacheable_tokens = 0;
95
96        // Determine which messages are eligible for caching
97        // Only cache the most recent N messages as configured
98        let start_index = len.saturating_sub(config.cache_recent_messages);
99
100        // Check eligibility for each message in the range
101        for (i, message) in messages.iter().enumerate().take(len).skip(start_index) {
102            if Self::is_cacheable(message, config.min_tokens_for_cache) {
103                let tokens = TokenEstimator::estimate_message_tokens(message);
104                cacheable_indices.push(i);
105                cacheable_tokens += tokens;
106            }
107        }
108
109        CacheEligibility {
110            cacheable_indices,
111            cacheable_tokens,
112        }
113    }
114
115    /// Add cache control markers to eligible messages.
116    ///
117    /// This method returns a new vector of messages with cache eligibility
118    /// information. The actual cache control markers should be added by
119    /// the provider when making API requests.
120    ///
121    /// # Arguments
122    ///
123    /// * `messages` - The messages to potentially mark for caching
124    /// * `config` - Cache configuration specifying thresholds and options
125    ///
126    /// # Returns
127    ///
128    /// A tuple of (messages, cacheable_indices) where cacheable_indices
129    /// contains the indices of messages that should have cache control applied
130    pub fn add_cache_control(
131        messages: &[Message],
132        config: &CacheConfig,
133    ) -> (Vec<Message>, Vec<usize>) {
134        let eligibility = Self::get_cache_eligibility(messages, config);
135        (messages.to_vec(), eligibility.cacheable_indices)
136    }
137
138    /// Check if a message is eligible for caching.
139    ///
140    /// A message is cacheable if:
141    /// 1. It has content
142    /// 2. Its estimated token count meets the minimum threshold
143    ///
144    /// # Arguments
145    ///
146    /// * `message` - The message to check
147    /// * `min_tokens` - Minimum token threshold for caching
148    ///
149    /// # Returns
150    ///
151    /// `true` if the message is eligible for caching
152    pub fn is_cacheable(message: &Message, min_tokens: usize) -> bool {
153        if message.content.is_empty() {
154            return false;
155        }
156
157        let tokens = TokenEstimator::estimate_message_tokens(message);
158        tokens >= min_tokens
159    }
160
161    /// Calculate cache cost savings based on token usage.
162    ///
163    /// Uses Anthropic's prompt caching pricing model:
164    /// - Cache write: 1.25x base input price
165    /// - Cache read: 0.1x base input price (90% discount)
166    ///
167    /// # Arguments
168    ///
169    /// * `usage` - Token usage statistics including cache metrics
170    ///
171    /// # Returns
172    ///
173    /// `CacheSavings` containing base cost, actual cost, and savings
174    ///
175    /// # Calculation
176    ///
177    /// ```text
178    /// base_cost = input_tokens * base_price
179    /// cache_write_cost = cache_creation_tokens * (base_price * 1.25)
180    /// cache_read_cost = cache_read_tokens * (base_price * 0.1)
181    /// actual_cost = (input_tokens - cache_read_tokens) * base_price
182    ///             + cache_write_cost + cache_read_cost
183    /// savings = base_cost - actual_cost
184    /// ```
185    pub fn calculate_cache_savings(usage: &TokenUsage) -> CacheSavings {
186        let base_price = BASE_INPUT_PRICE_PER_MILLION / 1_000_000.0;
187
188        // Calculate what the cost would be without caching
189        let base_cost = usage.input_tokens as f64 * base_price;
190
191        // Calculate actual cost with caching
192        let cache_creation_tokens = usage.cache_creation_tokens.unwrap_or(0);
193        let cache_read_tokens = usage.cache_read_tokens.unwrap_or(0);
194
195        // Cache write cost (1.25x base price)
196        let cache_write_cost = cache_creation_tokens as f64 * base_price * CACHE_WRITE_MULTIPLIER;
197
198        // Cache read cost (0.1x base price - 90% discount)
199        let cache_read_cost = cache_read_tokens as f64 * base_price * CACHE_READ_MULTIPLIER;
200
201        // Non-cached input tokens cost
202        let non_cached_tokens = usage.input_tokens.saturating_sub(cache_read_tokens);
203        let non_cached_cost = non_cached_tokens as f64 * base_price;
204
205        // Total actual cost
206        let actual_cost = non_cached_cost + cache_write_cost + cache_read_cost;
207
208        CacheSavings::new(base_cost, actual_cost)
209    }
210
211    /// Calculate cache statistics from token usage.
212    ///
213    /// # Arguments
214    ///
215    /// * `usage` - Token usage statistics
216    ///
217    /// # Returns
218    ///
219    /// `CacheStats` with totals and hit rate
220    pub fn calculate_cache_stats(usage: &TokenUsage) -> CacheStats {
221        let cache_creation = usage.cache_creation_tokens.unwrap_or(0);
222        let cache_read = usage.cache_read_tokens.unwrap_or(0);
223
224        let total_cache_tokens = cache_creation + cache_read;
225        let hit_rate = if total_cache_tokens > 0 {
226            cache_read as f64 / total_cache_tokens as f64
227        } else {
228            0.0
229        };
230
231        CacheStats {
232            total_cache_creation_tokens: cache_creation,
233            total_cache_read_tokens: cache_read,
234            cache_hit_rate: hit_rate,
235        }
236    }
237
238    /// Accumulate cache statistics from multiple usages.
239    ///
240    /// # Arguments
241    ///
242    /// * `usages` - Iterator of token usage statistics
243    ///
244    /// # Returns
245    ///
246    /// Aggregated `CacheStats`
247    pub fn accumulate_cache_stats<'a>(usages: impl Iterator<Item = &'a TokenUsage>) -> CacheStats {
248        let mut total_creation = 0usize;
249        let mut total_read = 0usize;
250
251        for usage in usages {
252            total_creation += usage.cache_creation_tokens.unwrap_or(0);
253            total_read += usage.cache_read_tokens.unwrap_or(0);
254        }
255
256        let total = total_creation + total_read;
257        let hit_rate = if total > 0 {
258            total_read as f64 / total as f64
259        } else {
260            0.0
261        };
262
263        CacheStats {
264            total_cache_creation_tokens: total_creation,
265            total_cache_read_tokens: total_read,
266            cache_hit_rate: hit_rate,
267        }
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    fn create_message_with_tokens(text: &str) -> Message {
276        Message::user().with_text(text)
277    }
278
279    fn create_long_message() -> Message {
280        // Create a message with enough content to exceed default threshold (1024 tokens)
281        let long_text = "x".repeat(4000); // ~1143 tokens at 3.5 chars/token
282        Message::user().with_text(long_text)
283    }
284
285    fn create_short_message() -> Message {
286        Message::user().with_text("Hello")
287    }
288
289    #[test]
290    fn test_is_cacheable_empty_message() {
291        let message = Message::user();
292        assert!(!CacheController::is_cacheable(&message, 1024));
293    }
294
295    #[test]
296    fn test_is_cacheable_short_message() {
297        let message = create_short_message();
298        assert!(!CacheController::is_cacheable(&message, 1024));
299    }
300
301    #[test]
302    fn test_is_cacheable_long_message() {
303        let message = create_long_message();
304        assert!(CacheController::is_cacheable(&message, 1024));
305    }
306
307    #[test]
308    fn test_is_cacheable_with_low_threshold() {
309        let message = create_short_message();
310        // With a very low threshold, even short messages should be cacheable
311        assert!(CacheController::is_cacheable(&message, 1));
312    }
313
314    #[test]
315    fn test_add_cache_control_empty_messages() {
316        let messages: Vec<Message> = vec![];
317        let config = CacheConfig::default();
318        let (result, indices) = CacheController::add_cache_control(&messages, &config);
319        assert!(result.is_empty());
320        assert!(indices.is_empty());
321    }
322
323    #[test]
324    fn test_add_cache_control_respects_recent_limit() {
325        // Create 5 long messages
326        let messages: Vec<Message> = (0..5).map(|_| create_long_message()).collect();
327
328        let config = CacheConfig {
329            cache_recent_messages: 2,
330            min_tokens_for_cache: 100, // Lower threshold for testing
331            ..Default::default()
332        };
333
334        let (result, indices) = CacheController::add_cache_control(&messages, &config);
335
336        // All 5 messages should be returned
337        assert_eq!(result.len(), 5);
338        // Only the last 2 messages should be cacheable (indices 3 and 4)
339        assert!(indices.iter().all(|&i| i >= 3));
340    }
341
342    #[test]
343    fn test_add_cache_control_respects_token_threshold() {
344        let messages = vec![create_short_message(), create_long_message()];
345
346        let config = CacheConfig::default();
347        let (result, indices) = CacheController::add_cache_control(&messages, &config);
348
349        // Both messages should be returned
350        assert_eq!(result.len(), 2);
351        // Only the long message (index 1) should be cacheable
352        assert!(indices.contains(&1) || indices.is_empty());
353    }
354
355    #[test]
356    fn test_get_cache_eligibility_empty() {
357        let messages: Vec<Message> = vec![];
358        let config = CacheConfig::default();
359        let eligibility = CacheController::get_cache_eligibility(&messages, &config);
360        assert!(eligibility.cacheable_indices.is_empty());
361        assert_eq!(eligibility.cacheable_tokens, 0);
362    }
363
364    #[test]
365    fn test_get_cache_eligibility_with_long_messages() {
366        let messages: Vec<Message> = (0..3).map(|_| create_long_message()).collect();
367
368        let config = CacheConfig {
369            min_tokens_for_cache: 100,
370            cache_recent_messages: 10,
371            ..Default::default()
372        };
373
374        let eligibility = CacheController::get_cache_eligibility(&messages, &config);
375
376        // All 3 messages should be cacheable
377        assert_eq!(eligibility.cacheable_indices.len(), 3);
378        assert!(eligibility.cacheable_tokens > 0);
379    }
380
381    #[test]
382    fn test_calculate_cache_savings_no_cache() {
383        let usage = TokenUsage::new(1000, 500);
384        let savings = CacheController::calculate_cache_savings(&usage);
385
386        // Without caching, base_cost should equal cache_cost
387        assert!((savings.base_cost - savings.cache_cost).abs() < 0.0001);
388        assert!(savings.savings.abs() < 0.0001);
389    }
390
391    #[test]
392    fn test_calculate_cache_savings_with_cache_read() {
393        let usage = TokenUsage {
394            input_tokens: 1000,
395            output_tokens: 500,
396            cache_creation_tokens: Some(0),
397            cache_read_tokens: Some(800), // 80% cache hit
398            thinking_tokens: None,
399        };
400
401        let savings = CacheController::calculate_cache_savings(&usage);
402
403        // With cache read, actual cost should be lower
404        assert!(savings.savings > 0.0);
405        assert!(savings.cache_cost < savings.base_cost);
406    }
407
408    #[test]
409    fn test_calculate_cache_savings_with_cache_write() {
410        let usage = TokenUsage {
411            input_tokens: 1000,
412            output_tokens: 500,
413            cache_creation_tokens: Some(500), // Writing to cache
414            cache_read_tokens: Some(0),
415            thinking_tokens: None,
416        };
417
418        let savings = CacheController::calculate_cache_savings(&usage);
419
420        // Cache write is more expensive (1.25x), so savings should be negative
421        assert!(savings.savings < 0.0);
422    }
423
424    #[test]
425    fn test_calculate_cache_savings_mixed() {
426        let usage = TokenUsage {
427            input_tokens: 10000,
428            output_tokens: 1000,
429            cache_creation_tokens: Some(1000), // Some cache write
430            cache_read_tokens: Some(8000),     // Mostly cache read
431            thinking_tokens: None,
432        };
433
434        let savings = CacheController::calculate_cache_savings(&usage);
435
436        // With high cache read ratio, should have positive savings
437        assert!(savings.savings > 0.0);
438        assert!(savings.savings_percentage() > 0.0);
439    }
440
441    #[test]
442    fn test_calculate_cache_stats_no_cache() {
443        let usage = TokenUsage::new(1000, 500);
444        let stats = CacheController::calculate_cache_stats(&usage);
445
446        assert_eq!(stats.total_cache_creation_tokens, 0);
447        assert_eq!(stats.total_cache_read_tokens, 0);
448        assert_eq!(stats.cache_hit_rate, 0.0);
449    }
450
451    #[test]
452    fn test_calculate_cache_stats_with_cache() {
453        let usage = TokenUsage {
454            input_tokens: 1000,
455            output_tokens: 500,
456            cache_creation_tokens: Some(200),
457            cache_read_tokens: Some(800),
458            thinking_tokens: None,
459        };
460
461        let stats = CacheController::calculate_cache_stats(&usage);
462
463        assert_eq!(stats.total_cache_creation_tokens, 200);
464        assert_eq!(stats.total_cache_read_tokens, 800);
465        assert!((stats.cache_hit_rate - 0.8).abs() < 0.001);
466    }
467
468    #[test]
469    fn test_accumulate_cache_stats() {
470        let usages = [
471            TokenUsage {
472                input_tokens: 1000,
473                output_tokens: 500,
474                cache_creation_tokens: Some(100),
475                cache_read_tokens: Some(400),
476                thinking_tokens: None,
477            },
478            TokenUsage {
479                input_tokens: 2000,
480                output_tokens: 1000,
481                cache_creation_tokens: Some(200),
482                cache_read_tokens: Some(600),
483                thinking_tokens: None,
484            },
485        ];
486
487        let stats = CacheController::accumulate_cache_stats(usages.iter());
488
489        assert_eq!(stats.total_cache_creation_tokens, 300);
490        assert_eq!(stats.total_cache_read_tokens, 1000);
491        // Hit rate: 1000 / (300 + 1000) = 0.769...
492        assert!((stats.cache_hit_rate - 0.769).abs() < 0.01);
493    }
494
495    #[test]
496    fn test_cache_savings_percentage() {
497        let savings = CacheSavings::new(100.0, 60.0);
498        assert!((savings.savings_percentage() - 40.0).abs() < 0.001);
499    }
500
501    #[test]
502    fn test_cache_savings_percentage_zero_base() {
503        let savings = CacheSavings::new(0.0, 0.0);
504        assert_eq!(savings.savings_percentage(), 0.0);
505    }
506}