Skip to main content

atomr_agents_cache/
semantic.rs

1//! Semantic LLM cache.
2//!
3//! Embeds the concatenated user-message text; on `get`, returns the
4//! cached value of the most-similar previous prompt if cosine
5//! similarity ≥ `threshold`.
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use atomr_agents_core::Result;
11use atomr_agents_embed::Embedder;
12use parking_lot::RwLock;
13
14use crate::{CacheKey, CachedTurn, LlmCache};
15
16struct Entry {
17    embedding: Vec<f32>,
18    value: CachedTurn,
19    /// Original key so exact-key hits also work.
20    key: CacheKey,
21    /// The text used to embed (we hash it back for keying when retrieving).
22    text: String,
23}
24
25pub struct SemanticLlmCache {
26    pub embedder: Arc<dyn Embedder>,
27    pub threshold: f32,
28    inner: Arc<RwLock<Vec<Entry>>>,
29}
30
31impl SemanticLlmCache {
32    pub fn new(embedder: Arc<dyn Embedder>, threshold: f32) -> Self {
33        Self {
34            embedder,
35            threshold,
36            inner: Arc::new(RwLock::new(Vec::new())),
37        }
38    }
39
40    pub fn len(&self) -> usize {
41        self.inner.read().len()
42    }
43
44    /// Variant `get` keyed by the prompt text rather than the
45    /// hash-based `CacheKey`. Useful when the caller has the original
46    /// prompt available.
47    pub async fn get_by_text(&self, text: &str) -> Result<Option<CachedTurn>> {
48        let q = self.embedder.embed(text).await?;
49        let g = self.inner.read();
50        let mut best: Option<(f32, CachedTurn)> = None;
51        for e in g.iter() {
52            let s = cosine(&q, &e.embedding);
53            if s >= self.threshold {
54                if best.as_ref().map(|(b, _)| s > *b).unwrap_or(true) {
55                    best = Some((s, e.value.clone()));
56                }
57            }
58        }
59        Ok(best.map(|(_, v)| v))
60    }
61
62    pub async fn put_with_text(
63        &self,
64        text: impl Into<String>,
65        key: CacheKey,
66        value: CachedTurn,
67    ) -> Result<()> {
68        let text = text.into();
69        let v = self.embedder.embed(&text).await?;
70        self.inner.write().push(Entry {
71            embedding: v,
72            value,
73            key,
74            text,
75        });
76        Ok(())
77    }
78}
79
80#[async_trait]
81impl LlmCache for SemanticLlmCache {
82    async fn get(&self, key: &CacheKey) -> Result<Option<CachedTurn>> {
83        // Exact-key first.
84        if let Some(v) = self
85            .inner
86            .read()
87            .iter()
88            .find(|e| &e.key == key)
89            .map(|e| e.value.clone())
90        {
91            return Ok(Some(v));
92        }
93        // No prompt text available without re-deriving from the key
94        // (which is hash-only). Fall back to "miss" — callers that
95        // want semantic matching should call `get_by_text` directly.
96        Ok(None)
97    }
98    async fn put(&self, _key: CacheKey, _value: CachedTurn) -> Result<()> {
99        // Hashed key alone isn't enough to embed; require `put_with_text`.
100        Err(atomr_agents_core::AgentError::Internal(
101            "SemanticLlmCache: use put_with_text() so the prompt text can be embedded".into(),
102        ))
103    }
104}
105
106fn cosine(a: &[f32], b: &[f32]) -> f32 {
107    if a.len() != b.len() {
108        return 0.0;
109    }
110    let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
111    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
112    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
113    if na == 0.0 || nb == 0.0 {
114        0.0
115    } else {
116        dot / (na * nb)
117    }
118}
119
120#[allow(dead_code)]
121fn _entry_in_scope(_e: &Entry) {}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use atomr_agents_embed::MockEmbedder;
127    use atomr_infer_core::tokens::TokenUsage;
128
129    fn turn(text: &str) -> CachedTurn {
130        CachedTurn {
131            text: text.into(),
132            usage: TokenUsage::default(),
133            finish_reason: None,
134        }
135    }
136
137    #[tokio::test]
138    async fn hits_on_near_duplicate_prompt() {
139        let c = SemanticLlmCache::new(Arc::new(MockEmbedder::new(8)), 0.99);
140        let key = CacheKey {
141            model: "m".into(),
142            messages_hash: 1,
143            sampling_hash: 1,
144        };
145        c.put_with_text("hello", key, turn("hi back")).await.unwrap();
146        let v = c.get_by_text("hello").await.unwrap().unwrap();
147        assert_eq!(v.text, "hi back");
148    }
149
150    #[tokio::test]
151    async fn miss_below_threshold() {
152        let c = SemanticLlmCache::new(Arc::new(MockEmbedder::new(8)), 0.999);
153        let key = CacheKey {
154            model: "m".into(),
155            messages_hash: 1,
156            sampling_hash: 1,
157        };
158        c.put_with_text("hello", key, turn("hi back")).await.unwrap();
159        let v = c.get_by_text("entirely different prompt").await.unwrap();
160        assert!(v.is_none());
161    }
162}