use crate::config::Config;
use crate::session::chat::format_number;
use crate::session::{Message, Session};
use anyhow::Result;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum CacheMarkerType {
System,
Tools,
Content,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheMarker {
pub message_index: usize,
pub marker_type: CacheMarkerType,
pub automatic: bool,
pub timestamp: u64,
}
pub struct CacheManager {
max_content_markers: usize,
}
impl Default for CacheManager {
fn default() -> Self {
Self {
max_content_markers: 2,
}
}
}
impl CacheManager {
pub fn new() -> Self {
Self::default()
}
pub fn add_automatic_cache_markers(
&self,
messages: &mut [Message],
has_tools: bool,
supports_caching: bool,
) {
if !supports_caching {
return;
}
if let Some(first_msg) = messages.first_mut() {
if first_msg.role == "system" && !first_msg.cached {
first_msg.cached = true;
}
}
if has_tools {
let mut last_system_index = None;
for (i, msg) in messages.iter().enumerate() {
if msg.role == "system" {
last_system_index = Some(i);
}
}
if let Some(index) = last_system_index {
if let Some(msg) = messages.get_mut(index) {
if !msg.cached {
msg.cached = true;
}
}
}
}
}
pub fn manage_content_cache_markers(
&self,
session: &mut Session,
target_message_index: Option<usize>,
_automatic: bool,
) -> Result<bool> {
let target_index = match target_message_index {
Some(idx) => idx,
None => {
session
.messages
.iter()
.enumerate()
.rev()
.find(|(_, msg)| msg.role == "user" || msg.role == "tool")
.map(|(i, _)| i)
.ok_or_else(|| anyhow::anyhow!("No user or tool messages found for caching"))?
}
};
let msg = session
.messages
.get(target_index)
.ok_or_else(|| anyhow::anyhow!("Message index {} not found", target_index))?;
if msg.role != "user" && msg.role != "tool" {
return Err(anyhow::anyhow!(
"Only user and tool messages can be marked for content caching"
));
}
let mut existing_markers: Vec<usize> = session
.messages
.iter()
.enumerate()
.filter_map(|(i, msg)| {
if msg.cached && (msg.role == "user" || msg.role == "tool") {
Some(i)
} else {
None
}
})
.collect();
existing_markers.sort();
if existing_markers.contains(&target_index) {
return Ok(false); }
match existing_markers.len().cmp(&self.max_content_markers) {
std::cmp::Ordering::Less => {
if let Some(target_msg) = session.messages.get_mut(target_index) {
target_msg.cached = true;
return Ok(true);
}
}
std::cmp::Ordering::Equal => {
if let Some(first_marker_index) = existing_markers.first() {
if let Some(first_msg) = session.messages.get_mut(*first_marker_index) {
first_msg.cached = false;
}
if let Some(target_msg) = session.messages.get_mut(target_index) {
target_msg.cached = true;
return Ok(true);
}
}
}
std::cmp::Ordering::Greater => {
while existing_markers.len() > self.max_content_markers {
if let Some(first_marker_index) = existing_markers.first() {
if let Some(first_msg) = session.messages.get_mut(*first_marker_index) {
first_msg.cached = false;
}
existing_markers.remove(0);
}
}
if let Some(target_msg) = session.messages.get_mut(target_index) {
target_msg.cached = true;
return Ok(true);
}
}
}
Ok(false)
}
pub fn check_and_apply_auto_cache_threshold(
&self,
session: &mut Session,
config: &Config,
supports_caching: bool,
_role: &str,
) -> Result<bool> {
if !supports_caching {
return Ok(false);
}
if session.messages.is_empty() {
return Ok(false);
}
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let time_since_last_cache =
current_time.saturating_sub(session.info.last_cache_checkpoint_time);
if time_since_last_cache >= config.cache_timeout_seconds {
let target_index = session
.messages
.iter()
.enumerate()
.rev()
.find(|(_, msg)| msg.role == "tool")
.or_else(|| {
session
.messages
.iter()
.enumerate()
.rev()
.find(|(_, msg)| msg.role == "user")
})
.map(|(i, _)| i);
if let Some(index) = target_index {
match self.apply_cache_to_message(session, index, supports_caching) {
Ok(true) => {
return Ok(true);
}
Ok(false) => {
session.info.last_cache_checkpoint_time = current_time;
return Ok(false);
}
Err(_) => {
session.info.last_cache_checkpoint_time = current_time;
return Ok(false); }
}
}
}
if config.cache_tokens_threshold > 0
&& session.info.current_non_cached_tokens >= config.cache_tokens_threshold
{
let target_index = session
.messages
.iter()
.enumerate()
.rev()
.find(|(_, msg)| msg.role == "tool")
.or_else(|| {
session
.messages
.iter()
.enumerate()
.rev()
.find(|(_, msg)| msg.role == "user")
})
.map(|(i, _)| i);
if let Some(index) = target_index {
match self.apply_cache_to_message(session, index, supports_caching) {
Ok(true) => return Ok(true),
Ok(false) => return Ok(false),
Err(_) => return Ok(false), }
}
}
Ok(false)
}
pub fn check_and_apply_auto_cache_threshold_on_tool_result(
&self,
session: &mut Session,
config: &Config,
supports_caching: bool,
tool_message_index: usize,
_role: &str,
) -> Result<bool> {
if !supports_caching {
return Ok(false);
}
if session.messages.len() <= tool_message_index {
return Ok(false);
}
if let Some(msg) = session.messages.get(tool_message_index) {
if msg.role != "tool" {
return Ok(false);
}
} else {
return Ok(false);
}
if config.cache_tokens_threshold > 0
&& session.info.current_non_cached_tokens >= config.cache_tokens_threshold
{
match self.apply_cache_to_message(session, tool_message_index, supports_caching) {
Ok(true) => return Ok(true),
Ok(false) => return Ok(false),
Err(_) => return Ok(false), }
}
let current_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let time_since_last_cache =
current_time.saturating_sub(session.info.last_cache_checkpoint_time);
if time_since_last_cache >= config.cache_timeout_seconds {
match self.apply_cache_to_message(session, tool_message_index, supports_caching) {
Ok(true) => return Ok(true),
Ok(false) => {
session.info.last_cache_checkpoint_time = current_time;
return Ok(false);
}
Err(_) => {
session.info.last_cache_checkpoint_time = current_time;
return Ok(false); }
}
}
Ok(false)
}
pub fn update_token_tracking(
&self,
session: &mut Session,
input_tokens: u64,
output_tokens: u64,
cache_read_tokens: u64,
cache_write_tokens: u64,
reasoning_tokens: u64,
) {
session.info.input_tokens += input_tokens;
session.info.output_tokens += output_tokens;
session.info.cache_read_tokens += cache_read_tokens;
session.info.cache_write_tokens += cache_write_tokens;
session.info.reasoning_tokens += reasoning_tokens;
let total_input = input_tokens + cache_read_tokens;
session.info.current_total_tokens += total_input;
session.info.current_non_cached_tokens += input_tokens;
}
pub fn estimate_current_session_tokens(&self, session: &Session) -> (u64, u64) {
let mut total_tokens = 0;
let mut non_cached_tokens = 0;
for msg in &session.messages {
let message_tokens = crate::session::estimate_message_tokens(msg) as u64;
total_tokens += message_tokens;
if !msg.cached {
non_cached_tokens += message_tokens;
}
}
(total_tokens, non_cached_tokens)
}
pub fn get_cache_statistics(&self, session: &Session) -> CacheStatistics {
self.get_cache_statistics_with_config(session, None)
}
pub fn get_cache_statistics_with_config(
&self,
session: &Session,
config: Option<&crate::config::Config>,
) -> CacheStatistics {
let mut content_markers = 0;
let mut system_markers = 0;
let mut tool_markers = 0;
for msg in &session.messages {
if msg.cached {
match msg.role.as_str() {
"system" => system_markers += 1,
"user" => content_markers += 1,
"tool" => {
if msg.tool_call_id.is_some() {
content_markers += 1;
} else {
tool_markers += 1; }
}
"assistant" => content_markers += 1, _ => {}
}
}
}
let has_cached_system = system_markers > 0;
let supports_caching = crate::session::model_supports_caching(&session.info.model);
if has_cached_system && supports_caching {
let has_tools = if let Some(cfg) = config {
!cfg.mcp.servers.is_empty()
} else {
session.info.tool_calls > 0 ||
(session.info.input_tokens > 0 && has_cached_system) ||
(session.info.input_tokens == 0 && session.info.cache_read_tokens == 0 && has_cached_system)
};
if has_tools && tool_markers == 0 {
tool_markers = 1; }
}
CacheStatistics {
content_markers,
system_markers,
tool_markers,
total_cache_read_tokens: session.info.cache_read_tokens,
total_cache_write_tokens: session.info.cache_write_tokens,
total_input_tokens: session.info.input_tokens + session.info.cache_read_tokens,
total_output_tokens: session.info.output_tokens,
current_non_cached_tokens: session.info.current_non_cached_tokens,
current_total_tokens: session.info.current_total_tokens,
cache_efficiency: if session.info.input_tokens + session.info.cache_read_tokens > 0 {
(session.info.cache_read_tokens as f64
/ (session.info.input_tokens + session.info.cache_read_tokens) as f64)
* 100.0
} else {
0.0
},
}
}
pub fn clear_content_cache_markers(&self, session: &mut Session) -> usize {
let mut cleared = 0;
for msg in &mut session.messages {
if msg.cached && (msg.role == "user" || msg.role == "tool" || msg.role == "assistant") {
if msg.role != "system" {
msg.cached = false;
cleared += 1;
}
}
}
cleared
}
pub fn apply_cache_to_message(
&self,
session: &mut Session,
message_index: usize,
supports_caching: bool,
) -> Result<bool> {
if !supports_caching {
return Ok(false);
}
if message_index >= session.messages.len() {
return Err(anyhow::anyhow!(
"Message index {} is out of bounds",
message_index
));
}
if let Some(msg) = session.messages.get(message_index) {
if msg.cached {
return Ok(false); }
}
let mut existing_markers: Vec<usize> = Vec::new();
let mut first_marker_to_remove: Option<usize> = None;
for (i, msg) in session.messages.iter().enumerate() {
if msg.cached && (msg.role == "user" || msg.role == "tool" || msg.role == "assistant") {
existing_markers.push(i);
}
}
existing_markers.sort();
if existing_markers.contains(&message_index) {
return Ok(false); }
if existing_markers.len() >= self.max_content_markers {
first_marker_to_remove = existing_markers.first().copied();
}
if let Some(first_marker_index) = first_marker_to_remove {
if let Some(first_msg) = session.messages.get_mut(first_marker_index) {
first_msg.cached = false;
}
}
if let Some(msg) = session.messages.get_mut(message_index) {
msg.cached = true;
session.info.current_non_cached_tokens = 0;
session.info.current_total_tokens = 0;
session.info.last_cache_checkpoint_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
return Ok(true);
}
Ok(false)
}
pub fn apply_cache_to_current_user_message(
&self,
session: &mut Session,
supports_caching: bool,
) -> Result<bool> {
if !supports_caching {
return Ok(false);
}
for (i, msg) in session.messages.iter().enumerate().rev() {
if msg.role == "user" {
return self.apply_cache_to_message(session, i, supports_caching);
}
}
Err(anyhow::anyhow!("No user message found to cache"))
}
pub fn apply_cache_to_current_tool_message(
&self,
session: &mut Session,
supports_caching: bool,
) -> Result<bool> {
if !supports_caching {
return Ok(false);
}
for (i, msg) in session.messages.iter().enumerate().rev() {
if msg.role == "tool" {
return self.apply_cache_to_message(session, i, supports_caching);
}
}
for (i, msg) in session.messages.iter().enumerate().rev() {
if msg.role == "user" {
return self.apply_cache_to_message(session, i, supports_caching);
}
}
Err(anyhow::anyhow!("No suitable message found to cache"))
}
pub fn validate_cache_support(&self, provider: &str, model: &str) -> bool {
match provider.to_lowercase().as_str() {
"openrouter" => {
model.contains("claude") ||
model.contains("gemini")
}
"anthropic" => {
model.contains("claude-3-5") || model.contains("claude-3.5")
}
"google" => {
model.contains("gemini-1.5")
}
_ => false,
}
}
}
#[derive(Debug, Clone)]
pub struct CacheStatistics {
pub content_markers: usize,
pub system_markers: usize,
pub tool_markers: usize,
pub total_cache_read_tokens: u64,
pub total_cache_write_tokens: u64,
pub total_input_tokens: u64, pub total_output_tokens: u64, pub current_non_cached_tokens: u64,
pub current_total_tokens: u64,
pub cache_efficiency: f64, }
impl CacheStatistics {
pub fn format_for_display(&self) -> String {
use colored::Colorize;
let mut output = String::new();
output.push_str(&format!("{}\n", "── Cache Statistics ──".bright_cyan()));
if self.content_markers > 0 || self.system_markers > 0 || self.tool_markers > 0 {
output.push_str(&format!(
"Active markers: {} content, {} system, {} tool\n",
self.content_markers.to_string().bright_blue(),
self.system_markers.to_string().bright_green(),
self.tool_markers.to_string().bright_yellow()
));
} else {
output.push_str(&format!("{}\n", "No active cache markers".bright_black()));
}
if self.total_cache_read_tokens > 0 || self.total_cache_write_tokens > 0 {
output.push_str(&format!(
"Total input tokens: {} ({} cache read, {} cache write, {} processed)\n",
format_number(self.total_input_tokens).bright_blue(),
format_number(self.total_cache_read_tokens).bright_magenta(),
format_number(self.total_cache_write_tokens).bright_yellow(),
format_number(self.total_input_tokens - self.total_cache_read_tokens).bright_cyan()
));
output.push_str(&format!(
"Total output tokens: {} (not cacheable)\n",
format_number(self.total_output_tokens).bright_cyan()
));
output.push_str(&format!(
"Overall cache efficiency: {:.1}% (lifetime session average)\n",
self.cache_efficiency.to_string().bright_green()
));
} else {
output.push_str(&format!(
"{}\n",
"No cached tokens recorded yet".bright_black()
));
}
if self.total_input_tokens > 0 {
let session_cached_pct =
(self.total_cache_read_tokens as f64 / self.total_input_tokens as f64) * 100.0;
let session_processed_pct = 100.0 - session_cached_pct;
output.push_str(&format!(
"Session totals: {:.1}% cache read, {:.1}% processed ({}/{} total input tokens)\n",
session_cached_pct.to_string().bright_green(),
session_processed_pct.to_string().bright_yellow(),
format_number(self.total_cache_read_tokens).bright_magenta(),
format_number(self.total_input_tokens).bright_blue()
));
}
output
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::session::{Session, SessionInfo};
fn create_test_session() -> Session {
Session {
info: SessionInfo {
name: "test".to_string(),
created_at: 0,
model: "openrouter:anthropic/claude-3.5-sonnet".to_string(),
provider: "openrouter".to_string(),
input_tokens: 0,
output_tokens: 0,
cache_read_tokens: 0,
cache_write_tokens: 0,
reasoning_tokens: 0,
total_cost: 0.0,
duration_seconds: 0,
layer_stats: Vec::new(),
tool_calls: 0,
total_api_time_ms: 0,
total_layer_time_ms: 0,
total_tool_time_ms: 0,
compression_stats: crate::session::CompressionStats::default(),
total_api_calls: 0,
current_non_cached_tokens: 0,
current_total_tokens: 0,
last_cache_checkpoint_time: 0,
cache_next_user_message: false,
spending_threshold_checkpoint: 0.0,
compression_hint_count: 0,
last_compression_hint_shown: 0,
context_tokens_after_last_compression: 0,
predicted_turns_at_last_compression: 0.0,
api_calls_at_last_compression: 0,
output_tokens_at_last_compression: 0,
consecutive_compressions: 0,
},
messages: Vec::new(),
session_file: None,
}
}
#[test]
fn test_cache_manager_creation() {
let manager = CacheManager::new();
assert_eq!(manager.max_content_markers, 2);
}
#[test]
fn test_two_marker_system() {
let manager = CacheManager::new();
let mut session = create_test_session();
session.add_message("user", "First message");
session.add_message("assistant", "First response");
session.add_message("user", "Second message");
session.add_message("assistant", "Second response");
session.add_message("user", "Third message");
let result = manager.manage_content_cache_markers(&mut session, Some(0), false);
assert!(result.is_ok());
assert!(result.unwrap());
assert!(session.messages[0].cached);
let result = manager.manage_content_cache_markers(&mut session, Some(2), false);
assert!(result.is_ok());
assert!(result.unwrap());
assert!(session.messages[2].cached);
let result = manager.manage_content_cache_markers(&mut session, Some(4), false);
assert!(result.is_ok());
assert!(result.unwrap());
assert!(!session.messages[0].cached); assert!(session.messages[2].cached); assert!(session.messages[4].cached); }
#[test]
fn test_cache_support_validation() {
let manager = CacheManager::new();
assert!(manager.validate_cache_support("openrouter", "anthropic/claude-3.5-sonnet"));
assert!(manager.validate_cache_support("openrouter", "claude-3-opus"));
assert!(manager.validate_cache_support("openrouter", "google/gemini-1.5-pro"));
assert!(manager.validate_cache_support("openrouter", "gemini-1.5-flash"));
assert!(!manager.validate_cache_support("openrouter", "openai/gpt-4"));
assert!(manager.validate_cache_support("anthropic", "claude-3.5-sonnet"));
assert!(manager.validate_cache_support("anthropic", "claude-3-5-haiku"));
assert!(!manager.validate_cache_support("anthropic", "claude-3-opus"));
assert!(manager.validate_cache_support("google", "gemini-1.5-pro"));
assert!(!manager.validate_cache_support("google", "gemini-pro"));
assert!(!manager.validate_cache_support("openai", "gpt-4"));
}
#[test]
fn test_automatic_cache_markers() {
let manager = CacheManager::new();
let mut messages = vec![
Message {
role: "system".to_string(),
content: "You are an AI assistant".to_string(),
timestamp: 0,
cached: false,
cache_ttl: None,
tool_call_id: None,
name: None,
tool_calls: None,
images: None,
videos: None,
thinking: None,
id: None,
},
Message {
role: "user".to_string(),
content: "Hello".to_string(),
timestamp: 0,
cached: false,
cache_ttl: None,
tool_call_id: None,
name: None,
tool_calls: None,
images: None,
videos: None,
thinking: None,
id: None,
},
];
manager.add_automatic_cache_markers(&mut messages, true, true);
assert!(messages[0].cached);
assert!(!messages[1].cached);
}
}