Skip to main content

latch_cache/
lib.rs

1use latch_core::{Message, PromptCacheProvider};
2use serde::{Deserialize, Serialize};
3
4#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
5pub struct CacheControl {
6    #[serde(rename = "type")]
7    pub kind: String,
8}
9
10#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
11pub struct CacheTaggedMessage {
12    pub role: String,
13    pub content: String,
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub cache_control: Option<CacheControl>,
16}
17
18impl From<Message> for CacheTaggedMessage {
19    fn from(value: Message) -> Self {
20        Self {
21            role: value.role,
22            content: value.content,
23            cache_control: None,
24        }
25    }
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
29pub struct PromptCachePlan {
30    pub provider: PromptCacheProvider,
31    pub tagged_indexes: Vec<usize>,
32}
33
34pub fn plan_prompt_cache(
35    messages: &[Message],
36    provider: PromptCacheProvider,
37    cache_roles: &[String],
38    min_content_chars: usize,
39) -> PromptCachePlan {
40    match provider {
41        PromptCacheProvider::Anthropic => {
42            let tagged_indexes = messages
43                .iter()
44                .enumerate()
45                .filter_map(|(idx, m)| {
46                    if !cache_roles.contains(&m.role) {
47                        return None;
48                    }
49                    if m.content.chars().count() < min_content_chars {
50                        return None;
51                    }
52                    Some(idx)
53                })
54                .collect();
55
56            PromptCachePlan {
57                provider,
58                tagged_indexes,
59            }
60        }
61        PromptCacheProvider::OpenAiCompatible | PromptCacheProvider::None => PromptCachePlan {
62            provider,
63            tagged_indexes: Vec::new(),
64        },
65    }
66}
67
68/// Backward-compatible convenience function with default parameters.
69pub fn plan_prompt_cache_default(
70    messages: &[Message],
71    provider: PromptCacheProvider,
72) -> PromptCachePlan {
73    plan_prompt_cache(messages, provider, &["system".to_string()], 0)
74}
75
76pub fn apply_prompt_cache_plan(
77    messages: &[Message],
78    plan: &PromptCachePlan,
79) -> Vec<CacheTaggedMessage> {
80    let mut out: Vec<CacheTaggedMessage> = messages.iter().cloned().map(Into::into).collect();
81    for idx in &plan.tagged_indexes {
82        if let Some(msg) = out.get_mut(*idx) {
83            msg.cache_control = Some(CacheControl {
84                kind: "ephemeral".to_string(),
85            });
86        }
87    }
88    out
89}
90
91#[cfg(test)]
92mod tests {
93    use super::{apply_prompt_cache_plan, plan_prompt_cache, plan_prompt_cache_default};
94    use latch_core::{Message, PromptCacheProvider};
95
96    fn msg(role: impl Into<String>, content: impl Into<String>) -> Message {
97        Message::new(role, content)
98    }
99
100    #[test]
101    fn anthropic_tags_system_messages_with_sufficient_length() {
102        let messages = vec![
103            msg("system", "a".repeat(100)), // Long enough
104            msg("user", "hello"),
105            msg("assistant", "hi"),
106            msg("system", "short"), // Too short
107        ];
108        let plan = plan_prompt_cache(&messages, PromptCacheProvider::Anthropic, &["system".to_string()], 100);
109        assert_eq!(plan.tagged_indexes, vec![0]);
110    }
111
112    #[test]
113    fn anthropic_tags_multiple_roles() {
114        let messages = vec![
115            msg("system", "a".repeat(100)),
116            msg("user", "b".repeat(100)),
117        ];
118        let cache_roles = vec!["system".to_string(), "user".to_string()];
119        let plan = plan_prompt_cache(&messages, PromptCacheProvider::Anthropic, &cache_roles, 100);
120        assert_eq!(plan.tagged_indexes, vec![0, 1]);
121    }
122
123    #[test]
124    fn openai_compatible_has_no_tags() {
125        let messages = vec![msg("system", "policy"), msg("user", "hello")];
126        let plan = plan_prompt_cache(&messages, PromptCacheProvider::OpenAiCompatible, &["system".to_string()], 0);
127        assert!(plan.tagged_indexes.is_empty());
128    }
129
130    #[test]
131    fn apply_sets_ephemeral_marker_only_for_planned_indexes() {
132        let messages = vec![msg("system", "a".repeat(100)), msg("user", "hello")];
133        let plan = plan_prompt_cache(&messages, PromptCacheProvider::Anthropic, &["system".to_string()], 100);
134        let out = apply_prompt_cache_plan(&messages, &plan);
135
136        assert_eq!(
137            out[0]
138                .cache_control
139                .as_ref()
140                .map(|cc| cc.kind.as_str())
141                .unwrap_or(""),
142            "ephemeral"
143        );
144        assert!(out[1].cache_control.is_none());
145    }
146
147    #[test]
148    fn plan_prompt_cache_default_works() {
149        let messages = vec![
150            msg("system", "a".repeat(100)),
151            msg("user", "hello"),
152        ];
153        let plan = plan_prompt_cache_default(&messages, PromptCacheProvider::Anthropic);
154        assert_eq!(plan.provider, PromptCacheProvider::Anthropic);
155        assert_eq!(plan.tagged_indexes, vec![0]);
156    }
157}