Skip to main content

atomr_accel_agents/
embedding_cache.rs

1//! `EmbeddingCache` — LRU cache of `(input hash) -> Vec<f32>`.
2//!
3//! F4 ships a CPU-resident LRU keyed on a 64-bit hash of the input
4//! bytes. F5 swaps the value type to `GpuRef<f32>` once the agents
5//! crate has a stable model-actor surface to compute embeddings
6//! against.
7
8use std::collections::HashMap;
9use std::collections::VecDeque;
10
11use atomr_core::actor::{Context, Props};
12use atomr_macros::Actor;
13use tokio::sync::oneshot;
14
15#[derive(Debug, Clone, Copy, Default)]
16pub struct CacheStats {
17    pub hits: u64,
18    pub misses: u64,
19    pub size: usize,
20    pub capacity: usize,
21}
22
23pub struct EmbeddingCacheConfig {
24    pub capacity_entries: usize,
25    pub embedding_dim: usize,
26}
27
28pub enum EmbeddingCacheMsg {
29    /// Try the cache. On miss, returns `None` and the caller is
30    /// responsible for computing + storing the embedding via
31    /// `Insert`. F4 keeps cache and compute decoupled.
32    Get {
33        key: Vec<u8>,
34        reply: oneshot::Sender<Option<Vec<f32>>>,
35    },
36    Insert {
37        key: Vec<u8>,
38        value: Vec<f32>,
39        reply: oneshot::Sender<()>,
40    },
41    Invalidate {
42        key: Vec<u8>,
43        reply: oneshot::Sender<bool>,
44    },
45    Stats {
46        reply: oneshot::Sender<CacheStats>,
47    },
48}
49
50#[derive(Actor)]
51#[msg(EmbeddingCacheMsg)]
52pub struct EmbeddingCache {
53    config: EmbeddingCacheConfig,
54    cache: HashMap<u64, Vec<f32>>,
55    /// LRU order: front is least-recent.
56    order: VecDeque<u64>,
57    stats: CacheStats,
58}
59
60fn hash_key(k: &[u8]) -> u64 {
61    use std::hash::{Hash, Hasher};
62    let mut h = std::collections::hash_map::DefaultHasher::new();
63    k.hash(&mut h);
64    h.finish()
65}
66
67impl EmbeddingCache {
68    pub fn props(config: EmbeddingCacheConfig) -> Props<Self> {
69        Props::create(move || EmbeddingCache {
70            config: EmbeddingCacheConfig {
71                capacity_entries: config.capacity_entries,
72                embedding_dim: config.embedding_dim,
73            },
74            cache: HashMap::with_capacity(config.capacity_entries),
75            order: VecDeque::with_capacity(config.capacity_entries),
76            stats: CacheStats {
77                capacity: config.capacity_entries,
78                ..Default::default()
79            },
80        })
81    }
82
83    fn touch(&mut self, k: u64) {
84        if let Some(pos) = self.order.iter().position(|x| *x == k) {
85            self.order.remove(pos);
86        }
87        self.order.push_back(k);
88    }
89}
90
91impl EmbeddingCache {
92    /// `#[derive(Actor)]` delegates to this method via the
93    /// atomr-macros-generated `impl Actor`.
94    async fn handle_msg(&mut self, _ctx: &mut Context<Self>, msg: EmbeddingCacheMsg) {
95        match msg {
96            EmbeddingCacheMsg::Get { key, reply } => {
97                let h = hash_key(&key);
98                if let Some(v) = self.cache.get(&h).cloned() {
99                    self.stats.hits += 1;
100                    self.touch(h);
101                    let _ = reply.send(Some(v));
102                } else {
103                    self.stats.misses += 1;
104                    let _ = reply.send(None);
105                }
106            }
107            EmbeddingCacheMsg::Insert { key, value, reply } => {
108                let h = hash_key(&key);
109                if self.cache.len() >= self.config.capacity_entries && !self.cache.contains_key(&h)
110                {
111                    if let Some(victim) = self.order.pop_front() {
112                        self.cache.remove(&victim);
113                    }
114                }
115                self.cache.insert(h, value);
116                self.touch(h);
117                self.stats.size = self.cache.len();
118                let _ = reply.send(());
119            }
120            EmbeddingCacheMsg::Invalidate { key, reply } => {
121                let h = hash_key(&key);
122                let removed = self.cache.remove(&h).is_some();
123                if removed {
124                    if let Some(pos) = self.order.iter().position(|x| *x == h) {
125                        self.order.remove(pos);
126                    }
127                    self.stats.size = self.cache.len();
128                }
129                let _ = reply.send(removed);
130            }
131            EmbeddingCacheMsg::Stats { reply } => {
132                let _ = reply.send(self.stats);
133            }
134        }
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use atomr_config::Config;
142    use atomr_core::actor::ActorSystem;
143    use std::time::Duration;
144
145    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
146    async fn cache_hit_miss() {
147        let sys = ActorSystem::create("embed-test", Config::empty())
148            .await
149            .unwrap();
150        let actor = sys
151            .actor_of(
152                EmbeddingCache::props(EmbeddingCacheConfig {
153                    capacity_entries: 4,
154                    embedding_dim: 8,
155                }),
156                "cache",
157            )
158            .unwrap();
159
160        let key = b"hello".to_vec();
161        // Miss
162        let (tx, rx) = oneshot::channel();
163        actor.tell(EmbeddingCacheMsg::Get {
164            key: key.clone(),
165            reply: tx,
166        });
167        let v = tokio::time::timeout(Duration::from_secs(2), rx)
168            .await
169            .unwrap()
170            .unwrap();
171        assert!(v.is_none());
172
173        // Insert
174        let (tx, rx) = oneshot::channel();
175        actor.tell(EmbeddingCacheMsg::Insert {
176            key: key.clone(),
177            value: vec![1.0; 8],
178            reply: tx,
179        });
180        tokio::time::timeout(Duration::from_secs(2), rx)
181            .await
182            .unwrap()
183            .unwrap();
184
185        // Hit
186        let (tx, rx) = oneshot::channel();
187        actor.tell(EmbeddingCacheMsg::Get {
188            key: key.clone(),
189            reply: tx,
190        });
191        let v = tokio::time::timeout(Duration::from_secs(2), rx)
192            .await
193            .unwrap()
194            .unwrap();
195        assert_eq!(v, Some(vec![1.0; 8]));
196
197        sys.terminate().await;
198    }
199}