agents_aws/
dynamodb_checkpointer.rs

1//! DynamoDB-backed checkpointer implementation for AWS deployments.
2//!
3//! This checkpointer stores agent state in Amazon DynamoDB, providing:
4//! - Fully managed, serverless persistence
5//! - Automatic scaling and high availability
6//! - Global table support for multi-region deployments
7//! - Pay-per-request pricing with on-demand mode
8//!
9//! ## Table Schema
10//!
11//! The checkpointer expects a DynamoDB table with the following schema:
12//!
13//! - **Primary Key**: `thread_id` (String)
14//! - **Attributes**:
15//!   - `state` (Map/JSON) - The serialized agent state
16//!   - `updated_at` (String) - ISO 8601 timestamp
17//!   - `ttl` (Number, optional) - Unix epoch for automatic expiration
18//!
19//! ## Setup
20//!
21//! Create the table using AWS CLI:
22//!
23//! ```bash
24//! aws dynamodb create-table \
25//!   --table-name agent-checkpoints \
26//!   --attribute-definitions AttributeName=thread_id,AttributeType=S \
27//!   --key-schema AttributeName=thread_id,KeyType=HASH \
28//!   --billing-mode PAY_PER_REQUEST
29//! ```
30//!
31//! Or use Terraform (see `deploy/modules/dynamodb/`).
32
33use agents_core::persistence::{Checkpointer, ThreadId};
34use agents_core::state::AgentStateSnapshot;
35use anyhow::Context;
36use async_trait::async_trait;
37use aws_sdk_dynamodb::{types::AttributeValue, Client};
38use std::collections::HashMap;
39use std::time::Duration;
40
41/// DynamoDB-backed checkpointer for serverless AWS deployments.
42///
43/// # Examples
44///
45/// ```rust,no_run
46/// use agents_aws::DynamoDbCheckpointer;
47/// use std::time::Duration;
48///
49/// #[tokio::main]
50/// async fn main() -> anyhow::Result<()> {
51///     // Using default AWS configuration
52///     let checkpointer = DynamoDbCheckpointer::new("agent-checkpoints").await?;
53///
54///     // With custom configuration and TTL
55///     let checkpointer = DynamoDbCheckpointer::builder()
56///         .table_name("my-agents")
57///         .ttl(Duration::from_secs(86400 * 7)) // 7 days
58///         .build()
59///         .await?;
60///
61///     Ok(())
62/// }
63/// ```
64#[derive(Clone)]
65pub struct DynamoDbCheckpointer {
66    client: Client,
67    table_name: String,
68    ttl_seconds: Option<i64>,
69}
70
71impl DynamoDbCheckpointer {
72    /// Create a new DynamoDB checkpointer with default AWS configuration.
73    ///
74    /// This will use the default AWS credential chain (environment variables,
75    /// IAM roles, AWS config files, etc.).
76    ///
77    /// # Arguments
78    ///
79    /// * `table_name` - The name of the DynamoDB table
80    pub async fn new(table_name: impl Into<String>) -> anyhow::Result<Self> {
81        Self::builder().table_name(table_name).build().await
82    }
83
84    /// Create a builder for configuring the DynamoDB checkpointer.
85    pub fn builder() -> DynamoDbCheckpointerBuilder {
86        DynamoDbCheckpointerBuilder::default()
87    }
88
89    /// Calculate TTL timestamp for the current time.
90    fn calculate_ttl(&self) -> Option<i64> {
91        self.ttl_seconds.map(|ttl| {
92            std::time::SystemTime::now()
93                .duration_since(std::time::UNIX_EPOCH)
94                .unwrap()
95                .as_secs() as i64
96                + ttl
97        })
98    }
99}
100
101#[async_trait]
102impl Checkpointer for DynamoDbCheckpointer {
103    async fn save_state(
104        &self,
105        thread_id: &ThreadId,
106        state: &AgentStateSnapshot,
107    ) -> anyhow::Result<()> {
108        let state_json =
109            serde_json::to_string(state).context("Failed to serialize agent state to JSON")?;
110
111        let mut item = HashMap::new();
112        item.insert(
113            "thread_id".to_string(),
114            AttributeValue::S(thread_id.clone()),
115        );
116        item.insert("state".to_string(), AttributeValue::S(state_json));
117        item.insert(
118            "updated_at".to_string(),
119            AttributeValue::S(chrono::Utc::now().to_rfc3339()),
120        );
121
122        // Add TTL if configured
123        if let Some(ttl) = self.calculate_ttl() {
124            item.insert("ttl".to_string(), AttributeValue::N(ttl.to_string()));
125        }
126
127        self.client
128            .put_item()
129            .table_name(&self.table_name)
130            .set_item(Some(item))
131            .send()
132            .await
133            .context("Failed to save state to DynamoDB")?;
134
135        tracing::debug!(
136            thread_id = %thread_id,
137            table = %self.table_name,
138            "Saved agent state to DynamoDB"
139        );
140
141        Ok(())
142    }
143
144    async fn load_state(&self, thread_id: &ThreadId) -> anyhow::Result<Option<AgentStateSnapshot>> {
145        let mut key = HashMap::new();
146        key.insert(
147            "thread_id".to_string(),
148            AttributeValue::S(thread_id.clone()),
149        );
150
151        let result = self
152            .client
153            .get_item()
154            .table_name(&self.table_name)
155            .set_key(Some(key))
156            .send()
157            .await
158            .context("Failed to load state from DynamoDB")?;
159
160        match result.item {
161            Some(item) => {
162                let state_value = item
163                    .get("state")
164                    .and_then(|v| v.as_s().ok())
165                    .ok_or_else(|| anyhow::anyhow!("State attribute not found or invalid"))?;
166
167                let state: AgentStateSnapshot = serde_json::from_str(state_value)
168                    .context("Failed to deserialize agent state from JSON")?;
169
170                tracing::debug!(
171                    thread_id = %thread_id,
172                    table = %self.table_name,
173                    "Loaded agent state from DynamoDB"
174                );
175
176                Ok(Some(state))
177            }
178            None => {
179                tracing::debug!(
180                    thread_id = %thread_id,
181                    table = %self.table_name,
182                    "No saved state found in DynamoDB"
183                );
184                Ok(None)
185            }
186        }
187    }
188
189    async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
190        let mut key = HashMap::new();
191        key.insert(
192            "thread_id".to_string(),
193            AttributeValue::S(thread_id.clone()),
194        );
195
196        self.client
197            .delete_item()
198            .table_name(&self.table_name)
199            .set_key(Some(key))
200            .send()
201            .await
202            .context("Failed to delete thread from DynamoDB")?;
203
204        tracing::debug!(
205            thread_id = %thread_id,
206            table = %self.table_name,
207            "Deleted thread from DynamoDB"
208        );
209
210        Ok(())
211    }
212
213    async fn list_threads(&self) -> anyhow::Result<Vec<ThreadId>> {
214        let mut threads = Vec::new();
215        let mut last_evaluated_key: Option<HashMap<String, AttributeValue>> = None;
216
217        loop {
218            let mut scan = self
219                .client
220                .scan()
221                .table_name(&self.table_name)
222                .projection_expression("thread_id");
223
224            if let Some(key) = last_evaluated_key {
225                scan = scan.set_exclusive_start_key(Some(key));
226            }
227
228            let result = scan
229                .send()
230                .await
231                .context("Failed to list threads from DynamoDB")?;
232
233            if let Some(items) = result.items {
234                for item in items {
235                    if let Some(thread_id) = item
236                        .get("thread_id")
237                        .and_then(|v| v.as_s().ok())
238                        .map(|s| s.to_string())
239                    {
240                        threads.push(thread_id);
241                    }
242                }
243            }
244
245            last_evaluated_key = result.last_evaluated_key;
246
247            if last_evaluated_key.is_none() {
248                break;
249            }
250        }
251
252        Ok(threads)
253    }
254}
255
256/// Builder for configuring a DynamoDB checkpointer.
257#[derive(Default)]
258pub struct DynamoDbCheckpointerBuilder {
259    table_name: Option<String>,
260    ttl: Option<Duration>,
261    client: Option<Client>,
262}
263
264impl DynamoDbCheckpointerBuilder {
265    /// Set the DynamoDB table name.
266    pub fn table_name(mut self, table_name: impl Into<String>) -> Self {
267        self.table_name = Some(table_name.into());
268        self
269    }
270
271    /// Set the TTL (time-to-live) for stored states.
272    ///
273    /// DynamoDB will automatically delete items after this duration.
274    /// Note: You must enable TTL on the `ttl` attribute in your table.
275    pub fn ttl(mut self, ttl: Duration) -> Self {
276        self.ttl = Some(ttl);
277        self
278    }
279
280    /// Use a custom DynamoDB client.
281    ///
282    /// This is useful for testing with LocalStack or using custom endpoints.
283    pub fn client(mut self, client: Client) -> Self {
284        self.client = Some(client);
285        self
286    }
287
288    /// Build the DynamoDB checkpointer.
289    pub async fn build(self) -> anyhow::Result<DynamoDbCheckpointer> {
290        let table_name = self
291            .table_name
292            .ok_or_else(|| anyhow::anyhow!("Table name is required"))?;
293
294        let client = match self.client {
295            Some(client) => client,
296            None => {
297                let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
298                Client::new(&config)
299            }
300        };
301
302        Ok(DynamoDbCheckpointer {
303            client,
304            table_name,
305            ttl_seconds: self.ttl.map(|d| d.as_secs() as i64),
306        })
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use agents_core::state::TodoItem;
314
315    fn sample_state() -> AgentStateSnapshot {
316        let mut state = AgentStateSnapshot::default();
317        state.todos.push(TodoItem::pending("Test todo"));
318        state
319            .files
320            .insert("test.txt".to_string(), "content".to_string());
321        state
322            .scratchpad
323            .insert("key".to_string(), serde_json::json!("value"));
324        state
325    }
326
327    #[tokio::test]
328    #[ignore] // Requires DynamoDB or LocalStack
329    async fn test_dynamodb_save_and_load() {
330        let checkpointer = DynamoDbCheckpointer::new("agent-checkpoints-test")
331            .await
332            .expect("Failed to create DynamoDB client");
333
334        let thread_id = "test-thread".to_string();
335        let state = sample_state();
336
337        // Save state
338        checkpointer
339            .save_state(&thread_id, &state)
340            .await
341            .expect("Failed to save state");
342
343        // Load state
344        let loaded = checkpointer
345            .load_state(&thread_id)
346            .await
347            .expect("Failed to load state");
348
349        assert!(loaded.is_some());
350        let loaded_state = loaded.unwrap();
351
352        assert_eq!(loaded_state.todos.len(), 1);
353        assert_eq!(loaded_state.files.get("test.txt").unwrap(), "content");
354
355        // Cleanup
356        checkpointer
357            .delete_thread(&thread_id)
358            .await
359            .expect("Failed to delete thread");
360    }
361}