use crate::types::{
CacheControl, ContentBlock, MessageContent, MessagesRequest, Role, SystemPrompt, TextBlock,
TypedContentBlock,
};
pub const MAX_CACHE_BREAKPOINTS: usize = 4;
pub trait CacheControlStrategy: Send + Sync {
fn apply(&self, req: &mut MessagesRequest);
}
pub struct DefaultCacheControlStrategy {
pub marker: CacheControl,
pub min_user_messages: usize,
}
impl Default for DefaultCacheControlStrategy {
fn default() -> Self {
Self {
marker: ephemeral_marker(),
min_user_messages: 2,
}
}
}
impl CacheControlStrategy for DefaultCacheControlStrategy {
fn apply(&self, req: &mut MessagesRequest) {
if has_any_cache_control(req) {
return;
}
inject_tools_cache(req, &self.marker);
inject_system_cache(req, &self.marker);
inject_messages_cache(req, &self.marker, self.min_user_messages);
}
}
pub struct NoCacheControlStrategy;
impl CacheControlStrategy for NoCacheControlStrategy {
fn apply(&self, _req: &mut MessagesRequest) {}
}
#[must_use]
pub fn ephemeral_marker() -> CacheControl {
CacheControl {
r#type: "ephemeral".to_owned(),
ttl: None,
}
}
#[must_use]
pub fn ephemeral_marker_with_ttl(ttl_seconds: u64) -> CacheControl {
CacheControl {
r#type: "ephemeral".to_owned(),
ttl: Some(ttl_seconds),
}
}
fn has_any_cache_control(req: &MessagesRequest) -> bool {
if let Some(tools) = &req.tools
&& tools.iter().any(|t| t.cache_control.is_some())
{
return true;
}
if let Some(SystemPrompt::Blocks(blocks)) = &req.system
&& blocks.iter().any(|b| b.cache_control.is_some())
{
return true;
}
for msg in &req.messages {
if let MessageContent::Blocks(blocks) = &msg.content {
for block in blocks {
if block_has_cache_control(block) {
return true;
}
}
}
}
false
}
fn block_has_cache_control(block: &ContentBlock) -> bool {
match block {
ContentBlock::Typed(t) => typed_block_cache_control(t).is_some(),
ContentBlock::Raw(obj) => obj.contains_key("cache_control"),
}
}
fn typed_block_cache_control(t: &TypedContentBlock) -> Option<&CacheControl> {
match t {
TypedContentBlock::Text { cache_control, .. }
| TypedContentBlock::Image { cache_control, .. }
| TypedContentBlock::ToolUse { cache_control, .. }
| TypedContentBlock::ToolResult { cache_control, .. } => cache_control.as_ref(),
TypedContentBlock::Thinking { .. } | TypedContentBlock::RedactedThinking { .. } => None,
}
}
fn typed_block_cache_control_mut(t: &mut TypedContentBlock) -> Option<&mut Option<CacheControl>> {
match t {
TypedContentBlock::Text { cache_control, .. }
| TypedContentBlock::Image { cache_control, .. }
| TypedContentBlock::ToolUse { cache_control, .. }
| TypedContentBlock::ToolResult { cache_control, .. } => Some(cache_control),
TypedContentBlock::Thinking { .. } | TypedContentBlock::RedactedThinking { .. } => None,
}
}
fn inject_tools_cache(req: &mut MessagesRequest, marker: &CacheControl) {
if let Some(tools) = &mut req.tools
&& let Some(last) = tools.last_mut()
&& last.cache_control.is_none()
{
last.cache_control = Some(marker.clone());
}
}
fn inject_system_cache(req: &mut MessagesRequest, marker: &CacheControl) {
match req.system.as_mut() {
Some(sp @ SystemPrompt::Text(_)) => {
let SystemPrompt::Text(text) = sp else {
unreachable!()
};
if text.is_empty() {
return;
}
let promoted = vec![TextBlock {
r#type: "text".to_owned(),
text: std::mem::take(text),
cache_control: Some(marker.clone()),
}];
*sp = SystemPrompt::Blocks(promoted);
}
Some(SystemPrompt::Blocks(blocks)) => {
if let Some(last) = blocks.last_mut()
&& last.cache_control.is_none()
{
last.cache_control = Some(marker.clone());
}
}
None => {}
}
}
fn inject_messages_cache(
req: &mut MessagesRequest,
marker: &CacheControl,
min_user_messages: usize,
) {
let user_indices: Vec<usize> = req
.messages
.iter()
.enumerate()
.filter(|(_, m)| m.role == Role::User)
.map(|(i, _)| i)
.collect();
if user_indices.len() < min_user_messages {
return;
}
let target_idx = user_indices[user_indices.len() - 2];
let msg = &mut req.messages[target_idx];
if let MessageContent::Text(text) = &msg.content {
let promoted = vec![ContentBlock::Typed(TypedContentBlock::Text {
text: text.clone(),
cache_control: None,
})];
msg.content = MessageContent::Blocks(promoted);
}
if let MessageContent::Blocks(blocks) = &mut msg.content
&& let Some(last) = blocks.last_mut()
&& let ContentBlock::Typed(typed) = last
&& let Some(slot) = typed_block_cache_control_mut(typed)
&& slot.is_none()
{
*slot = Some(marker.clone());
}
}
#[must_use]
pub fn count_cache_controls(req: &MessagesRequest) -> usize {
let mut count = 0;
if let Some(tools) = &req.tools {
count += tools.iter().filter(|t| t.cache_control.is_some()).count();
}
if let Some(SystemPrompt::Blocks(blocks)) = &req.system {
count += blocks.iter().filter(|b| b.cache_control.is_some()).count();
}
for msg in &req.messages {
if let MessageContent::Blocks(blocks) = &msg.content {
for block in blocks {
if block_has_cache_control(block) {
count += 1;
}
}
}
}
count
}
pub fn enforce_breakpoint_cap(req: &mut MessagesRequest) {
enforce_breakpoint_cap_with_max(req, MAX_CACHE_BREAKPOINTS);
}
pub fn enforce_breakpoint_cap_with_max(req: &mut MessagesRequest, max_blocks: usize) {
let total = count_cache_controls(req);
if total <= max_blocks {
return;
}
let mut excess = total - max_blocks;
if let Some(SystemPrompt::Blocks(blocks)) = req.system.as_mut() {
let last_cc = blocks.iter().rposition(|b| b.cache_control.is_some());
for (i, block) in blocks.iter_mut().enumerate() {
if excess == 0 {
break;
}
if Some(i) != last_cc && block.cache_control.is_some() {
block.cache_control = None;
excess -= 1;
}
}
}
if excess == 0 {
return;
}
if let Some(tools) = req.tools.as_mut() {
let last_cc = tools.iter().rposition(|t| t.cache_control.is_some());
for (i, t) in tools.iter_mut().enumerate() {
if excess == 0 {
break;
}
if Some(i) != last_cc && t.cache_control.is_some() {
t.cache_control = None;
excess -= 1;
}
}
}
if excess == 0 {
return;
}
for msg in req.messages.iter_mut() {
if excess == 0 {
break;
}
if let MessageContent::Blocks(blocks) = &mut msg.content {
for block in blocks.iter_mut() {
if excess == 0 {
break;
}
clear_cache_control(block, &mut excess);
}
}
}
}
fn clear_cache_control(block: &mut ContentBlock, excess: &mut usize) {
match block {
ContentBlock::Typed(t) => {
if let Some(slot) = typed_block_cache_control_mut(t)
&& slot.is_some()
{
*slot = None;
*excess -= 1;
}
}
ContentBlock::Raw(obj) => {
if obj.remove("cache_control").is_some() {
*excess -= 1;
}
}
}
}
pub fn normalize_ttl_ordering(req: &mut MessagesRequest) {
let mut seen_short = false;
if let Some(tools) = req.tools.as_mut() {
for t in tools.iter_mut() {
if let Some(cc) = t.cache_control.as_mut() {
normalize_ttl(cc, &mut seen_short);
}
}
}
if let Some(SystemPrompt::Blocks(blocks)) = req.system.as_mut() {
for b in blocks.iter_mut() {
if let Some(cc) = b.cache_control.as_mut() {
normalize_ttl(cc, &mut seen_short);
}
}
}
for msg in req.messages.iter_mut() {
if let MessageContent::Blocks(blocks) = &mut msg.content {
for block in blocks.iter_mut() {
normalize_block_ttl(block, &mut seen_short);
}
}
}
}
fn normalize_block_ttl(block: &mut ContentBlock, seen_short: &mut bool) {
match block {
ContentBlock::Typed(t) => {
if let Some(slot) = typed_block_cache_control_mut(t)
&& let Some(cc) = slot.as_mut()
{
normalize_ttl(cc, seen_short);
}
}
ContentBlock::Raw(obj) => {
if let Some(cc_val) = obj.get_mut("cache_control")
&& let serde_json::Value::Object(cc_obj) = cc_val
{
let ttl = cc_obj.get("ttl").and_then(serde_json::Value::as_u64);
match ttl {
None | Some(0..=300) => *seen_short = true,
Some(_) if *seen_short => {
cc_obj.remove("ttl");
}
_ => {}
}
}
}
}
}
fn normalize_ttl(cc: &mut CacheControl, seen_short: &mut bool) {
match cc.ttl {
None | Some(0..=300) => *seen_short = true,
Some(_) if *seen_short => {
cc.ttl = None;
}
_ => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Message, Tool};
fn user_text(text: &str) -> Message {
Message {
role: Role::User,
content: MessageContent::Text(text.to_owned()),
}
}
fn assistant_text(text: &str) -> Message {
Message {
role: Role::Assistant,
content: MessageContent::Text(text.to_owned()),
}
}
fn req_with(messages: Vec<Message>) -> MessagesRequest {
MessagesRequest::builder()
.model("claude-sonnet-4-20250514")
.messages(messages)
.max_tokens(1024_u64)
.build()
}
fn tool(name: &str) -> Tool {
Tool {
name: name.to_owned(),
description: None,
input_schema: serde_json::json!({"type":"object"}),
cache_control: None,
}
}
fn ephemeral_with_ttl(ttl: u64) -> CacheControl {
ephemeral_marker_with_ttl(ttl)
}
#[test]
fn default_strategy_no_op_with_only_one_user_message() {
let mut req = req_with(vec![user_text("hi")]);
DefaultCacheControlStrategy::default().apply(&mut req);
assert_eq!(count_cache_controls(&req), 0);
}
#[test]
fn default_strategy_injects_on_second_to_last_user_message() {
let mut req = req_with(vec![
user_text("first"),
assistant_text("reply"),
user_text("second"),
]);
DefaultCacheControlStrategy::default().apply(&mut req);
let MessageContent::Blocks(blocks) = &req.messages[0].content else {
panic!("expected Blocks");
};
let typed = match &blocks[0] {
ContentBlock::Typed(t) => t,
_ => panic!("expected Typed"),
};
let TypedContentBlock::Text { cache_control, .. } = typed else {
panic!("expected Text");
};
assert!(cache_control.is_some());
assert!(matches!(req.messages[2].content, MessageContent::Text(_)));
}
#[test]
fn default_strategy_injects_on_last_tool() {
let mut req = req_with(vec![user_text("hi")]);
req.tools = Some(vec![tool("a"), tool("b")]);
DefaultCacheControlStrategy::default().apply(&mut req);
let tools = req.tools.as_ref().unwrap();
assert!(tools[0].cache_control.is_none());
assert!(tools[1].cache_control.is_some());
}
#[test]
fn default_strategy_promotes_system_string_to_blocks() {
let mut req = req_with(vec![user_text("hi")]);
req.system = Some(SystemPrompt::Text("You are helpful.".into()));
DefaultCacheControlStrategy::default().apply(&mut req);
let SystemPrompt::Blocks(blocks) = req.system.as_ref().unwrap() else {
panic!("expected promoted Blocks");
};
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].text, "You are helpful.");
assert!(blocks[0].cache_control.is_some());
}
#[test]
fn default_strategy_injects_on_last_system_block() {
let mut req = req_with(vec![user_text("hi")]);
req.system = Some(SystemPrompt::Blocks(vec![
TextBlock {
r#type: "text".into(),
text: "first".into(),
cache_control: None,
},
TextBlock {
r#type: "text".into(),
text: "second".into(),
cache_control: None,
},
]));
DefaultCacheControlStrategy::default().apply(&mut req);
let SystemPrompt::Blocks(blocks) = req.system.as_ref().unwrap() else {
panic!()
};
assert!(blocks[0].cache_control.is_none());
assert!(blocks[1].cache_control.is_some());
}
#[test]
fn default_strategy_skips_when_any_cache_control_present() {
let mut req = req_with(vec![user_text("first"), user_text("second")]);
req.tools = Some(vec![Tool {
name: "t".into(),
description: None,
input_schema: serde_json::json!({"type":"object"}),
cache_control: Some(ephemeral_marker()),
}]);
DefaultCacheControlStrategy::default().apply(&mut req);
assert_eq!(count_cache_controls(&req), 1);
}
#[test]
fn default_strategy_skips_empty_system_string() {
let mut req = req_with(vec![user_text("hi")]);
req.system = Some(SystemPrompt::Text(String::new()));
DefaultCacheControlStrategy::default().apply(&mut req);
assert!(matches!(req.system, Some(SystemPrompt::Text(_))));
assert_eq!(count_cache_controls(&req), 0);
}
#[test]
fn no_cache_control_strategy_is_a_noop() {
let mut req = req_with(vec![user_text("first"), user_text("second")]);
req.tools = Some(vec![tool("a")]);
NoCacheControlStrategy.apply(&mut req);
assert_eq!(count_cache_controls(&req), 0);
}
#[test]
fn enforce_cap_strips_excess() {
let mut req = req_with(vec![user_text("hi")]);
req.tools = Some(vec![
Tool {
name: "a".into(),
description: None,
input_schema: serde_json::json!({"type":"object"}),
cache_control: Some(ephemeral_marker()),
},
Tool {
name: "b".into(),
description: None,
input_schema: serde_json::json!({"type":"object"}),
cache_control: Some(ephemeral_marker()),
},
Tool {
name: "c".into(),
description: None,
input_schema: serde_json::json!({"type":"object"}),
cache_control: Some(ephemeral_marker()),
},
]);
req.system = Some(SystemPrompt::Blocks(vec![
TextBlock {
r#type: "text".into(),
text: "s1".into(),
cache_control: Some(ephemeral_marker()),
},
TextBlock {
r#type: "text".into(),
text: "s2".into(),
cache_control: Some(ephemeral_marker()),
},
]));
enforce_breakpoint_cap(&mut req);
assert_eq!(count_cache_controls(&req), 4);
let SystemPrompt::Blocks(blocks) = req.system.as_ref().unwrap() else {
panic!()
};
assert!(blocks[1].cache_control.is_some());
let tools = req.tools.as_ref().unwrap();
assert!(tools[2].cache_control.is_some());
}
#[test]
fn enforce_cap_under_limit_no_change() {
let mut req = req_with(vec![user_text("hi")]);
req.tools = Some(vec![Tool {
name: "a".into(),
description: None,
input_schema: serde_json::json!({"type":"object"}),
cache_control: Some(ephemeral_marker()),
}]);
enforce_breakpoint_cap(&mut req);
assert_eq!(count_cache_controls(&req), 1);
}
#[test]
fn normalize_ttl_strips_long_after_short() {
let mut req = req_with(vec![user_text("hi")]);
req.tools = Some(vec![
Tool {
name: "a".into(),
description: None,
input_schema: serde_json::json!({"type":"object"}),
cache_control: Some(ephemeral_marker()),
},
Tool {
name: "b".into(),
description: None,
input_schema: serde_json::json!({"type":"object"}),
cache_control: Some(ephemeral_with_ttl(3600)),
},
]);
normalize_ttl_ordering(&mut req);
let tools = req.tools.as_ref().unwrap();
assert!(tools[0].cache_control.as_ref().unwrap().ttl.is_none());
assert!(tools[1].cache_control.as_ref().unwrap().ttl.is_none());
}
#[test]
fn normalize_ttl_preserves_long_when_no_short_seen() {
let mut req = req_with(vec![user_text("hi")]);
req.tools = Some(vec![
Tool {
name: "a".into(),
description: None,
input_schema: serde_json::json!({"type":"object"}),
cache_control: Some(ephemeral_with_ttl(3600)),
},
Tool {
name: "b".into(),
description: None,
input_schema: serde_json::json!({"type":"object"}),
cache_control: Some(ephemeral_with_ttl(3600)),
},
]);
normalize_ttl_ordering(&mut req);
let tools = req.tools.as_ref().unwrap();
assert_eq!(tools[0].cache_control.as_ref().unwrap().ttl, Some(3600));
assert_eq!(tools[1].cache_control.as_ref().unwrap().ttl, Some(3600));
}
#[test]
fn normalize_ttl_walks_in_evaluation_order() {
let mut req = req_with(vec![user_text("hi")]);
req.tools = Some(vec![Tool {
name: "a".into(),
description: None,
input_schema: serde_json::json!({"type":"object"}),
cache_control: Some(ephemeral_marker()),
}]);
req.system = Some(SystemPrompt::Blocks(vec![TextBlock {
r#type: "text".into(),
text: "s".into(),
cache_control: Some(ephemeral_with_ttl(3600)),
}]));
normalize_ttl_ordering(&mut req);
let SystemPrompt::Blocks(blocks) = req.system.as_ref().unwrap() else {
panic!()
};
assert!(blocks[0].cache_control.as_ref().unwrap().ttl.is_none());
}
#[test]
fn dyn_dispatch_compiles() {
let _: Box<dyn CacheControlStrategy> = Box::new(DefaultCacheControlStrategy::default());
let _: Box<dyn CacheControlStrategy> = Box::new(NoCacheControlStrategy);
}
}