1use crate::llm;
15use crate::types::{AgentState, ThreadId, ToolExecution};
16use anyhow::{Context, Result};
17use async_trait::async_trait;
18use std::collections::HashMap;
19use std::sync::RwLock;
20
21#[async_trait]
24pub trait MessageStore: Send + Sync {
25 async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()>;
30
31 async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>>;
36
37 async fn clear(&self, thread_id: &ThreadId) -> Result<()>;
42
43 async fn count(&self, thread_id: &ThreadId) -> Result<usize> {
48 Ok(self.get_history(thread_id).await?.len())
49 }
50
51 async fn replace_history(
57 &self,
58 thread_id: &ThreadId,
59 messages: Vec<llm::Message>,
60 ) -> Result<()>;
61}
62
63#[async_trait]
66pub trait StateStore: Send + Sync {
67 async fn save(&self, state: &AgentState) -> Result<()>;
72
73 async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>>;
78
79 async fn delete(&self, thread_id: &ThreadId) -> Result<()>;
84}
85
86#[async_trait]
94pub trait ToolExecutionStore: Send + Sync {
95 async fn get_execution(&self, tool_call_id: &str) -> Result<Option<ToolExecution>>;
100
101 async fn record_execution(&self, execution: ToolExecution) -> Result<()>;
106
107 async fn update_execution(&self, execution: ToolExecution) -> Result<()>;
112
113 async fn get_execution_by_operation_id(
118 &self,
119 operation_id: &str,
120 ) -> Result<Option<ToolExecution>>;
121}
122
123#[derive(Default)]
126pub struct InMemoryStore {
127 messages: RwLock<HashMap<String, Vec<llm::Message>>>,
128 states: RwLock<HashMap<String, AgentState>>,
129}
130
131impl InMemoryStore {
132 #[must_use]
133 pub fn new() -> Self {
134 Self::default()
135 }
136}
137
138#[async_trait]
139impl MessageStore for InMemoryStore {
140 async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()> {
141 self.messages
142 .write()
143 .ok()
144 .context("lock poisoned")?
145 .entry(thread_id.0.clone())
146 .or_default()
147 .push(message);
148 Ok(())
149 }
150
151 async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>> {
152 let messages = self.messages.read().ok().context("lock poisoned")?;
153 Ok(messages.get(&thread_id.0).cloned().unwrap_or_default())
154 }
155
156 async fn clear(&self, thread_id: &ThreadId) -> Result<()> {
157 self.messages
158 .write()
159 .ok()
160 .context("lock poisoned")?
161 .remove(&thread_id.0);
162 Ok(())
163 }
164
165 async fn replace_history(
166 &self,
167 thread_id: &ThreadId,
168 messages: Vec<llm::Message>,
169 ) -> Result<()> {
170 self.messages
171 .write()
172 .ok()
173 .context("lock poisoned")?
174 .insert(thread_id.0.clone(), messages);
175 Ok(())
176 }
177}
178
179#[async_trait]
180impl StateStore for InMemoryStore {
181 async fn save(&self, state: &AgentState) -> Result<()> {
182 self.states
183 .write()
184 .ok()
185 .context("lock poisoned")?
186 .insert(state.thread_id.0.clone(), state.clone());
187 Ok(())
188 }
189
190 async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>> {
191 let states = self.states.read().ok().context("lock poisoned")?;
192 Ok(states.get(&thread_id.0).cloned())
193 }
194
195 async fn delete(&self, thread_id: &ThreadId) -> Result<()> {
196 self.states
197 .write()
198 .ok()
199 .context("lock poisoned")?
200 .remove(&thread_id.0);
201 Ok(())
202 }
203}
204
205#[derive(Default)]
210pub struct InMemoryExecutionStore {
211 executions: RwLock<HashMap<String, ToolExecution>>,
213 operation_index: RwLock<HashMap<String, String>>,
215}
216
217impl InMemoryExecutionStore {
218 #[must_use]
219 pub fn new() -> Self {
220 Self::default()
221 }
222}
223
224#[async_trait]
225impl ToolExecutionStore for InMemoryExecutionStore {
226 async fn get_execution(&self, tool_call_id: &str) -> Result<Option<ToolExecution>> {
227 let executions = self.executions.read().ok().context("lock poisoned")?;
228 Ok(executions.get(tool_call_id).cloned())
229 }
230
231 async fn record_execution(&self, execution: ToolExecution) -> Result<()> {
232 let tool_call_id = execution.tool_call_id.clone();
233 self.executions
234 .write()
235 .ok()
236 .context("lock poisoned")?
237 .insert(tool_call_id, execution);
238 Ok(())
239 }
240
241 async fn update_execution(&self, execution: ToolExecution) -> Result<()> {
242 let tool_call_id = execution.tool_call_id.clone();
243
244 if let Some(ref op_id) = execution.operation_id {
246 self.operation_index
247 .write()
248 .ok()
249 .context("lock poisoned")?
250 .insert(op_id.clone(), tool_call_id.clone());
251 }
252
253 self.executions
254 .write()
255 .ok()
256 .context("lock poisoned")?
257 .insert(tool_call_id, execution);
258 Ok(())
259 }
260
261 async fn get_execution_by_operation_id(
262 &self,
263 operation_id: &str,
264 ) -> Result<Option<ToolExecution>> {
265 let tool_call_id = {
267 let op_index = self.operation_index.read().ok().context("lock poisoned")?;
268 op_index.get(operation_id).cloned()
269 };
270
271 let Some(tool_call_id) = tool_call_id else {
272 return Ok(None);
273 };
274
275 let executions = self.executions.read().ok().context("lock poisoned")?;
276 Ok(executions.get(&tool_call_id).cloned())
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283 use crate::llm::Message;
284 use crate::types::ToolResult;
285
286 #[tokio::test]
287 async fn test_in_memory_message_store() -> Result<()> {
288 let store = InMemoryStore::new();
289 let thread_id = ThreadId::new();
290
291 let history = store.get_history(&thread_id).await?;
293 assert!(history.is_empty());
294
295 store.append(&thread_id, Message::user("Hello")).await?;
297 store
298 .append(&thread_id, Message::assistant("Hi there!"))
299 .await?;
300
301 let history = store.get_history(&thread_id).await?;
303 assert_eq!(history.len(), 2);
304
305 let count = store.count(&thread_id).await?;
307 assert_eq!(count, 2);
308
309 store.clear(&thread_id).await?;
311 let history = store.get_history(&thread_id).await?;
312 assert!(history.is_empty());
313
314 Ok(())
315 }
316
317 #[tokio::test]
318 async fn test_replace_history() -> Result<()> {
319 let store = InMemoryStore::new();
320 let thread_id = ThreadId::new();
321
322 store.append(&thread_id, Message::user("Hello")).await?;
324 store
325 .append(&thread_id, Message::assistant("Hi there!"))
326 .await?;
327 store
328 .append(&thread_id, Message::user("How are you?"))
329 .await?;
330
331 let history = store.get_history(&thread_id).await?;
333 assert_eq!(history.len(), 3);
334
335 let new_history = vec![
337 Message::user("[Summary] Previous conversation about greetings"),
338 Message::assistant("I understand the context. Continuing..."),
339 ];
340 store.replace_history(&thread_id, new_history).await?;
341
342 let history = store.get_history(&thread_id).await?;
344 assert_eq!(history.len(), 2);
345
346 Ok(())
347 }
348
349 #[tokio::test]
350 async fn test_in_memory_state_store() -> Result<()> {
351 let store = InMemoryStore::new();
352 let thread_id = ThreadId::new();
353
354 let state = store.load(&thread_id).await?;
356 assert!(state.is_none());
357
358 let state = AgentState::new(thread_id.clone());
360 store.save(&state).await?;
361
362 let loaded = store.load(&thread_id).await?;
364 assert!(loaded.is_some());
365 if let Some(loaded_state) = loaded {
366 assert_eq!(loaded_state.thread_id, thread_id);
367 }
368
369 store.delete(&thread_id).await?;
371 let state = store.load(&thread_id).await?;
372 assert!(state.is_none());
373
374 Ok(())
375 }
376
377 #[tokio::test]
378 async fn test_execution_store_basic_operations() -> Result<()> {
379 let store = InMemoryExecutionStore::new();
380 let thread_id = ThreadId::new();
381
382 let execution = store.get_execution("tool_call_123").await?;
384 assert!(execution.is_none());
385
386 let execution = ToolExecution::new_in_flight(
388 "tool_call_123",
389 thread_id.clone(),
390 "my_tool",
391 "My Tool",
392 serde_json::json!({"param": "value"}),
393 time::OffsetDateTime::now_utc(),
394 );
395 store.record_execution(execution).await?;
396
397 let loaded = store.get_execution("tool_call_123").await?;
399 assert!(loaded.is_some());
400 let loaded = loaded.expect("execution should exist");
401 assert_eq!(loaded.tool_call_id, "tool_call_123");
402 assert_eq!(loaded.tool_name, "my_tool");
403 assert!(loaded.is_in_flight());
404
405 Ok(())
406 }
407
408 #[tokio::test]
409 async fn test_execution_store_complete_execution() -> Result<()> {
410 let store = InMemoryExecutionStore::new();
411 let thread_id = ThreadId::new();
412
413 let mut execution = ToolExecution::new_in_flight(
415 "tool_call_456",
416 thread_id.clone(),
417 "my_tool",
418 "My Tool",
419 serde_json::json!({}),
420 time::OffsetDateTime::now_utc(),
421 );
422 store.record_execution(execution.clone()).await?;
423
424 execution.complete(ToolResult::success("Done!"));
426 store.update_execution(execution).await?;
427
428 let loaded = store.get_execution("tool_call_456").await?;
430 let loaded = loaded.expect("execution should exist");
431 assert!(loaded.is_completed());
432 assert!(loaded.result.is_some());
433 assert!(loaded.result.as_ref().is_some_and(|r| r.success));
434
435 Ok(())
436 }
437
438 #[tokio::test]
439 async fn test_execution_store_operation_id_lookup() -> Result<()> {
440 let store = InMemoryExecutionStore::new();
441 let thread_id = ThreadId::new();
442
443 let mut execution = ToolExecution::new_in_flight(
445 "tool_call_789",
446 thread_id.clone(),
447 "async_tool",
448 "Async Tool",
449 serde_json::json!({}),
450 time::OffsetDateTime::now_utc(),
451 );
452 execution.set_operation_id("op_abc123");
453 store.record_execution(execution.clone()).await?;
454 store.update_execution(execution).await?;
455
456 let loaded = store.get_execution_by_operation_id("op_abc123").await?;
458 assert!(loaded.is_some());
459 let loaded = loaded.expect("execution should exist");
460 assert_eq!(loaded.tool_call_id, "tool_call_789");
461 assert_eq!(loaded.operation_id, Some("op_abc123".to_string()));
462
463 let not_found = store.get_execution_by_operation_id("nonexistent").await?;
465 assert!(not_found.is_none());
466
467 Ok(())
468 }
469}