ai_agents_memory/
in_memory.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use parking_lot::RwLock;
5
6use ai_agents_core::{ChatMessage, MemorySnapshot, Result};
7
8use super::Memory;
9
10pub struct InMemoryStore {
11 messages: Arc<RwLock<Vec<ChatMessage>>>,
12 max_messages: usize,
13}
14
15impl InMemoryStore {
16 pub fn new(max_messages: usize) -> Self {
17 Self {
18 messages: Arc::new(RwLock::new(Vec::new())),
19 max_messages,
20 }
21 }
22
23 pub fn max_messages(&self) -> usize {
24 self.max_messages
25 }
26}
27
28impl Clone for InMemoryStore {
29 fn clone(&self) -> Self {
30 Self {
31 messages: Arc::clone(&self.messages),
32 max_messages: self.max_messages,
33 }
34 }
35}
36
37#[async_trait]
38impl ai_agents_core::Memory for InMemoryStore {
39 async fn add_message(&self, message: ChatMessage) -> Result<()> {
40 let mut messages = self.messages.write();
41 messages.push(message);
42
43 while messages.len() > self.max_messages {
44 messages.remove(0);
45 }
46
47 Ok(())
48 }
49
50 async fn get_messages(&self, limit: Option<usize>) -> Result<Vec<ChatMessage>> {
51 let messages = self.messages.read();
52 match limit {
53 Some(n) => {
54 let start = messages.len().saturating_sub(n);
55 Ok(messages[start..].to_vec())
56 }
57 None => Ok(messages.clone()),
58 }
59 }
60
61 async fn clear(&self) -> Result<()> {
62 self.messages.write().clear();
63 Ok(())
64 }
65
66 fn len(&self) -> usize {
67 self.messages.read().len()
68 }
69
70 async fn restore(&self, snapshot: MemorySnapshot) -> Result<()> {
71 let mut messages = self.messages.write();
72 *messages = snapshot.messages;
73 while messages.len() > self.max_messages {
74 messages.remove(0);
75 }
76 Ok(())
77 }
78
79 async fn evict_oldest(&self, count: usize) -> Result<Vec<ChatMessage>> {
80 let mut messages = self.messages.write();
81 let evict_count = count.min(messages.len());
82 let evicted: Vec<ChatMessage> = messages.drain(..evict_count).collect();
83 Ok(evicted)
84 }
85}
86
87#[async_trait]
88impl Memory for InMemoryStore {}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use ai_agents_core::{Memory as CoreMemory, Role};
94
95 fn make_message(content: &str) -> ChatMessage {
96 ChatMessage {
97 role: Role::User,
98 content: content.to_string(),
99 name: None,
100 timestamp: None,
101 }
102 }
103
104 #[tokio::test]
105 async fn test_add_and_get_messages() {
106 let store = InMemoryStore::new(10);
107
108 store.add_message(make_message("hello")).await.unwrap();
109 store.add_message(make_message("world")).await.unwrap();
110
111 let messages = store.get_messages(None).await.unwrap();
112 assert_eq!(messages.len(), 2);
113 assert_eq!(messages[0].content, "hello");
114 assert_eq!(messages[1].content, "world");
115 }
116
117 #[tokio::test]
118 async fn test_max_messages_limit() {
119 let store = InMemoryStore::new(3);
120
121 for i in 0..5 {
122 store
123 .add_message(make_message(&format!("msg{}", i)))
124 .await
125 .unwrap();
126 }
127
128 let messages = store.get_messages(None).await.unwrap();
129 assert_eq!(messages.len(), 3);
130 assert_eq!(messages[0].content, "msg2");
131 assert_eq!(messages[1].content, "msg3");
132 assert_eq!(messages[2].content, "msg4");
133 }
134
135 #[tokio::test]
136 async fn test_get_messages_with_limit() {
137 let store = InMemoryStore::new(10);
138
139 for i in 0..5 {
140 store
141 .add_message(make_message(&format!("msg{}", i)))
142 .await
143 .unwrap();
144 }
145
146 let messages = store.get_messages(Some(2)).await.unwrap();
147 assert_eq!(messages.len(), 2);
148 assert_eq!(messages[0].content, "msg3");
149 assert_eq!(messages[1].content, "msg4");
150 }
151
152 #[tokio::test]
153 async fn test_clear() {
154 let store = InMemoryStore::new(10);
155
156 store.add_message(make_message("test")).await.unwrap();
157 assert!(!store.is_empty());
158
159 store.clear().await.unwrap();
160 assert!(store.is_empty());
161 }
162
163 #[tokio::test]
164 async fn test_clone_shares_state() {
165 let store1 = InMemoryStore::new(10);
166 let store2 = store1.clone();
167
168 store1
169 .add_message(make_message("from store1"))
170 .await
171 .unwrap();
172
173 let messages = store2.get_messages(None).await.unwrap();
174 assert_eq!(messages.len(), 1);
175 assert_eq!(messages[0].content, "from store1");
176 }
177
178 #[tokio::test]
179 async fn test_snapshot_restore() {
180 let store = InMemoryStore::new(10);
181 store.add_message(make_message("msg1")).await.unwrap();
182 store.add_message(make_message("msg2")).await.unwrap();
183
184 let snapshot = store.snapshot().await.unwrap();
185 assert_eq!(snapshot.messages.len(), 2);
186
187 store.clear().await.unwrap();
188 assert!(store.is_empty());
189
190 store.restore(snapshot).await.unwrap();
191 let messages = store.get_messages(None).await.unwrap();
192 assert_eq!(messages.len(), 2);
193 assert_eq!(messages[0].content, "msg1");
194 }
195
196 #[tokio::test]
197 async fn test_evict_oldest() {
198 let store = InMemoryStore::new(10);
199 for i in 0..5 {
200 store
201 .add_message(make_message(&format!("msg{}", i)))
202 .await
203 .unwrap();
204 }
205
206 let evicted = store.evict_oldest(2).await.unwrap();
207 assert_eq!(evicted.len(), 2);
208 assert_eq!(evicted[0].content, "msg0");
209 assert_eq!(evicted[1].content, "msg1");
210
211 let remaining = store.get_messages(None).await.unwrap();
212 assert_eq!(remaining.len(), 3);
213 assert_eq!(remaining[0].content, "msg2");
214 }
215}