atomr_agents_cache/
inmem.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use atomr_agents_core::Result;
6use parking_lot::RwLock;
7
8use crate::{CacheKey, CachedTurn, LlmCache};
9
10#[derive(Default, Clone)]
11pub struct InMemoryLlmCache {
12 inner: Arc<RwLock<HashMap<CacheKey, CachedTurn>>>,
13}
14
15impl InMemoryLlmCache {
16 pub fn new() -> Self {
17 Self::default()
18 }
19
20 pub fn len(&self) -> usize {
21 self.inner.read().len()
22 }
23}
24
25#[async_trait]
26impl LlmCache for InMemoryLlmCache {
27 async fn get(&self, key: &CacheKey) -> Result<Option<CachedTurn>> {
28 Ok(self.inner.read().get(key).cloned())
29 }
30 async fn put(&self, key: CacheKey, value: CachedTurn) -> Result<()> {
31 self.inner.write().insert(key, value);
32 Ok(())
33 }
34}
35
36#[cfg(test)]
37mod tests {
38 use super::*;
39 use atomr_infer_core::batch::{ExecuteBatch, Message, MessageContent, Role, SamplingParams};
40
41 fn batch(model: &str, text: &str) -> ExecuteBatch {
42 ExecuteBatch {
43 request_id: "r".into(),
44 model: model.into(),
45 messages: vec![Message {
46 role: Role::User,
47 content: MessageContent::Text(text.into()),
48 }],
49 sampling: SamplingParams::default(),
50 stream: false,
51 estimated_tokens: 1,
52 }
53 }
54
55 #[tokio::test]
56 async fn key_collisions_only_on_identical_payload() {
57 let a = CacheKey::from_batch(&batch("m", "hi"));
58 let b = CacheKey::from_batch(&batch("m", "hi"));
59 let c = CacheKey::from_batch(&batch("m", "different"));
60 assert_eq!(a, b);
61 assert_ne!(a, c);
62 }
63
64 #[tokio::test]
65 async fn put_get_round_trip() {
66 let c = InMemoryLlmCache::new();
67 let k = CacheKey::from_batch(&batch("m", "hello"));
68 let v = CachedTurn {
69 text: "hi back".into(),
70 usage: Default::default(),
71 finish_reason: None,
72 };
73 c.put(k.clone(), v.clone()).await.unwrap();
74 let got = c.get(&k).await.unwrap().unwrap();
75 assert_eq!(got.text, "hi back");
76 }
77}