atomr_agents_cache/
semantic.rs1use std::sync::Arc;
8
9use async_trait::async_trait;
10use atomr_agents_core::Result;
11use atomr_agents_embed::Embedder;
12use parking_lot::RwLock;
13
14use crate::{CacheKey, CachedTurn, LlmCache};
15
16struct Entry {
17 embedding: Vec<f32>,
18 value: CachedTurn,
19 key: CacheKey,
21 text: String,
23}
24
25pub struct SemanticLlmCache {
26 pub embedder: Arc<dyn Embedder>,
27 pub threshold: f32,
28 inner: Arc<RwLock<Vec<Entry>>>,
29}
30
31impl SemanticLlmCache {
32 pub fn new(embedder: Arc<dyn Embedder>, threshold: f32) -> Self {
33 Self {
34 embedder,
35 threshold,
36 inner: Arc::new(RwLock::new(Vec::new())),
37 }
38 }
39
40 pub fn len(&self) -> usize {
41 self.inner.read().len()
42 }
43
44 pub async fn get_by_text(&self, text: &str) -> Result<Option<CachedTurn>> {
48 let q = self.embedder.embed(text).await?;
49 let g = self.inner.read();
50 let mut best: Option<(f32, CachedTurn)> = None;
51 for e in g.iter() {
52 let s = cosine(&q, &e.embedding);
53 if s >= self.threshold {
54 if best.as_ref().map(|(b, _)| s > *b).unwrap_or(true) {
55 best = Some((s, e.value.clone()));
56 }
57 }
58 }
59 Ok(best.map(|(_, v)| v))
60 }
61
62 pub async fn put_with_text(
63 &self,
64 text: impl Into<String>,
65 key: CacheKey,
66 value: CachedTurn,
67 ) -> Result<()> {
68 let text = text.into();
69 let v = self.embedder.embed(&text).await?;
70 self.inner.write().push(Entry {
71 embedding: v,
72 value,
73 key,
74 text,
75 });
76 Ok(())
77 }
78}
79
80#[async_trait]
81impl LlmCache for SemanticLlmCache {
82 async fn get(&self, key: &CacheKey) -> Result<Option<CachedTurn>> {
83 if let Some(v) = self
85 .inner
86 .read()
87 .iter()
88 .find(|e| &e.key == key)
89 .map(|e| e.value.clone())
90 {
91 return Ok(Some(v));
92 }
93 Ok(None)
97 }
98 async fn put(&self, _key: CacheKey, _value: CachedTurn) -> Result<()> {
99 Err(atomr_agents_core::AgentError::Internal(
101 "SemanticLlmCache: use put_with_text() so the prompt text can be embedded".into(),
102 ))
103 }
104}
105
106fn cosine(a: &[f32], b: &[f32]) -> f32 {
107 if a.len() != b.len() {
108 return 0.0;
109 }
110 let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
111 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
112 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
113 if na == 0.0 || nb == 0.0 {
114 0.0
115 } else {
116 dot / (na * nb)
117 }
118}
119
120#[allow(dead_code)]
121fn _entry_in_scope(_e: &Entry) {}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 use atomr_agents_embed::MockEmbedder;
127 use atomr_infer_core::tokens::TokenUsage;
128
129 fn turn(text: &str) -> CachedTurn {
130 CachedTurn {
131 text: text.into(),
132 usage: TokenUsage::default(),
133 finish_reason: None,
134 }
135 }
136
137 #[tokio::test]
138 async fn hits_on_near_duplicate_prompt() {
139 let c = SemanticLlmCache::new(Arc::new(MockEmbedder::new(8)), 0.99);
140 let key = CacheKey {
141 model: "m".into(),
142 messages_hash: 1,
143 sampling_hash: 1,
144 };
145 c.put_with_text("hello", key, turn("hi back")).await.unwrap();
146 let v = c.get_by_text("hello").await.unwrap().unwrap();
147 assert_eq!(v.text, "hi back");
148 }
149
150 #[tokio::test]
151 async fn miss_below_threshold() {
152 let c = SemanticLlmCache::new(Arc::new(MockEmbedder::new(8)), 0.999);
153 let key = CacheKey {
154 model: "m".into(),
155 messages_hash: 1,
156 sampling_hash: 1,
157 };
158 c.put_with_text("hello", key, turn("hi back")).await.unwrap();
159 let v = c.get_by_text("entirely different prompt").await.unwrap();
160 assert!(v.is_none());
161 }
162}