use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum CacheControl {
Ephemeral {
#[serde(skip_serializing_if = "Option::is_none")]
ttl: Option<String>,
},
}
impl CacheControl {
pub fn ephemeral() -> Self {
Self::Ephemeral { ttl: None }
}
pub fn ephemeral_with_ttl(ttl: impl Into<String>) -> Self {
Self::Ephemeral {
ttl: Some(ttl.into()),
}
}
pub fn has_ttl(&self) -> bool {
match self {
Self::Ephemeral { ttl } => ttl.is_some(),
}
}
pub fn ttl(&self) -> Option<&str> {
match self {
Self::Ephemeral { ttl } => ttl.as_deref(),
}
}
}
impl Default for CacheControl {
fn default() -> Self {
Self::ephemeral()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum PromptCacheRetention {
#[default]
InMemory,
#[serde(rename = "24h")]
Extended24h,
}
impl std::fmt::Display for PromptCacheRetention {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InMemory => write!(f, "in_memory"),
Self::Extended24h => write!(f, "24h"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CacheWarningType {
UnsupportedContext,
BreakpointLimitExceeded,
UnsupportedProvider,
}
impl std::fmt::Display for CacheWarningType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UnsupportedContext => write!(f, "unsupported_context"),
Self::BreakpointLimitExceeded => write!(f, "breakpoint_limit_exceeded"),
Self::UnsupportedProvider => write!(f, "unsupported_provider"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheWarning {
pub warning_type: CacheWarningType,
pub message: String,
}
impl CacheWarning {
pub fn new(warning_type: CacheWarningType, message: impl Into<String>) -> Self {
Self {
warning_type,
message: message.into(),
}
}
pub fn unsupported_context(context_type: &str) -> Self {
Self::new(
CacheWarningType::UnsupportedContext,
format!(
"cache_control cannot be set on {}. It will be ignored.",
context_type
),
)
}
pub fn breakpoint_limit_exceeded(count: usize, max: usize) -> Self {
Self::new(
CacheWarningType::BreakpointLimitExceeded,
format!(
"Maximum {} cache breakpoints exceeded (found {}). This breakpoint will be ignored.",
max, count
),
)
}
pub fn unsupported_provider(provider: &str) -> Self {
Self::new(
CacheWarningType::UnsupportedProvider,
format!(
"cache_control is not supported by provider '{}'. It will be ignored.",
provider
),
)
}
}
impl std::fmt::Display for CacheWarning {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[{}] {}", self.warning_type, self.message)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_control_ephemeral() {
let cache = CacheControl::ephemeral();
assert!(!cache.has_ttl());
assert_eq!(cache.ttl(), None);
}
#[test]
fn test_cache_control_with_ttl() {
let cache = CacheControl::ephemeral_with_ttl("1h");
assert!(cache.has_ttl());
assert_eq!(cache.ttl(), Some("1h"));
}
#[test]
fn test_cache_control_serialization() {
let cache = CacheControl::ephemeral();
let json = serde_json::to_string(&cache).unwrap();
assert_eq!(json, r#"{"type":"ephemeral"}"#);
let cache = CacheControl::ephemeral_with_ttl("1h");
let json = serde_json::to_string(&cache).unwrap();
assert_eq!(json, r#"{"type":"ephemeral","ttl":"1h"}"#);
}
#[test]
fn test_prompt_cache_retention_serialization() {
let retention = PromptCacheRetention::InMemory;
let json = serde_json::to_string(&retention).unwrap();
assert_eq!(json, r#""in_memory""#);
let retention = PromptCacheRetention::Extended24h;
let json = serde_json::to_string(&retention).unwrap();
assert_eq!(json, r#""24h""#);
}
#[test]
fn test_cache_warning_display() {
let warning = CacheWarning::breakpoint_limit_exceeded(5, 4);
let display = format!("{}", warning);
assert!(display.contains("breakpoint_limit_exceeded"));
assert!(display.contains("Maximum 4"));
}
}