Skip to main content

latch_cache/
lib.rs

1use latch_core::{Message, PromptCacheProvider};
2use serde::{Deserialize, Serialize};
3
4#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
5pub struct CacheControl {
6    #[serde(rename = "type")]
7    pub kind: String,
8}
9
10#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
11pub struct CacheTaggedMessage {
12    pub role: String,
13    pub content: String,
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub cache_control: Option<CacheControl>,
16}
17
18impl From<Message> for CacheTaggedMessage {
19    fn from(value: Message) -> Self {
20        Self {
21            role: value.role,
22            content: value.content,
23            cache_control: None,
24        }
25    }
26}
27
28#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
29pub struct PromptCachePlan {
30    pub provider: PromptCacheProvider,
31    pub tagged_indexes: Vec<usize>,
32}
33
34pub fn plan_prompt_cache(messages: &[Message], provider: PromptCacheProvider) -> PromptCachePlan {
35    match provider {
36        PromptCacheProvider::Anthropic => PromptCachePlan {
37            provider,
38            tagged_indexes: messages
39                .iter()
40                .enumerate()
41                .filter_map(|(idx, m)| (m.role == "system").then_some(idx))
42                .collect(),
43        },
44        PromptCacheProvider::OpenAiCompatible | PromptCacheProvider::None => PromptCachePlan {
45            provider,
46            tagged_indexes: Vec::new(),
47        },
48    }
49}
50
51pub fn apply_prompt_cache_plan(
52    messages: &[Message],
53    plan: &PromptCachePlan,
54) -> Vec<CacheTaggedMessage> {
55    let mut out: Vec<CacheTaggedMessage> = messages.iter().cloned().map(Into::into).collect();
56    for idx in &plan.tagged_indexes {
57        if let Some(msg) = out.get_mut(*idx) {
58            msg.cache_control = Some(CacheControl {
59                kind: "ephemeral".to_string(),
60            });
61        }
62    }
63    out
64}
65
66#[cfg(test)]
67mod tests {
68    use super::{apply_prompt_cache_plan, plan_prompt_cache};
69    use latch_core::{Message, PromptCacheProvider};
70
71    fn msg(role: &str, content: &str) -> Message {
72        Message::new(role, content)
73    }
74
75    #[test]
76    fn anthropic_tags_system_messages() {
77        let messages = vec![
78            msg("system", "policy"),
79            msg("user", "hello"),
80            msg("assistant", "hi"),
81            msg("system", "persona"),
82        ];
83        let plan = plan_prompt_cache(&messages, PromptCacheProvider::Anthropic);
84        assert_eq!(plan.tagged_indexes, vec![0, 3]);
85    }
86
87    #[test]
88    fn openai_compatible_has_no_tags() {
89        let messages = vec![msg("system", "policy"), msg("user", "hello")];
90        let plan = plan_prompt_cache(&messages, PromptCacheProvider::OpenAiCompatible);
91        assert!(plan.tagged_indexes.is_empty());
92    }
93
94    #[test]
95    fn apply_sets_ephemeral_marker_only_for_planned_indexes() {
96        let messages = vec![msg("system", "policy"), msg("user", "hello")];
97        let plan = plan_prompt_cache(&messages, PromptCacheProvider::Anthropic);
98        let out = apply_prompt_cache_plan(&messages, &plan);
99
100        assert_eq!(
101            out[0]
102                .cache_control
103                .as_ref()
104                .map(|cc| cc.kind.as_str())
105                .unwrap_or(""),
106            "ephemeral"
107        );
108        assert!(out[1].cache_control.is_none());
109    }
110}