1use crate::llm;
15use crate::types::{AgentState, ThreadId};
16use anyhow::{Context, Result};
17use async_trait::async_trait;
18use std::collections::HashMap;
19use std::sync::RwLock;
20
21#[async_trait]
24pub trait MessageStore: Send + Sync {
25 async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()>;
30
31 async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>>;
36
37 async fn clear(&self, thread_id: &ThreadId) -> Result<()>;
42
43 async fn count(&self, thread_id: &ThreadId) -> Result<usize> {
48 Ok(self.get_history(thread_id).await?.len())
49 }
50
51 async fn replace_history(
57 &self,
58 thread_id: &ThreadId,
59 messages: Vec<llm::Message>,
60 ) -> Result<()>;
61}
62
63#[async_trait]
66pub trait StateStore: Send + Sync {
67 async fn save(&self, state: &AgentState) -> Result<()>;
72
73 async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>>;
78
79 async fn delete(&self, thread_id: &ThreadId) -> Result<()>;
84}
85
86#[derive(Default)]
89pub struct InMemoryStore {
90 messages: RwLock<HashMap<String, Vec<llm::Message>>>,
91 states: RwLock<HashMap<String, AgentState>>,
92}
93
94impl InMemoryStore {
95 #[must_use]
96 pub fn new() -> Self {
97 Self::default()
98 }
99}
100
101#[async_trait]
102impl MessageStore for InMemoryStore {
103 async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()> {
104 self.messages
105 .write()
106 .ok()
107 .context("lock poisoned")?
108 .entry(thread_id.0.clone())
109 .or_default()
110 .push(message);
111 Ok(())
112 }
113
114 async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>> {
115 let messages = self.messages.read().ok().context("lock poisoned")?;
116 Ok(messages.get(&thread_id.0).cloned().unwrap_or_default())
117 }
118
119 async fn clear(&self, thread_id: &ThreadId) -> Result<()> {
120 self.messages
121 .write()
122 .ok()
123 .context("lock poisoned")?
124 .remove(&thread_id.0);
125 Ok(())
126 }
127
128 async fn replace_history(
129 &self,
130 thread_id: &ThreadId,
131 messages: Vec<llm::Message>,
132 ) -> Result<()> {
133 self.messages
134 .write()
135 .ok()
136 .context("lock poisoned")?
137 .insert(thread_id.0.clone(), messages);
138 Ok(())
139 }
140}
141
142#[async_trait]
143impl StateStore for InMemoryStore {
144 async fn save(&self, state: &AgentState) -> Result<()> {
145 self.states
146 .write()
147 .ok()
148 .context("lock poisoned")?
149 .insert(state.thread_id.0.clone(), state.clone());
150 Ok(())
151 }
152
153 async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>> {
154 let states = self.states.read().ok().context("lock poisoned")?;
155 Ok(states.get(&thread_id.0).cloned())
156 }
157
158 async fn delete(&self, thread_id: &ThreadId) -> Result<()> {
159 self.states
160 .write()
161 .ok()
162 .context("lock poisoned")?
163 .remove(&thread_id.0);
164 Ok(())
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use crate::llm::Message;
172
173 #[tokio::test]
174 async fn test_in_memory_message_store() -> Result<()> {
175 let store = InMemoryStore::new();
176 let thread_id = ThreadId::new();
177
178 let history = store.get_history(&thread_id).await?;
180 assert!(history.is_empty());
181
182 store.append(&thread_id, Message::user("Hello")).await?;
184 store
185 .append(&thread_id, Message::assistant("Hi there!"))
186 .await?;
187
188 let history = store.get_history(&thread_id).await?;
190 assert_eq!(history.len(), 2);
191
192 let count = store.count(&thread_id).await?;
194 assert_eq!(count, 2);
195
196 store.clear(&thread_id).await?;
198 let history = store.get_history(&thread_id).await?;
199 assert!(history.is_empty());
200
201 Ok(())
202 }
203
204 #[tokio::test]
205 async fn test_replace_history() -> Result<()> {
206 let store = InMemoryStore::new();
207 let thread_id = ThreadId::new();
208
209 store.append(&thread_id, Message::user("Hello")).await?;
211 store
212 .append(&thread_id, Message::assistant("Hi there!"))
213 .await?;
214 store
215 .append(&thread_id, Message::user("How are you?"))
216 .await?;
217
218 let history = store.get_history(&thread_id).await?;
220 assert_eq!(history.len(), 3);
221
222 let new_history = vec![
224 Message::user("[Summary] Previous conversation about greetings"),
225 Message::assistant("I understand the context. Continuing..."),
226 ];
227 store.replace_history(&thread_id, new_history).await?;
228
229 let history = store.get_history(&thread_id).await?;
231 assert_eq!(history.len(), 2);
232
233 Ok(())
234 }
235
236 #[tokio::test]
237 async fn test_in_memory_state_store() -> Result<()> {
238 let store = InMemoryStore::new();
239 let thread_id = ThreadId::new();
240
241 let state = store.load(&thread_id).await?;
243 assert!(state.is_none());
244
245 let state = AgentState::new(thread_id.clone());
247 store.save(&state).await?;
248
249 let loaded = store.load(&thread_id).await?;
251 assert!(loaded.is_some());
252 if let Some(loaded_state) = loaded {
253 assert_eq!(loaded_state.thread_id, thread_id);
254 }
255
256 store.delete(&thread_id).await?;
258 let state = store.load(&thread_id).await?;
259 assert!(state.is_none());
260
261 Ok(())
262 }
263}