Skip to main content

ai_agents_memory/
in_memory.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use parking_lot::RwLock;
5
6use ai_agents_core::{ChatMessage, MemorySnapshot, Result};
7
8use super::Memory;
9
10pub struct InMemoryStore {
11    messages: Arc<RwLock<Vec<ChatMessage>>>,
12    max_messages: usize,
13}
14
15impl InMemoryStore {
16    pub fn new(max_messages: usize) -> Self {
17        Self {
18            messages: Arc::new(RwLock::new(Vec::new())),
19            max_messages,
20        }
21    }
22
23    pub fn max_messages(&self) -> usize {
24        self.max_messages
25    }
26}
27
28impl Clone for InMemoryStore {
29    fn clone(&self) -> Self {
30        Self {
31            messages: Arc::clone(&self.messages),
32            max_messages: self.max_messages,
33        }
34    }
35}
36
37#[async_trait]
38impl ai_agents_core::Memory for InMemoryStore {
39    async fn add_message(&self, message: ChatMessage) -> Result<()> {
40        let mut messages = self.messages.write();
41        messages.push(message);
42
43        while messages.len() > self.max_messages {
44            messages.remove(0);
45        }
46
47        Ok(())
48    }
49
50    async fn get_messages(&self, limit: Option<usize>) -> Result<Vec<ChatMessage>> {
51        let messages = self.messages.read();
52        match limit {
53            Some(n) => {
54                let start = messages.len().saturating_sub(n);
55                Ok(messages[start..].to_vec())
56            }
57            None => Ok(messages.clone()),
58        }
59    }
60
61    async fn clear(&self) -> Result<()> {
62        self.messages.write().clear();
63        Ok(())
64    }
65
66    fn len(&self) -> usize {
67        self.messages.read().len()
68    }
69
70    async fn restore(&self, snapshot: MemorySnapshot) -> Result<()> {
71        let mut messages = self.messages.write();
72        *messages = snapshot.messages;
73        while messages.len() > self.max_messages {
74            messages.remove(0);
75        }
76        Ok(())
77    }
78
79    async fn evict_oldest(&self, count: usize) -> Result<Vec<ChatMessage>> {
80        let mut messages = self.messages.write();
81        let evict_count = count.min(messages.len());
82        let evicted: Vec<ChatMessage> = messages.drain(..evict_count).collect();
83        Ok(evicted)
84    }
85}
86
87#[async_trait]
88impl Memory for InMemoryStore {}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use ai_agents_core::{Memory as CoreMemory, Role};
94
95    fn make_message(content: &str) -> ChatMessage {
96        ChatMessage {
97            role: Role::User,
98            content: content.to_string(),
99            name: None,
100            timestamp: None,
101        }
102    }
103
104    #[tokio::test]
105    async fn test_add_and_get_messages() {
106        let store = InMemoryStore::new(10);
107
108        store.add_message(make_message("hello")).await.unwrap();
109        store.add_message(make_message("world")).await.unwrap();
110
111        let messages = store.get_messages(None).await.unwrap();
112        assert_eq!(messages.len(), 2);
113        assert_eq!(messages[0].content, "hello");
114        assert_eq!(messages[1].content, "world");
115    }
116
117    #[tokio::test]
118    async fn test_max_messages_limit() {
119        let store = InMemoryStore::new(3);
120
121        for i in 0..5 {
122            store
123                .add_message(make_message(&format!("msg{}", i)))
124                .await
125                .unwrap();
126        }
127
128        let messages = store.get_messages(None).await.unwrap();
129        assert_eq!(messages.len(), 3);
130        assert_eq!(messages[0].content, "msg2");
131        assert_eq!(messages[1].content, "msg3");
132        assert_eq!(messages[2].content, "msg4");
133    }
134
135    #[tokio::test]
136    async fn test_get_messages_with_limit() {
137        let store = InMemoryStore::new(10);
138
139        for i in 0..5 {
140            store
141                .add_message(make_message(&format!("msg{}", i)))
142                .await
143                .unwrap();
144        }
145
146        let messages = store.get_messages(Some(2)).await.unwrap();
147        assert_eq!(messages.len(), 2);
148        assert_eq!(messages[0].content, "msg3");
149        assert_eq!(messages[1].content, "msg4");
150    }
151
152    #[tokio::test]
153    async fn test_clear() {
154        let store = InMemoryStore::new(10);
155
156        store.add_message(make_message("test")).await.unwrap();
157        assert!(!store.is_empty());
158
159        store.clear().await.unwrap();
160        assert!(store.is_empty());
161    }
162
163    #[tokio::test]
164    async fn test_clone_shares_state() {
165        let store1 = InMemoryStore::new(10);
166        let store2 = store1.clone();
167
168        store1
169            .add_message(make_message("from store1"))
170            .await
171            .unwrap();
172
173        let messages = store2.get_messages(None).await.unwrap();
174        assert_eq!(messages.len(), 1);
175        assert_eq!(messages[0].content, "from store1");
176    }
177
178    #[tokio::test]
179    async fn test_snapshot_restore() {
180        let store = InMemoryStore::new(10);
181        store.add_message(make_message("msg1")).await.unwrap();
182        store.add_message(make_message("msg2")).await.unwrap();
183
184        let snapshot = store.snapshot().await.unwrap();
185        assert_eq!(snapshot.messages.len(), 2);
186
187        store.clear().await.unwrap();
188        assert!(store.is_empty());
189
190        store.restore(snapshot).await.unwrap();
191        let messages = store.get_messages(None).await.unwrap();
192        assert_eq!(messages.len(), 2);
193        assert_eq!(messages[0].content, "msg1");
194    }
195
196    #[tokio::test]
197    async fn test_evict_oldest() {
198        let store = InMemoryStore::new(10);
199        for i in 0..5 {
200            store
201                .add_message(make_message(&format!("msg{}", i)))
202                .await
203                .unwrap();
204        }
205
206        let evicted = store.evict_oldest(2).await.unwrap();
207        assert_eq!(evicted.len(), 2);
208        assert_eq!(evicted[0].content, "msg0");
209        assert_eq!(evicted[1].content, "msg1");
210
211        let remaining = store.get_messages(None).await.unwrap();
212        assert_eq!(remaining.len(), 3);
213        assert_eq!(remaining[0].content, "msg2");
214    }
215}