Skip to main content

agent_sdk/
stores.rs

1//! Storage traits for message history and agent state.
2//!
3//! The SDK uses two storage abstractions:
4//!
5//! - [`MessageStore`] - Stores conversation message history per thread
6//! - [`StateStore`] - Stores agent state checkpoints for recovery
7//!
8//! # Built-in Implementation
9//!
10//! [`InMemoryStore`] implements both traits and is suitable for testing
11//! and single-process deployments. For production, implement custom stores
12//! backed by your database (e.g., Postgres, Redis).
13
14use crate::llm;
15use crate::types::{AgentState, ThreadId, ToolExecution};
16use anyhow::{Context, Result};
17use async_trait::async_trait;
18use std::collections::HashMap;
19use std::sync::RwLock;
20
21/// Trait for storing and retrieving conversation messages.
22/// Implement this trait to persist messages to your storage backend.
23#[async_trait]
24pub trait MessageStore: Send + Sync {
25    /// Append a message to the thread's history
26    ///
27    /// # Errors
28    /// Returns an error if the message cannot be stored.
29    async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()>;
30
31    /// Get all messages for a thread
32    ///
33    /// # Errors
34    /// Returns an error if the history cannot be retrieved.
35    async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>>;
36
37    /// Clear all messages for a thread
38    ///
39    /// # Errors
40    /// Returns an error if the messages cannot be cleared.
41    async fn clear(&self, thread_id: &ThreadId) -> Result<()>;
42
43    /// Get the message count for a thread
44    ///
45    /// # Errors
46    /// Returns an error if the count cannot be retrieved.
47    async fn count(&self, thread_id: &ThreadId) -> Result<usize> {
48        Ok(self.get_history(thread_id).await?.len())
49    }
50
51    /// Replace the entire message history for a thread.
52    /// Used for context compaction to replace old messages with a summary.
53    ///
54    /// # Errors
55    /// Returns an error if the history cannot be replaced.
56    async fn replace_history(
57        &self,
58        thread_id: &ThreadId,
59        messages: Vec<llm::Message>,
60    ) -> Result<()>;
61}
62
63/// Trait for storing agent state checkpoints.
64/// Implement this to enable conversation recovery and resume.
65#[async_trait]
66pub trait StateStore: Send + Sync {
67    /// Save the current agent state
68    ///
69    /// # Errors
70    /// Returns an error if the state cannot be saved.
71    async fn save(&self, state: &AgentState) -> Result<()>;
72
73    /// Load the most recent state for a thread
74    ///
75    /// # Errors
76    /// Returns an error if the state cannot be loaded.
77    async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>>;
78
79    /// Delete state for a thread
80    ///
81    /// # Errors
82    /// Returns an error if the state cannot be deleted.
83    async fn delete(&self, thread_id: &ThreadId) -> Result<()>;
84}
85
86/// Store for tracking tool executions (idempotency).
87///
88/// This trait enables write-ahead execution tracking to ensure tool idempotency.
89/// The pattern is:
90/// 1. Record execution intent BEFORE calling the tool (`record_execution`)
91/// 2. Update with result AFTER completion (`update_execution`)
92/// 3. On retry, check if execution exists and return cached result
93#[async_trait]
94pub trait ToolExecutionStore: Send + Sync {
95    /// Get an execution by `tool_call_id`.
96    ///
97    /// # Errors
98    /// Returns an error if the execution cannot be retrieved.
99    async fn get_execution(&self, tool_call_id: &str) -> Result<Option<ToolExecution>>;
100
101    /// Record a new execution (write-ahead, before calling tool).
102    ///
103    /// # Errors
104    /// Returns an error if the execution cannot be recorded.
105    async fn record_execution(&self, execution: ToolExecution) -> Result<()>;
106
107    /// Update an existing execution (after completion or to set `operation_id`).
108    ///
109    /// # Errors
110    /// Returns an error if the execution cannot be updated.
111    async fn update_execution(&self, execution: ToolExecution) -> Result<()>;
112
113    /// Get execution by `operation_id` (for async tool resume).
114    ///
115    /// # Errors
116    /// Returns an error if the execution cannot be retrieved.
117    async fn get_execution_by_operation_id(
118        &self,
119        operation_id: &str,
120    ) -> Result<Option<ToolExecution>>;
121}
122
123/// In-memory implementation of `MessageStore` and `StateStore`.
124/// Useful for testing and simple use cases.
125#[derive(Default)]
126pub struct InMemoryStore {
127    messages: RwLock<HashMap<String, Vec<llm::Message>>>,
128    states: RwLock<HashMap<String, AgentState>>,
129}
130
131impl InMemoryStore {
132    #[must_use]
133    pub fn new() -> Self {
134        Self::default()
135    }
136}
137
138#[async_trait]
139impl MessageStore for InMemoryStore {
140    async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()> {
141        self.messages
142            .write()
143            .ok()
144            .context("lock poisoned")?
145            .entry(thread_id.0.clone())
146            .or_default()
147            .push(message);
148        Ok(())
149    }
150
151    async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>> {
152        let messages = self.messages.read().ok().context("lock poisoned")?;
153        Ok(messages.get(&thread_id.0).cloned().unwrap_or_default())
154    }
155
156    async fn clear(&self, thread_id: &ThreadId) -> Result<()> {
157        self.messages
158            .write()
159            .ok()
160            .context("lock poisoned")?
161            .remove(&thread_id.0);
162        Ok(())
163    }
164
165    async fn replace_history(
166        &self,
167        thread_id: &ThreadId,
168        messages: Vec<llm::Message>,
169    ) -> Result<()> {
170        self.messages
171            .write()
172            .ok()
173            .context("lock poisoned")?
174            .insert(thread_id.0.clone(), messages);
175        Ok(())
176    }
177}
178
179#[async_trait]
180impl StateStore for InMemoryStore {
181    async fn save(&self, state: &AgentState) -> Result<()> {
182        self.states
183            .write()
184            .ok()
185            .context("lock poisoned")?
186            .insert(state.thread_id.0.clone(), state.clone());
187        Ok(())
188    }
189
190    async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>> {
191        let states = self.states.read().ok().context("lock poisoned")?;
192        Ok(states.get(&thread_id.0).cloned())
193    }
194
195    async fn delete(&self, thread_id: &ThreadId) -> Result<()> {
196        self.states
197            .write()
198            .ok()
199            .context("lock poisoned")?
200            .remove(&thread_id.0);
201        Ok(())
202    }
203}
204
205/// In-memory implementation of `ToolExecutionStore`.
206///
207/// Useful for testing and simple use cases where durability is not required.
208/// For production, implement a custom store backed by a database.
209#[derive(Default)]
210pub struct InMemoryExecutionStore {
211    /// Executions indexed by `tool_call_id`
212    executions: RwLock<HashMap<String, ToolExecution>>,
213    /// Index from `operation_id` to `tool_call_id` for async tool lookup
214    operation_index: RwLock<HashMap<String, String>>,
215}
216
217impl InMemoryExecutionStore {
218    #[must_use]
219    pub fn new() -> Self {
220        Self::default()
221    }
222}
223
224#[async_trait]
225impl ToolExecutionStore for InMemoryExecutionStore {
226    async fn get_execution(&self, tool_call_id: &str) -> Result<Option<ToolExecution>> {
227        let executions = self.executions.read().ok().context("lock poisoned")?;
228        Ok(executions.get(tool_call_id).cloned())
229    }
230
231    async fn record_execution(&self, execution: ToolExecution) -> Result<()> {
232        let tool_call_id = execution.tool_call_id.clone();
233        self.executions
234            .write()
235            .ok()
236            .context("lock poisoned")?
237            .insert(tool_call_id, execution);
238        Ok(())
239    }
240
241    async fn update_execution(&self, execution: ToolExecution) -> Result<()> {
242        let tool_call_id = execution.tool_call_id.clone();
243
244        // Update operation_id index if present
245        if let Some(ref op_id) = execution.operation_id {
246            self.operation_index
247                .write()
248                .ok()
249                .context("lock poisoned")?
250                .insert(op_id.clone(), tool_call_id.clone());
251        }
252
253        self.executions
254            .write()
255            .ok()
256            .context("lock poisoned")?
257            .insert(tool_call_id, execution);
258        Ok(())
259    }
260
261    async fn get_execution_by_operation_id(
262        &self,
263        operation_id: &str,
264    ) -> Result<Option<ToolExecution>> {
265        // Get tool_call_id and drop lock before acquiring another
266        let tool_call_id = {
267            let op_index = self.operation_index.read().ok().context("lock poisoned")?;
268            op_index.get(operation_id).cloned()
269        };
270
271        let Some(tool_call_id) = tool_call_id else {
272            return Ok(None);
273        };
274
275        let executions = self.executions.read().ok().context("lock poisoned")?;
276        Ok(executions.get(&tool_call_id).cloned())
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use crate::llm::Message;
284    use crate::types::ToolResult;
285
286    #[tokio::test]
287    async fn test_in_memory_message_store() -> Result<()> {
288        let store = InMemoryStore::new();
289        let thread_id = ThreadId::new();
290
291        // Initially empty
292        let history = store.get_history(&thread_id).await?;
293        assert!(history.is_empty());
294
295        // Add messages
296        store.append(&thread_id, Message::user("Hello")).await?;
297        store
298            .append(&thread_id, Message::assistant("Hi there!"))
299            .await?;
300
301        // Retrieve messages
302        let history = store.get_history(&thread_id).await?;
303        assert_eq!(history.len(), 2);
304
305        // Count
306        let count = store.count(&thread_id).await?;
307        assert_eq!(count, 2);
308
309        // Clear
310        store.clear(&thread_id).await?;
311        let history = store.get_history(&thread_id).await?;
312        assert!(history.is_empty());
313
314        Ok(())
315    }
316
317    #[tokio::test]
318    async fn test_replace_history() -> Result<()> {
319        let store = InMemoryStore::new();
320        let thread_id = ThreadId::new();
321
322        // Add some messages
323        store.append(&thread_id, Message::user("Hello")).await?;
324        store
325            .append(&thread_id, Message::assistant("Hi there!"))
326            .await?;
327        store
328            .append(&thread_id, Message::user("How are you?"))
329            .await?;
330
331        // Verify original messages
332        let history = store.get_history(&thread_id).await?;
333        assert_eq!(history.len(), 3);
334
335        // Replace with compacted history
336        let new_history = vec![
337            Message::user("[Summary] Previous conversation about greetings"),
338            Message::assistant("I understand the context. Continuing..."),
339        ];
340        store.replace_history(&thread_id, new_history).await?;
341
342        // Verify replaced history
343        let history = store.get_history(&thread_id).await?;
344        assert_eq!(history.len(), 2);
345
346        Ok(())
347    }
348
349    #[tokio::test]
350    async fn test_in_memory_state_store() -> Result<()> {
351        let store = InMemoryStore::new();
352        let thread_id = ThreadId::new();
353
354        // Initially none
355        let state = store.load(&thread_id).await?;
356        assert!(state.is_none());
357
358        // Save state
359        let state = AgentState::new(thread_id.clone());
360        store.save(&state).await?;
361
362        // Load state
363        let loaded = store.load(&thread_id).await?;
364        assert!(loaded.is_some());
365        if let Some(loaded_state) = loaded {
366            assert_eq!(loaded_state.thread_id, thread_id);
367        }
368
369        // Delete state
370        store.delete(&thread_id).await?;
371        let state = store.load(&thread_id).await?;
372        assert!(state.is_none());
373
374        Ok(())
375    }
376
377    #[tokio::test]
378    async fn test_execution_store_basic_operations() -> Result<()> {
379        let store = InMemoryExecutionStore::new();
380        let thread_id = ThreadId::new();
381
382        // Initially none
383        let execution = store.get_execution("tool_call_123").await?;
384        assert!(execution.is_none());
385
386        // Record execution
387        let execution = ToolExecution::new_in_flight(
388            "tool_call_123",
389            thread_id.clone(),
390            "my_tool",
391            "My Tool",
392            serde_json::json!({"param": "value"}),
393            time::OffsetDateTime::now_utc(),
394        );
395        store.record_execution(execution).await?;
396
397        // Retrieve execution
398        let loaded = store.get_execution("tool_call_123").await?;
399        assert!(loaded.is_some());
400        let loaded = loaded.expect("execution should exist");
401        assert_eq!(loaded.tool_call_id, "tool_call_123");
402        assert_eq!(loaded.tool_name, "my_tool");
403        assert!(loaded.is_in_flight());
404
405        Ok(())
406    }
407
408    #[tokio::test]
409    async fn test_execution_store_complete_execution() -> Result<()> {
410        let store = InMemoryExecutionStore::new();
411        let thread_id = ThreadId::new();
412
413        // Record in-flight execution
414        let mut execution = ToolExecution::new_in_flight(
415            "tool_call_456",
416            thread_id.clone(),
417            "my_tool",
418            "My Tool",
419            serde_json::json!({}),
420            time::OffsetDateTime::now_utc(),
421        );
422        store.record_execution(execution.clone()).await?;
423
424        // Complete the execution
425        execution.complete(ToolResult::success("Done!"));
426        store.update_execution(execution).await?;
427
428        // Verify it's completed
429        let loaded = store.get_execution("tool_call_456").await?;
430        let loaded = loaded.expect("execution should exist");
431        assert!(loaded.is_completed());
432        assert!(loaded.result.is_some());
433        assert!(loaded.result.as_ref().is_some_and(|r| r.success));
434
435        Ok(())
436    }
437
438    #[tokio::test]
439    async fn test_execution_store_operation_id_lookup() -> Result<()> {
440        let store = InMemoryExecutionStore::new();
441        let thread_id = ThreadId::new();
442
443        // Record execution with operation_id
444        let mut execution = ToolExecution::new_in_flight(
445            "tool_call_789",
446            thread_id.clone(),
447            "async_tool",
448            "Async Tool",
449            serde_json::json!({}),
450            time::OffsetDateTime::now_utc(),
451        );
452        execution.set_operation_id("op_abc123");
453        store.record_execution(execution.clone()).await?;
454        store.update_execution(execution).await?;
455
456        // Lookup by operation_id
457        let loaded = store.get_execution_by_operation_id("op_abc123").await?;
458        assert!(loaded.is_some());
459        let loaded = loaded.expect("execution should exist");
460        assert_eq!(loaded.tool_call_id, "tool_call_789");
461        assert_eq!(loaded.operation_id, Some("op_abc123".to_string()));
462
463        // Non-existent operation_id
464        let not_found = store.get_execution_by_operation_id("nonexistent").await?;
465        assert!(not_found.is_none());
466
467        Ok(())
468    }
469}