use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::error::LlmError;
use crate::types::{ChatMessage, ContentPart, MessageContent};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheControl {
pub r#type: CacheType,
pub ttl: Option<u32>,
pub cache_key: Option<String>,
}
impl Default for CacheControl {
fn default() -> Self {
Self {
r#type: CacheType::Ephemeral,
ttl: None,
cache_key: None,
}
}
}
impl CacheControl {
pub const fn ephemeral() -> Self {
Self {
r#type: CacheType::Ephemeral,
ttl: None,
cache_key: None,
}
}
pub const fn with_ttl(mut self, ttl_seconds: u32) -> Self {
self.ttl = Some(ttl_seconds);
self
}
pub fn with_key<S: Into<String>>(mut self, key: S) -> Self {
self.cache_key = Some(key.into());
self
}
pub fn to_json(&self) -> serde_json::Value {
let mut json = serde_json::json!({
"type": self.r#type
});
if let Some(ttl) = self.ttl {
json["ttl"] = serde_json::Value::Number(ttl.into());
}
if let Some(ref key) = self.cache_key {
json["cache_key"] = serde_json::Value::String(key.clone());
}
json
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum CacheType {
Ephemeral,
}
pub struct CacheAwareMessageBuilder {
message: ChatMessage,
cache_control: Option<CacheControl>,
content_cache_controls: HashMap<usize, CacheControl>,
}
impl CacheAwareMessageBuilder {
pub fn new(message: ChatMessage) -> Self {
Self {
message,
cache_control: None,
content_cache_controls: HashMap::new(),
}
}
pub fn with_cache_control(mut self, cache_control: CacheControl) -> Self {
self.cache_control = Some(cache_control);
self
}
pub fn with_content_cache_control(
mut self,
content_index: usize,
cache_control: CacheControl,
) -> Self {
self.content_cache_controls
.insert(content_index, cache_control);
self
}
pub fn build(self) -> Result<serde_json::Value, LlmError> {
let mut message_json = self.convert_message_to_json()?;
if let Some(cache_control) = self.cache_control {
message_json["cache_control"] = cache_control.to_json();
}
if !self.content_cache_controls.is_empty()
&& let Some(content) = message_json.get_mut("content")
{
match content {
serde_json::Value::Array(content_array) => {
for (index, cache_control) in self.content_cache_controls {
if let Some(content_item) = content_array.get_mut(index)
&& let Some(content_obj) = content_item.as_object_mut()
{
content_obj
.insert("cache_control".to_string(), cache_control.to_json());
}
}
}
serde_json::Value::String(_) => {
}
_ => {}
}
}
Ok(message_json)
}
fn convert_message_to_json(&self) -> Result<serde_json::Value, LlmError> {
let mut message_json = serde_json::json!({
"role": match self.message.role {
crate::types::MessageRole::System => "system",
crate::types::MessageRole::User => "user",
crate::types::MessageRole::Assistant => "assistant",
crate::types::MessageRole::Developer => "user", crate::types::MessageRole::Tool => "tool",
}
});
match &self.message.content {
MessageContent::Text(text) => {
message_json["content"] = serde_json::Value::String(text.clone());
}
MessageContent::MultiModal(parts) => {
let mut content_parts = Vec::new();
for part in parts {
match part {
ContentPart::Text { text } => {
content_parts.push(serde_json::json!({
"type": "text",
"text": text
}));
}
ContentPart::Image { image_url, detail } => {
let mut image_part = serde_json::json!({
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg", "data": image_url
}
});
if let Some(detail) = detail {
image_part["detail"] = serde_json::Value::String(detail.clone());
}
content_parts.push(image_part);
}
ContentPart::Audio { audio_url, format } => {
content_parts.push(serde_json::json!({
"type": "audio",
"source": {
"type": "base64",
"media_type": format,
"data": audio_url
}
}));
}
}
}
message_json["content"] = serde_json::Value::Array(content_parts);
}
}
Ok(message_json)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheStatistics {
pub cache_hits: u32,
pub cache_misses: u32,
pub cached_tokens: u32,
pub cache_creation_tokens: u32,
pub cache_read_tokens: u32,
}
impl CacheStatistics {
pub const fn empty() -> Self {
Self {
cache_hits: 0,
cache_misses: 0,
cached_tokens: 0,
cache_creation_tokens: 0,
cache_read_tokens: 0,
}
}
pub fn from_response(response: &serde_json::Value) -> Option<Self> {
let usage = response.get("usage")?;
Some(Self {
cache_hits: usage.get("cache_hits")?.as_u64()? as u32,
cache_misses: usage.get("cache_misses")?.as_u64()? as u32,
cached_tokens: usage.get("cached_tokens")?.as_u64()? as u32,
cache_creation_tokens: usage.get("cache_creation_tokens")?.as_u64()? as u32,
cache_read_tokens: usage.get("cache_read_tokens")?.as_u64()? as u32,
})
}
pub fn cache_efficiency(&self) -> f64 {
let total_requests = self.cache_hits + self.cache_misses;
if total_requests == 0 {
0.0
} else {
self.cache_hits as f64 / total_requests as f64
}
}
pub const fn token_savings(&self) -> u32 {
self.cache_read_tokens
}
}
pub mod patterns {
use super::*;
pub fn cached_system_message<S: Into<String>>(content: S) -> CacheAwareMessageBuilder {
let message = ChatMessage {
role: crate::types::MessageRole::System,
content: MessageContent::Text(content.into()),
metadata: crate::types::MessageMetadata::default(),
tool_calls: None,
tool_call_id: None,
};
CacheAwareMessageBuilder::new(message).with_cache_control(CacheControl::ephemeral())
}
pub fn cached_document_message<S: Into<String>>(
document: S,
query: S,
) -> CacheAwareMessageBuilder {
let content = format!("Document:\n{}\n\nQuery: {}", document.into(), query.into());
let message = ChatMessage {
role: crate::types::MessageRole::User,
content: MessageContent::Text(content),
metadata: crate::types::MessageMetadata::default(),
tool_calls: None,
tool_call_id: None,
};
CacheAwareMessageBuilder::new(message).with_cache_control(CacheControl::ephemeral())
}
pub fn cached_conversation_context(
context_messages: Vec<ChatMessage>,
new_message: ChatMessage,
) -> Vec<serde_json::Value> {
let mut result = Vec::new();
for (i, message) in context_messages.into_iter().enumerate() {
let builder = CacheAwareMessageBuilder::new(message);
let builder = if i == 0 {
builder.with_cache_control(CacheControl::ephemeral())
} else {
builder
};
if let Ok(json) = builder.build() {
result.push(json);
}
}
if let Ok(json) = CacheAwareMessageBuilder::new(new_message).build() {
result.push(json);
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_control_creation() {
let cache_control = CacheControl::ephemeral()
.with_ttl(3600)
.with_key("test-key");
assert_eq!(cache_control.r#type, CacheType::Ephemeral);
assert_eq!(cache_control.ttl, Some(3600));
assert_eq!(cache_control.cache_key, Some("test-key".to_string()));
}
#[test]
fn test_cache_control_json() {
let cache_control = CacheControl::ephemeral().with_ttl(1800);
let json = cache_control.to_json();
assert_eq!(json["type"], "ephemeral");
assert_eq!(json["ttl"], 1800);
}
#[test]
fn test_cache_statistics() {
let stats = CacheStatistics {
cache_hits: 8,
cache_misses: 2,
cached_tokens: 1000,
cache_creation_tokens: 100,
cache_read_tokens: 800,
};
assert_eq!(stats.cache_efficiency(), 0.8);
assert_eq!(stats.token_savings(), 800);
}
#[test]
fn test_cached_system_message() {
let builder = patterns::cached_system_message("You are a helpful assistant.");
let json = builder.build().unwrap();
assert_eq!(json["role"], "system");
assert_eq!(json["content"], "You are a helpful assistant.");
assert!(json["cache_control"].is_object());
}
}