agents_aws/
dynamodb_checkpointer.rs1use agents_core::persistence::{Checkpointer, ThreadId};
34use agents_core::state::AgentStateSnapshot;
35use anyhow::Context;
36use async_trait::async_trait;
37use aws_sdk_dynamodb::{types::AttributeValue, Client};
38use std::collections::HashMap;
39use std::time::Duration;
40
41#[derive(Clone)]
65pub struct DynamoDbCheckpointer {
66 client: Client,
67 table_name: String,
68 ttl_seconds: Option<i64>,
69}
70
71impl DynamoDbCheckpointer {
72 pub async fn new(table_name: impl Into<String>) -> anyhow::Result<Self> {
81 Self::builder().table_name(table_name).build().await
82 }
83
84 pub fn builder() -> DynamoDbCheckpointerBuilder {
86 DynamoDbCheckpointerBuilder::default()
87 }
88
89 fn calculate_ttl(&self) -> Option<i64> {
91 self.ttl_seconds.map(|ttl| {
92 std::time::SystemTime::now()
93 .duration_since(std::time::UNIX_EPOCH)
94 .unwrap()
95 .as_secs() as i64
96 + ttl
97 })
98 }
99}
100
101#[async_trait]
102impl Checkpointer for DynamoDbCheckpointer {
103 async fn save_state(
104 &self,
105 thread_id: &ThreadId,
106 state: &AgentStateSnapshot,
107 ) -> anyhow::Result<()> {
108 let state_json =
109 serde_json::to_string(state).context("Failed to serialize agent state to JSON")?;
110
111 let mut item = HashMap::new();
112 item.insert(
113 "thread_id".to_string(),
114 AttributeValue::S(thread_id.clone()),
115 );
116 item.insert("state".to_string(), AttributeValue::S(state_json));
117 item.insert(
118 "updated_at".to_string(),
119 AttributeValue::S(chrono::Utc::now().to_rfc3339()),
120 );
121
122 if let Some(ttl) = self.calculate_ttl() {
124 item.insert("ttl".to_string(), AttributeValue::N(ttl.to_string()));
125 }
126
127 self.client
128 .put_item()
129 .table_name(&self.table_name)
130 .set_item(Some(item))
131 .send()
132 .await
133 .context("Failed to save state to DynamoDB")?;
134
135 tracing::debug!(
136 thread_id = %thread_id,
137 table = %self.table_name,
138 "Saved agent state to DynamoDB"
139 );
140
141 Ok(())
142 }
143
144 async fn load_state(&self, thread_id: &ThreadId) -> anyhow::Result<Option<AgentStateSnapshot>> {
145 let mut key = HashMap::new();
146 key.insert(
147 "thread_id".to_string(),
148 AttributeValue::S(thread_id.clone()),
149 );
150
151 let result = self
152 .client
153 .get_item()
154 .table_name(&self.table_name)
155 .set_key(Some(key))
156 .send()
157 .await
158 .context("Failed to load state from DynamoDB")?;
159
160 match result.item {
161 Some(item) => {
162 let state_value = item
163 .get("state")
164 .and_then(|v| v.as_s().ok())
165 .ok_or_else(|| anyhow::anyhow!("State attribute not found or invalid"))?;
166
167 let state: AgentStateSnapshot = serde_json::from_str(state_value)
168 .context("Failed to deserialize agent state from JSON")?;
169
170 tracing::debug!(
171 thread_id = %thread_id,
172 table = %self.table_name,
173 "Loaded agent state from DynamoDB"
174 );
175
176 Ok(Some(state))
177 }
178 None => {
179 tracing::debug!(
180 thread_id = %thread_id,
181 table = %self.table_name,
182 "No saved state found in DynamoDB"
183 );
184 Ok(None)
185 }
186 }
187 }
188
189 async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
190 let mut key = HashMap::new();
191 key.insert(
192 "thread_id".to_string(),
193 AttributeValue::S(thread_id.clone()),
194 );
195
196 self.client
197 .delete_item()
198 .table_name(&self.table_name)
199 .set_key(Some(key))
200 .send()
201 .await
202 .context("Failed to delete thread from DynamoDB")?;
203
204 tracing::debug!(
205 thread_id = %thread_id,
206 table = %self.table_name,
207 "Deleted thread from DynamoDB"
208 );
209
210 Ok(())
211 }
212
213 async fn list_threads(&self) -> anyhow::Result<Vec<ThreadId>> {
214 let mut threads = Vec::new();
215 let mut last_evaluated_key: Option<HashMap<String, AttributeValue>> = None;
216
217 loop {
218 let mut scan = self
219 .client
220 .scan()
221 .table_name(&self.table_name)
222 .projection_expression("thread_id");
223
224 if let Some(key) = last_evaluated_key {
225 scan = scan.set_exclusive_start_key(Some(key));
226 }
227
228 let result = scan
229 .send()
230 .await
231 .context("Failed to list threads from DynamoDB")?;
232
233 if let Some(items) = result.items {
234 for item in items {
235 if let Some(thread_id) = item
236 .get("thread_id")
237 .and_then(|v| v.as_s().ok())
238 .map(|s| s.to_string())
239 {
240 threads.push(thread_id);
241 }
242 }
243 }
244
245 last_evaluated_key = result.last_evaluated_key;
246
247 if last_evaluated_key.is_none() {
248 break;
249 }
250 }
251
252 Ok(threads)
253 }
254}
255
256#[derive(Default)]
258pub struct DynamoDbCheckpointerBuilder {
259 table_name: Option<String>,
260 ttl: Option<Duration>,
261 client: Option<Client>,
262}
263
264impl DynamoDbCheckpointerBuilder {
265 pub fn table_name(mut self, table_name: impl Into<String>) -> Self {
267 self.table_name = Some(table_name.into());
268 self
269 }
270
271 pub fn ttl(mut self, ttl: Duration) -> Self {
276 self.ttl = Some(ttl);
277 self
278 }
279
280 pub fn client(mut self, client: Client) -> Self {
284 self.client = Some(client);
285 self
286 }
287
288 pub async fn build(self) -> anyhow::Result<DynamoDbCheckpointer> {
290 let table_name = self
291 .table_name
292 .ok_or_else(|| anyhow::anyhow!("Table name is required"))?;
293
294 let client = match self.client {
295 Some(client) => client,
296 None => {
297 let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
298 Client::new(&config)
299 }
300 };
301
302 Ok(DynamoDbCheckpointer {
303 client,
304 table_name,
305 ttl_seconds: self.ttl.map(|d| d.as_secs() as i64),
306 })
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use agents_core::state::TodoItem;
314
315 fn sample_state() -> AgentStateSnapshot {
316 let mut state = AgentStateSnapshot::default();
317 state.todos.push(TodoItem::pending("Test todo"));
318 state
319 .files
320 .insert("test.txt".to_string(), "content".to_string());
321 state
322 .scratchpad
323 .insert("key".to_string(), serde_json::json!("value"));
324 state
325 }
326
327 #[tokio::test]
328 #[ignore] async fn test_dynamodb_save_and_load() {
330 let checkpointer = DynamoDbCheckpointer::new("agent-checkpoints-test")
331 .await
332 .expect("Failed to create DynamoDB client");
333
334 let thread_id = "test-thread".to_string();
335 let state = sample_state();
336
337 checkpointer
339 .save_state(&thread_id, &state)
340 .await
341 .expect("Failed to save state");
342
343 let loaded = checkpointer
345 .load_state(&thread_id)
346 .await
347 .expect("Failed to load state");
348
349 assert!(loaded.is_some());
350 let loaded_state = loaded.unwrap();
351
352 assert_eq!(loaded_state.todos.len(), 1);
353 assert_eq!(loaded_state.files.get("test.txt").unwrap(), "content");
354
355 checkpointer
357 .delete_thread(&thread_id)
358 .await
359 .expect("Failed to delete thread");
360 }
361}