agent_sdk/
stores.rs

1//! Storage traits for message history and agent state.
2//!
3//! The SDK uses two storage abstractions:
4//!
5//! - [`MessageStore`] - Stores conversation message history per thread
6//! - [`StateStore`] - Stores agent state checkpoints for recovery
7//!
8//! # Built-in Implementation
9//!
10//! [`InMemoryStore`] implements both traits and is suitable for testing
11//! and single-process deployments. For production, implement custom stores
12//! backed by your database (e.g., Postgres, Redis).
13
14use crate::llm;
15use crate::types::{AgentState, ThreadId};
16use anyhow::{Context, Result};
17use async_trait::async_trait;
18use std::collections::HashMap;
19use std::sync::RwLock;
20
21/// Trait for storing and retrieving conversation messages.
22/// Implement this trait to persist messages to your storage backend.
23#[async_trait]
24pub trait MessageStore: Send + Sync {
25    /// Append a message to the thread's history
26    ///
27    /// # Errors
28    /// Returns an error if the message cannot be stored.
29    async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()>;
30
31    /// Get all messages for a thread
32    ///
33    /// # Errors
34    /// Returns an error if the history cannot be retrieved.
35    async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>>;
36
37    /// Clear all messages for a thread
38    ///
39    /// # Errors
40    /// Returns an error if the messages cannot be cleared.
41    async fn clear(&self, thread_id: &ThreadId) -> Result<()>;
42
43    /// Get the message count for a thread
44    ///
45    /// # Errors
46    /// Returns an error if the count cannot be retrieved.
47    async fn count(&self, thread_id: &ThreadId) -> Result<usize> {
48        Ok(self.get_history(thread_id).await?.len())
49    }
50
51    /// Replace the entire message history for a thread.
52    /// Used for context compaction to replace old messages with a summary.
53    ///
54    /// # Errors
55    /// Returns an error if the history cannot be replaced.
56    async fn replace_history(
57        &self,
58        thread_id: &ThreadId,
59        messages: Vec<llm::Message>,
60    ) -> Result<()>;
61}
62
63/// Trait for storing agent state checkpoints.
64/// Implement this to enable conversation recovery and resume.
65#[async_trait]
66pub trait StateStore: Send + Sync {
67    /// Save the current agent state
68    ///
69    /// # Errors
70    /// Returns an error if the state cannot be saved.
71    async fn save(&self, state: &AgentState) -> Result<()>;
72
73    /// Load the most recent state for a thread
74    ///
75    /// # Errors
76    /// Returns an error if the state cannot be loaded.
77    async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>>;
78
79    /// Delete state for a thread
80    ///
81    /// # Errors
82    /// Returns an error if the state cannot be deleted.
83    async fn delete(&self, thread_id: &ThreadId) -> Result<()>;
84}
85
86/// In-memory implementation of `MessageStore` and `StateStore`.
87/// Useful for testing and simple use cases.
88#[derive(Default)]
89pub struct InMemoryStore {
90    messages: RwLock<HashMap<String, Vec<llm::Message>>>,
91    states: RwLock<HashMap<String, AgentState>>,
92}
93
94impl InMemoryStore {
95    #[must_use]
96    pub fn new() -> Self {
97        Self::default()
98    }
99}
100
101#[async_trait]
102impl MessageStore for InMemoryStore {
103    async fn append(&self, thread_id: &ThreadId, message: llm::Message) -> Result<()> {
104        self.messages
105            .write()
106            .ok()
107            .context("lock poisoned")?
108            .entry(thread_id.0.clone())
109            .or_default()
110            .push(message);
111        Ok(())
112    }
113
114    async fn get_history(&self, thread_id: &ThreadId) -> Result<Vec<llm::Message>> {
115        let messages = self.messages.read().ok().context("lock poisoned")?;
116        Ok(messages.get(&thread_id.0).cloned().unwrap_or_default())
117    }
118
119    async fn clear(&self, thread_id: &ThreadId) -> Result<()> {
120        self.messages
121            .write()
122            .ok()
123            .context("lock poisoned")?
124            .remove(&thread_id.0);
125        Ok(())
126    }
127
128    async fn replace_history(
129        &self,
130        thread_id: &ThreadId,
131        messages: Vec<llm::Message>,
132    ) -> Result<()> {
133        self.messages
134            .write()
135            .ok()
136            .context("lock poisoned")?
137            .insert(thread_id.0.clone(), messages);
138        Ok(())
139    }
140}
141
142#[async_trait]
143impl StateStore for InMemoryStore {
144    async fn save(&self, state: &AgentState) -> Result<()> {
145        self.states
146            .write()
147            .ok()
148            .context("lock poisoned")?
149            .insert(state.thread_id.0.clone(), state.clone());
150        Ok(())
151    }
152
153    async fn load(&self, thread_id: &ThreadId) -> Result<Option<AgentState>> {
154        let states = self.states.read().ok().context("lock poisoned")?;
155        Ok(states.get(&thread_id.0).cloned())
156    }
157
158    async fn delete(&self, thread_id: &ThreadId) -> Result<()> {
159        self.states
160            .write()
161            .ok()
162            .context("lock poisoned")?
163            .remove(&thread_id.0);
164        Ok(())
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::llm::Message;
172
173    #[tokio::test]
174    async fn test_in_memory_message_store() -> Result<()> {
175        let store = InMemoryStore::new();
176        let thread_id = ThreadId::new();
177
178        // Initially empty
179        let history = store.get_history(&thread_id).await?;
180        assert!(history.is_empty());
181
182        // Add messages
183        store.append(&thread_id, Message::user("Hello")).await?;
184        store
185            .append(&thread_id, Message::assistant("Hi there!"))
186            .await?;
187
188        // Retrieve messages
189        let history = store.get_history(&thread_id).await?;
190        assert_eq!(history.len(), 2);
191
192        // Count
193        let count = store.count(&thread_id).await?;
194        assert_eq!(count, 2);
195
196        // Clear
197        store.clear(&thread_id).await?;
198        let history = store.get_history(&thread_id).await?;
199        assert!(history.is_empty());
200
201        Ok(())
202    }
203
204    #[tokio::test]
205    async fn test_replace_history() -> Result<()> {
206        let store = InMemoryStore::new();
207        let thread_id = ThreadId::new();
208
209        // Add some messages
210        store.append(&thread_id, Message::user("Hello")).await?;
211        store
212            .append(&thread_id, Message::assistant("Hi there!"))
213            .await?;
214        store
215            .append(&thread_id, Message::user("How are you?"))
216            .await?;
217
218        // Verify original messages
219        let history = store.get_history(&thread_id).await?;
220        assert_eq!(history.len(), 3);
221
222        // Replace with compacted history
223        let new_history = vec![
224            Message::user("[Summary] Previous conversation about greetings"),
225            Message::assistant("I understand the context. Continuing..."),
226        ];
227        store.replace_history(&thread_id, new_history).await?;
228
229        // Verify replaced history
230        let history = store.get_history(&thread_id).await?;
231        assert_eq!(history.len(), 2);
232
233        Ok(())
234    }
235
236    #[tokio::test]
237    async fn test_in_memory_state_store() -> Result<()> {
238        let store = InMemoryStore::new();
239        let thread_id = ThreadId::new();
240
241        // Initially none
242        let state = store.load(&thread_id).await?;
243        assert!(state.is_none());
244
245        // Save state
246        let state = AgentState::new(thread_id.clone());
247        store.save(&state).await?;
248
249        // Load state
250        let loaded = store.load(&thread_id).await?;
251        assert!(loaded.is_some());
252        if let Some(loaded_state) = loaded {
253            assert_eq!(loaded_state.thread_id, thread_id);
254        }
255
256        // Delete state
257        store.delete(&thread_id).await?;
258        let state = store.load(&thread_id).await?;
259        assert!(state.is_none());
260
261        Ok(())
262    }
263}