1use async_trait::async_trait;
4use dashmap::DashMap;
5use std::sync::Arc;
6use std::time::{Duration, SystemTime};
7
8use crate::jobs::JobResult;
9
10#[async_trait]
12pub trait IdempotencyStore: Send + Sync {
13 async fn get(&self, key: &str) -> anyhow::Result<Option<Arc<JobResult>>>;
15
16 async fn set(&self, key: &str, result: Arc<JobResult>, ttl: Duration) -> anyhow::Result<()>;
18
19 async fn remove(&self, key: &str) -> anyhow::Result<()>;
21}
22
23#[derive(Clone, Debug)]
25struct IdempotencyEntry {
26 result: Arc<JobResult>,
27 expires_at: SystemTime,
28}
29
30#[derive(Debug)]
32pub struct InMemoryIdempotencyStore {
33 store: Arc<DashMap<String, IdempotencyEntry>>,
34}
35
36impl InMemoryIdempotencyStore {
37 #[must_use]
39 pub fn new() -> Self {
40 Self::default()
41 }
42
43 fn cleanup_expired(&self) {
45 let now = SystemTime::now();
46 self.store.retain(|_, entry| entry.expires_at > now);
47 }
48}
49
50impl Default for InMemoryIdempotencyStore {
51 fn default() -> Self {
52 Self {
53 store: Arc::new(DashMap::default()),
54 }
55 }
56}
57
58#[async_trait]
59impl IdempotencyStore for InMemoryIdempotencyStore {
60 async fn get(&self, key: &str) -> anyhow::Result<Option<Arc<JobResult>>> {
61 self.cleanup_expired();
63
64 self.store.get(key).map_or_else(
65 || Ok(None),
66 |entry| {
67 if entry.expires_at > SystemTime::now() {
68 Ok(Some(Arc::clone(&entry.result)))
69 } else {
70 Ok(None)
71 }
72 },
73 )
74 }
75
76 async fn set(&self, key: &str, result: Arc<JobResult>, ttl: Duration) -> anyhow::Result<()> {
77 let expires_at = SystemTime::now()
78 .checked_add(ttl)
79 .unwrap_or_else(|| SystemTime::now() + Duration::from_secs(365 * 24 * 60 * 60 * 100)); self.store
81 .insert(key.to_string(), IdempotencyEntry { result, expires_at });
82 Ok(())
83 }
84
85 async fn remove(&self, key: &str) -> anyhow::Result<()> {
86 self.store.remove(key);
87 Ok(())
88 }
89}
90
91#[cfg(feature = "redis")]
93pub struct RedisIdempotencyStore {
94 client: redis::Client,
95 key_prefix: String,
96}
97
98#[cfg(feature = "redis")]
99impl RedisIdempotencyStore {
100 pub fn new(redis_url: &str, key_prefix: Option<&str>) -> anyhow::Result<Self> {
106 let client = redis::Client::open(redis_url)?;
107 Ok(Self {
108 client,
109 key_prefix: key_prefix
110 .map_or_else(|| "riglr:idempotency:".to_string(), |s| s.to_string()),
111 })
112 }
113
114 fn make_key(&self, key: &str) -> String {
115 format!("{}{}", self.key_prefix, key)
116 }
117}
118
119#[cfg(feature = "redis")]
120#[async_trait]
121impl IdempotencyStore for RedisIdempotencyStore {
122 async fn get(&self, key: &str) -> anyhow::Result<Option<Arc<JobResult>>> {
123 let mut conn = self.client.get_multiplexed_async_connection().await?;
124 let redis_key = self.make_key(key);
125
126 let result: Option<String> = redis::cmd("GET")
127 .arg(&redis_key)
128 .query_async(&mut conn)
129 .await?;
130
131 match result {
132 Some(json_str) => {
133 let result: JobResult = serde_json::from_str(&json_str)?;
134 Ok(Some(Arc::new(result)))
135 }
136 None => Ok(None),
137 }
138 }
139
140 async fn set(&self, key: &str, result: Arc<JobResult>, ttl: Duration) -> anyhow::Result<()> {
141 let mut conn = self.client.get_multiplexed_async_connection().await?;
142 let redis_key = self.make_key(key);
143 let json_str = serde_json::to_string(&*result)?;
144 let ttl_seconds = ttl.as_secs() as usize;
145
146 redis::cmd("SETEX")
147 .arg(&redis_key)
148 .arg(ttl_seconds)
149 .arg(json_str)
150 .query_async::<()>(&mut conn)
151 .await?;
152
153 Ok(())
154 }
155
156 async fn remove(&self, key: &str) -> anyhow::Result<()> {
157 let mut conn = self.client.get_multiplexed_async_connection().await?;
158 let redis_key = self.make_key(key);
159
160 redis::cmd("DEL")
161 .arg(&redis_key)
162 .query_async::<()>(&mut conn)
163 .await?;
164
165 Ok(())
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use serde_json::json;
173
174 #[test]
176 fn test_in_memory_idempotency_store_new() {
177 let store = InMemoryIdempotencyStore::default();
178 assert!(store.store.is_empty());
179 }
180
181 #[test]
183 fn test_in_memory_idempotency_store_default() {
184 let store = InMemoryIdempotencyStore::default();
185 assert!(store.store.is_empty());
186 }
187
188 #[tokio::test]
190 async fn test_in_memory_idempotency_store_basic_operations() {
191 let store = InMemoryIdempotencyStore::default();
192
193 let result = JobResult::success(&"test_value").unwrap();
194 let key = "test_key";
195
196 assert!(store.get(key).await.unwrap().is_none());
198
199 store
201 .set(key, Arc::new(result), Duration::from_secs(60))
202 .await
203 .unwrap();
204
205 let retrieved = store.get(key).await.unwrap();
207 assert!(retrieved.is_some());
208 assert!(retrieved.unwrap().is_success());
209
210 store.remove(key).await.unwrap();
212 assert!(store.get(key).await.unwrap().is_none());
213 }
214
215 #[tokio::test]
217 async fn test_in_memory_store_with_failure_results() {
218 let store = InMemoryIdempotencyStore::default();
219 let key = "failure_key";
220
221 let retriable_failure = JobResult::Failure {
223 error: crate::error::ToolError::retriable_string("Network timeout"),
224 };
225 store
226 .set(key, Arc::new(retriable_failure), Duration::from_secs(60))
227 .await
228 .unwrap();
229
230 let retrieved = store.get(key).await.unwrap().unwrap();
231 assert!(!retrieved.is_success());
232 assert!(retrieved.is_retriable());
233
234 let permanent_failure = JobResult::Failure {
236 error: crate::error::ToolError::permanent_string("Invalid input"),
237 };
238 store
239 .set(key, Arc::new(permanent_failure), Duration::from_secs(60))
240 .await
241 .unwrap();
242
243 let retrieved = store.get(key).await.unwrap().unwrap();
244 assert!(!retrieved.is_success());
245 assert!(!retrieved.is_retriable());
246 }
247
248 #[tokio::test]
250 async fn test_in_memory_store_with_tx_hash() {
251 let store = InMemoryIdempotencyStore::default();
252 let key = "tx_key";
253
254 let result = JobResult::success_with_tx(&json!({"amount": 100}), "0x123abc").unwrap();
255 store
256 .set(key, Arc::new(result), Duration::from_secs(60))
257 .await
258 .unwrap();
259
260 let retrieved = store.get(key).await.unwrap().unwrap();
261 assert!(retrieved.is_success());
262 }
263
264 #[tokio::test]
266 async fn test_idempotency_expiry() {
267 let store = InMemoryIdempotencyStore::default();
268
269 let result = JobResult::success(&"test_value").unwrap();
270 let key = "test_key";
271
272 store
274 .set(key, Arc::new(result), Duration::from_millis(200))
275 .await
276 .unwrap();
277
278 assert!(store.get(key).await.unwrap().is_some());
280
281 tokio::time::sleep(Duration::from_millis(500)).await;
283
284 assert!(store.get(key).await.unwrap().is_none());
286 }
287
288 #[tokio::test]
290 async fn test_cleanup_expired_entries() {
291 let store = InMemoryIdempotencyStore::default();
292 let result = JobResult::success(&"test").unwrap();
293
294 store
296 .set(
297 "short_ttl",
298 Arc::new(result.clone()),
299 Duration::from_millis(100),
300 )
301 .await
302 .unwrap();
303 store
304 .set("long_ttl", Arc::new(result), Duration::from_secs(60))
305 .await
306 .unwrap();
307
308 assert!(store.get("short_ttl").await.unwrap().is_some());
310 assert!(store.get("long_ttl").await.unwrap().is_some());
311 assert_eq!(store.store.len(), 2);
312
313 tokio::time::sleep(Duration::from_millis(300)).await;
315
316 let _ = store.get("long_ttl").await.unwrap();
318
319 assert!(store.get("short_ttl").await.unwrap().is_none());
321 assert!(store.get("long_ttl").await.unwrap().is_some());
322 }
323
324 #[tokio::test]
326 async fn test_get_expired_entry_returns_none() {
327 let store = InMemoryIdempotencyStore::default();
328 let result = JobResult::success(&"test").unwrap();
329 let key = "expire_test";
330
331 store
333 .set(key, Arc::new(result), Duration::from_millis(50))
334 .await
335 .unwrap();
336
337 tokio::time::sleep(Duration::from_millis(150)).await;
339
340 assert!(store.get(key).await.unwrap().is_none());
342 }
343
344 #[tokio::test]
346 async fn test_remove_non_existent_key() {
347 let store = InMemoryIdempotencyStore::default();
348
349 store.remove("non_existent").await.unwrap();
351 }
352
353 #[tokio::test]
355 async fn test_concurrent_operations() {
356 let store = Arc::new(InMemoryIdempotencyStore::default());
357 let result = JobResult::success(&"concurrent_test").unwrap();
358
359 let mut handles = vec![];
361 for i in 0..10 {
362 let store_clone = Arc::clone(&store);
363 let result_clone = result.clone();
364 let handle = tokio::spawn(async move {
365 let key = format!("concurrent_key_{}", i);
366 store_clone
367 .set(&key, Arc::new(result_clone), Duration::from_secs(60))
368 .await
369 .unwrap();
370
371 let retrieved = store_clone.get(&key).await.unwrap();
373 assert!(retrieved.is_some());
374 });
375 handles.push(handle);
376 }
377
378 for handle in handles {
380 handle.await.unwrap();
381 }
382
383 for i in 0..10 {
385 let key = format!("concurrent_key_{}", i);
386 assert!(store.get(&key).await.unwrap().is_some());
387 }
388 }
389
390 #[tokio::test]
392 async fn test_zero_duration_ttl() {
393 let store = InMemoryIdempotencyStore::default();
394 let result = JobResult::success(&"zero_ttl").unwrap();
395 let key = "zero_key";
396
397 store
399 .set(key, Arc::new(result), Duration::from_secs(0))
400 .await
401 .unwrap();
402
403 assert!(store.get(key).await.unwrap().is_none());
405 }
406
407 #[tokio::test]
409 async fn test_large_ttl() {
410 let store = InMemoryIdempotencyStore::default();
411 let result = JobResult::success(&"large_ttl").unwrap();
412 let key = "large_key";
413
414 store
416 .set(key, Arc::new(result), Duration::from_secs(u64::MAX))
417 .await
418 .unwrap();
419
420 assert!(store.get(key).await.unwrap().is_some());
422 }
423
424 #[tokio::test]
426 async fn test_empty_key() {
427 let store = InMemoryIdempotencyStore::default();
428 let result = JobResult::success(&"empty_key_test").unwrap();
429
430 store
432 .set("", Arc::new(result), Duration::from_secs(60))
433 .await
434 .unwrap();
435 assert!(store.get("").await.unwrap().is_some());
436 store.remove("").await.unwrap();
437 assert!(store.get("").await.unwrap().is_none());
438 }
439
440 #[tokio::test]
442 async fn test_special_characters_in_key() {
443 let store = InMemoryIdempotencyStore::default();
444 let result = JobResult::success(&"special_chars").unwrap();
445 let key = "key:with/special\\chars@#$%";
446
447 store
448 .set(key, Arc::new(result), Duration::from_secs(60))
449 .await
450 .unwrap();
451 assert!(store.get(key).await.unwrap().is_some());
452 store.remove(key).await.unwrap();
453 assert!(store.get(key).await.unwrap().is_none());
454 }
455
456 #[tokio::test]
458 async fn test_overwrite_same_key() {
459 let store = InMemoryIdempotencyStore::default();
460 let key = "overwrite_key";
461
462 let result1 = JobResult::success(&"first_value").unwrap();
463 let result2 = JobResult::success(&"second_value").unwrap();
464
465 store
467 .set(key, Arc::new(result1), Duration::from_secs(60))
468 .await
469 .unwrap();
470 let retrieved1 = store.get(key).await.unwrap().unwrap();
471
472 store
474 .set(key, Arc::new(result2), Duration::from_secs(60))
475 .await
476 .unwrap();
477 let retrieved2 = store.get(key).await.unwrap().unwrap();
478
479 assert_ne!(
481 serde_json::to_string(&retrieved1).unwrap(),
482 serde_json::to_string(&retrieved2).unwrap()
483 );
484 }
485
486 #[test]
488 fn test_idempotency_entry_creation() {
489 let result = JobResult::success(&"test").unwrap();
490 let expires_at = SystemTime::now() + Duration::from_secs(60);
491
492 let entry = IdempotencyEntry {
493 result: Arc::new(result.clone()),
494 expires_at,
495 };
496
497 let cloned_entry = entry.clone();
499 assert!(cloned_entry.expires_at == entry.expires_at);
500 }
501
502 #[cfg(feature = "redis")]
504 mod redis_tests {
505 use super::*;
506
507 #[test]
508 fn test_redis_store_new_with_default_prefix() {
509 let result = RedisIdempotencyStore::new("redis://127.0.0.1:6379", None);
511 match result {
512 Ok(store) => {
513 assert_eq!(store.key_prefix, "riglr:idempotency:");
514 }
515 Err(_) => {
516 }
518 }
519 }
520
521 #[test]
522 fn test_redis_store_new_with_custom_prefix() {
523 let result = RedisIdempotencyStore::new("redis://127.0.0.1:6379", Some("custom:"));
524 match result {
525 Ok(store) => {
526 assert_eq!(store.key_prefix, "custom:");
527 }
528 Err(_) => {
529 }
531 }
532 }
533
534 #[test]
535 fn test_redis_make_key() {
536 let result = RedisIdempotencyStore::new("redis://127.0.0.1:6379", Some("test:"));
538 if let Ok(store) = result {
539 assert_eq!(store.make_key("mykey"), "test:mykey");
540 assert_eq!(store.make_key(""), "test:");
541 assert_eq!(store.make_key("key:with:colons"), "test:key:with:colons");
542 }
543 }
545
546 #[test]
547 fn test_redis_invalid_url() {
548 let result = RedisIdempotencyStore::new("invalid_url", None);
549 assert!(result.is_err());
550 }
551 }
552}