1use latch_core::{Message, PromptCacheProvider};
2use serde::{Deserialize, Serialize};
3
4#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
5pub struct CacheControl {
6 #[serde(rename = "type")]
7 pub kind: String,
8}
9
10#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
11pub struct CacheTaggedMessage {
12 pub role: String,
13 pub content: String,
14 #[serde(skip_serializing_if = "Option::is_none")]
15 pub cache_control: Option<CacheControl>,
16}
17
18impl From<Message> for CacheTaggedMessage {
19 fn from(value: Message) -> Self {
20 Self {
21 role: value.role,
22 content: value.content,
23 cache_control: None,
24 }
25 }
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
29pub struct PromptCachePlan {
30 pub provider: PromptCacheProvider,
31 pub tagged_indexes: Vec<usize>,
32}
33
34pub fn plan_prompt_cache(messages: &[Message], provider: PromptCacheProvider) -> PromptCachePlan {
35 match provider {
36 PromptCacheProvider::Anthropic => PromptCachePlan {
37 provider,
38 tagged_indexes: messages
39 .iter()
40 .enumerate()
41 .filter_map(|(idx, m)| (m.role == "system").then_some(idx))
42 .collect(),
43 },
44 PromptCacheProvider::OpenAiCompatible | PromptCacheProvider::None => PromptCachePlan {
45 provider,
46 tagged_indexes: Vec::new(),
47 },
48 }
49}
50
51pub fn apply_prompt_cache_plan(
52 messages: &[Message],
53 plan: &PromptCachePlan,
54) -> Vec<CacheTaggedMessage> {
55 let mut out: Vec<CacheTaggedMessage> = messages.iter().cloned().map(Into::into).collect();
56 for idx in &plan.tagged_indexes {
57 if let Some(msg) = out.get_mut(*idx) {
58 msg.cache_control = Some(CacheControl {
59 kind: "ephemeral".to_string(),
60 });
61 }
62 }
63 out
64}
65
66#[cfg(test)]
67mod tests {
68 use super::{apply_prompt_cache_plan, plan_prompt_cache};
69 use latch_core::{Message, PromptCacheProvider};
70
71 fn msg(role: &str, content: &str) -> Message {
72 Message::new(role, content)
73 }
74
75 #[test]
76 fn anthropic_tags_system_messages() {
77 let messages = vec![
78 msg("system", "policy"),
79 msg("user", "hello"),
80 msg("assistant", "hi"),
81 msg("system", "persona"),
82 ];
83 let plan = plan_prompt_cache(&messages, PromptCacheProvider::Anthropic);
84 assert_eq!(plan.tagged_indexes, vec![0, 3]);
85 }
86
87 #[test]
88 fn openai_compatible_has_no_tags() {
89 let messages = vec![msg("system", "policy"), msg("user", "hello")];
90 let plan = plan_prompt_cache(&messages, PromptCacheProvider::OpenAiCompatible);
91 assert!(plan.tagged_indexes.is_empty());
92 }
93
94 #[test]
95 fn apply_sets_ephemeral_marker_only_for_planned_indexes() {
96 let messages = vec![msg("system", "policy"), msg("user", "hello")];
97 let plan = plan_prompt_cache(&messages, PromptCacheProvider::Anthropic);
98 let out = apply_prompt_cache_plan(&messages, &plan);
99
100 assert_eq!(
101 out[0]
102 .cache_control
103 .as_ref()
104 .map(|cc| cc.kind.as_str())
105 .unwrap_or(""),
106 "ephemeral"
107 );
108 assert!(out[1].cache_control.is_none());
109 }
110}