Skip to main content

atomr_agents_cache/
inmem.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use atomr_agents_core::Result;
6use parking_lot::RwLock;
7
8use crate::{CacheKey, CachedTurn, LlmCache};
9
10#[derive(Default, Clone)]
11pub struct InMemoryLlmCache {
12    inner: Arc<RwLock<HashMap<CacheKey, CachedTurn>>>,
13}
14
15impl InMemoryLlmCache {
16    pub fn new() -> Self {
17        Self::default()
18    }
19
20    pub fn len(&self) -> usize {
21        self.inner.read().len()
22    }
23}
24
25#[async_trait]
26impl LlmCache for InMemoryLlmCache {
27    async fn get(&self, key: &CacheKey) -> Result<Option<CachedTurn>> {
28        Ok(self.inner.read().get(key).cloned())
29    }
30    async fn put(&self, key: CacheKey, value: CachedTurn) -> Result<()> {
31        self.inner.write().insert(key, value);
32        Ok(())
33    }
34}
35
36#[cfg(test)]
37mod tests {
38    use super::*;
39    use atomr_infer_core::batch::{ExecuteBatch, Message, MessageContent, Role, SamplingParams};
40
41    fn batch(model: &str, text: &str) -> ExecuteBatch {
42        ExecuteBatch {
43            request_id: "r".into(),
44            model: model.into(),
45            messages: vec![Message {
46                role: Role::User,
47                content: MessageContent::Text(text.into()),
48            }],
49            sampling: SamplingParams::default(),
50            stream: false,
51            estimated_tokens: 1,
52        }
53    }
54
55    #[tokio::test]
56    async fn key_collisions_only_on_identical_payload() {
57        let a = CacheKey::from_batch(&batch("m", "hi"));
58        let b = CacheKey::from_batch(&batch("m", "hi"));
59        let c = CacheKey::from_batch(&batch("m", "different"));
60        assert_eq!(a, b);
61        assert_ne!(a, c);
62    }
63
64    #[tokio::test]
65    async fn put_get_round_trip() {
66        let c = InMemoryLlmCache::new();
67        let k = CacheKey::from_batch(&batch("m", "hello"));
68        let v = CachedTurn {
69            text: "hi back".into(),
70            usage: Default::default(),
71            finish_reason: None,
72        };
73        c.put(k.clone(), v.clone()).await.unwrap();
74        let got = c.get(&k).await.unwrap().unwrap();
75        assert_eq!(got.text, "hi back");
76    }
77}