atomr_accel_agents/
embedding_cache.rs1use std::collections::HashMap;
9use std::collections::VecDeque;
10
11use atomr_core::actor::{Context, Props};
12use atomr_macros::Actor;
13use tokio::sync::oneshot;
14
15#[derive(Debug, Clone, Copy, Default)]
16pub struct CacheStats {
17 pub hits: u64,
18 pub misses: u64,
19 pub size: usize,
20 pub capacity: usize,
21}
22
23pub struct EmbeddingCacheConfig {
24 pub capacity_entries: usize,
25 pub embedding_dim: usize,
26}
27
28pub enum EmbeddingCacheMsg {
29 Get {
33 key: Vec<u8>,
34 reply: oneshot::Sender<Option<Vec<f32>>>,
35 },
36 Insert {
37 key: Vec<u8>,
38 value: Vec<f32>,
39 reply: oneshot::Sender<()>,
40 },
41 Invalidate {
42 key: Vec<u8>,
43 reply: oneshot::Sender<bool>,
44 },
45 Stats {
46 reply: oneshot::Sender<CacheStats>,
47 },
48}
49
50#[derive(Actor)]
51#[msg(EmbeddingCacheMsg)]
52pub struct EmbeddingCache {
53 config: EmbeddingCacheConfig,
54 cache: HashMap<u64, Vec<f32>>,
55 order: VecDeque<u64>,
57 stats: CacheStats,
58}
59
60fn hash_key(k: &[u8]) -> u64 {
61 use std::hash::{Hash, Hasher};
62 let mut h = std::collections::hash_map::DefaultHasher::new();
63 k.hash(&mut h);
64 h.finish()
65}
66
67impl EmbeddingCache {
68 pub fn props(config: EmbeddingCacheConfig) -> Props<Self> {
69 Props::create(move || EmbeddingCache {
70 config: EmbeddingCacheConfig {
71 capacity_entries: config.capacity_entries,
72 embedding_dim: config.embedding_dim,
73 },
74 cache: HashMap::with_capacity(config.capacity_entries),
75 order: VecDeque::with_capacity(config.capacity_entries),
76 stats: CacheStats {
77 capacity: config.capacity_entries,
78 ..Default::default()
79 },
80 })
81 }
82
83 fn touch(&mut self, k: u64) {
84 if let Some(pos) = self.order.iter().position(|x| *x == k) {
85 self.order.remove(pos);
86 }
87 self.order.push_back(k);
88 }
89}
90
91impl EmbeddingCache {
92 async fn handle_msg(&mut self, _ctx: &mut Context<Self>, msg: EmbeddingCacheMsg) {
95 match msg {
96 EmbeddingCacheMsg::Get { key, reply } => {
97 let h = hash_key(&key);
98 if let Some(v) = self.cache.get(&h).cloned() {
99 self.stats.hits += 1;
100 self.touch(h);
101 let _ = reply.send(Some(v));
102 } else {
103 self.stats.misses += 1;
104 let _ = reply.send(None);
105 }
106 }
107 EmbeddingCacheMsg::Insert { key, value, reply } => {
108 let h = hash_key(&key);
109 if self.cache.len() >= self.config.capacity_entries && !self.cache.contains_key(&h)
110 {
111 if let Some(victim) = self.order.pop_front() {
112 self.cache.remove(&victim);
113 }
114 }
115 self.cache.insert(h, value);
116 self.touch(h);
117 self.stats.size = self.cache.len();
118 let _ = reply.send(());
119 }
120 EmbeddingCacheMsg::Invalidate { key, reply } => {
121 let h = hash_key(&key);
122 let removed = self.cache.remove(&h).is_some();
123 if removed {
124 if let Some(pos) = self.order.iter().position(|x| *x == h) {
125 self.order.remove(pos);
126 }
127 self.stats.size = self.cache.len();
128 }
129 let _ = reply.send(removed);
130 }
131 EmbeddingCacheMsg::Stats { reply } => {
132 let _ = reply.send(self.stats);
133 }
134 }
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use atomr_config::Config;
142 use atomr_core::actor::ActorSystem;
143 use std::time::Duration;
144
145 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
146 async fn cache_hit_miss() {
147 let sys = ActorSystem::create("embed-test", Config::empty())
148 .await
149 .unwrap();
150 let actor = sys
151 .actor_of(
152 EmbeddingCache::props(EmbeddingCacheConfig {
153 capacity_entries: 4,
154 embedding_dim: 8,
155 }),
156 "cache",
157 )
158 .unwrap();
159
160 let key = b"hello".to_vec();
161 let (tx, rx) = oneshot::channel();
163 actor.tell(EmbeddingCacheMsg::Get {
164 key: key.clone(),
165 reply: tx,
166 });
167 let v = tokio::time::timeout(Duration::from_secs(2), rx)
168 .await
169 .unwrap()
170 .unwrap();
171 assert!(v.is_none());
172
173 let (tx, rx) = oneshot::channel();
175 actor.tell(EmbeddingCacheMsg::Insert {
176 key: key.clone(),
177 value: vec![1.0; 8],
178 reply: tx,
179 });
180 tokio::time::timeout(Duration::from_secs(2), rx)
181 .await
182 .unwrap()
183 .unwrap();
184
185 let (tx, rx) = oneshot::channel();
187 actor.tell(EmbeddingCacheMsg::Get {
188 key: key.clone(),
189 reply: tx,
190 });
191 let v = tokio::time::timeout(Duration::from_secs(2), rx)
192 .await
193 .unwrap()
194 .unwrap();
195 assert_eq!(v, Some(vec![1.0; 8]));
196
197 sys.terminate().await;
198 }
199}