1use latch_core::{Message, PromptCacheProvider};
2use serde::{Deserialize, Serialize};
3
4#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
5pub struct CacheControl {
6 #[serde(rename = "type")]
7 pub kind: String,
8}
9
10#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
11pub struct CacheTaggedMessage {
12 pub role: String,
13 pub content: String,
14 #[serde(skip_serializing_if = "Option::is_none")]
15 pub cache_control: Option<CacheControl>,
16}
17
18impl From<Message> for CacheTaggedMessage {
19 fn from(value: Message) -> Self {
20 Self {
21 role: value.role,
22 content: value.content,
23 cache_control: None,
24 }
25 }
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
29pub struct PromptCachePlan {
30 pub provider: PromptCacheProvider,
31 pub tagged_indexes: Vec<usize>,
32}
33
34pub fn plan_prompt_cache(
35 messages: &[Message],
36 provider: PromptCacheProvider,
37 cache_roles: &[String],
38 min_content_chars: usize,
39) -> PromptCachePlan {
40 match provider {
41 PromptCacheProvider::Anthropic => {
42 let tagged_indexes = messages
43 .iter()
44 .enumerate()
45 .filter_map(|(idx, m)| {
46 if !cache_roles.contains(&m.role) {
47 return None;
48 }
49 if m.content.chars().count() < min_content_chars {
50 return None;
51 }
52 Some(idx)
53 })
54 .collect();
55
56 PromptCachePlan {
57 provider,
58 tagged_indexes,
59 }
60 }
61 PromptCacheProvider::OpenAiCompatible | PromptCacheProvider::None => PromptCachePlan {
62 provider,
63 tagged_indexes: Vec::new(),
64 },
65 }
66}
67
68pub fn plan_prompt_cache_default(
70 messages: &[Message],
71 provider: PromptCacheProvider,
72) -> PromptCachePlan {
73 plan_prompt_cache(messages, provider, &["system".to_string()], 0)
74}
75
76pub fn apply_prompt_cache_plan(
77 messages: &[Message],
78 plan: &PromptCachePlan,
79) -> Vec<CacheTaggedMessage> {
80 let mut out: Vec<CacheTaggedMessage> = messages.iter().cloned().map(Into::into).collect();
81 for idx in &plan.tagged_indexes {
82 if let Some(msg) = out.get_mut(*idx) {
83 msg.cache_control = Some(CacheControl {
84 kind: "ephemeral".to_string(),
85 });
86 }
87 }
88 out
89}
90
91#[cfg(test)]
92mod tests {
93 use super::{apply_prompt_cache_plan, plan_prompt_cache, plan_prompt_cache_default};
94 use latch_core::{Message, PromptCacheProvider};
95
96 fn msg(role: impl Into<String>, content: impl Into<String>) -> Message {
97 Message::new(role, content)
98 }
99
100 #[test]
101 fn anthropic_tags_system_messages_with_sufficient_length() {
102 let messages = vec![
103 msg("system", "a".repeat(100)), msg("user", "hello"),
105 msg("assistant", "hi"),
106 msg("system", "short"), ];
108 let plan = plan_prompt_cache(&messages, PromptCacheProvider::Anthropic, &["system".to_string()], 100);
109 assert_eq!(plan.tagged_indexes, vec![0]);
110 }
111
112 #[test]
113 fn anthropic_tags_multiple_roles() {
114 let messages = vec![
115 msg("system", "a".repeat(100)),
116 msg("user", "b".repeat(100)),
117 ];
118 let cache_roles = vec!["system".to_string(), "user".to_string()];
119 let plan = plan_prompt_cache(&messages, PromptCacheProvider::Anthropic, &cache_roles, 100);
120 assert_eq!(plan.tagged_indexes, vec![0, 1]);
121 }
122
123 #[test]
124 fn openai_compatible_has_no_tags() {
125 let messages = vec![msg("system", "policy"), msg("user", "hello")];
126 let plan = plan_prompt_cache(&messages, PromptCacheProvider::OpenAiCompatible, &["system".to_string()], 0);
127 assert!(plan.tagged_indexes.is_empty());
128 }
129
130 #[test]
131 fn apply_sets_ephemeral_marker_only_for_planned_indexes() {
132 let messages = vec![msg("system", "a".repeat(100)), msg("user", "hello")];
133 let plan = plan_prompt_cache(&messages, PromptCacheProvider::Anthropic, &["system".to_string()], 100);
134 let out = apply_prompt_cache_plan(&messages, &plan);
135
136 assert_eq!(
137 out[0]
138 .cache_control
139 .as_ref()
140 .map(|cc| cc.kind.as_str())
141 .unwrap_or(""),
142 "ephemeral"
143 );
144 assert!(out[1].cache_control.is_none());
145 }
146
147 #[test]
148 fn plan_prompt_cache_default_works() {
149 let messages = vec![
150 msg("system", "a".repeat(100)),
151 msg("user", "hello"),
152 ];
153 let plan = plan_prompt_cache_default(&messages, PromptCacheProvider::Anthropic);
154 assert_eq!(plan.provider, PromptCacheProvider::Anthropic);
155 assert_eq!(plan.tagged_indexes, vec![0]);
156 }
157}