riglr_core/
idempotency.rs

1//! Idempotency store for preventing duplicate execution of jobs.
2
3use async_trait::async_trait;
4use dashmap::DashMap;
5use std::sync::Arc;
6use std::time::{Duration, SystemTime};
7
8use crate::jobs::JobResult;
9
10/// Trait for idempotency store implementations
11#[async_trait]
12pub trait IdempotencyStore: Send + Sync {
13    /// Check if a result exists for the given idempotency key
14    async fn get(&self, key: &str) -> anyhow::Result<Option<Arc<JobResult>>>;
15
16    /// Store a result with the given idempotency key and TTL
17    async fn set(&self, key: &str, result: Arc<JobResult>, ttl: Duration) -> anyhow::Result<()>;
18
19    /// Remove an entry by key
20    async fn remove(&self, key: &str) -> anyhow::Result<()>;
21}
22
23/// Entry in the idempotency store
24#[derive(Clone, Debug)]
25struct IdempotencyEntry {
26    result: Arc<JobResult>,
27    expires_at: SystemTime,
28}
29
30/// In-memory idempotency store for testing and development
31#[derive(Debug)]
32pub struct InMemoryIdempotencyStore {
33    store: Arc<DashMap<String, IdempotencyEntry>>,
34}
35
36impl InMemoryIdempotencyStore {
37    /// Create a new in-memory idempotency store
38    #[must_use]
39    pub fn new() -> Self {
40        Self::default()
41    }
42
43    /// Clean up expired entries
44    fn cleanup_expired(&self) {
45        let now = SystemTime::now();
46        self.store.retain(|_, entry| entry.expires_at > now);
47    }
48}
49
50impl Default for InMemoryIdempotencyStore {
51    fn default() -> Self {
52        Self {
53            store: Arc::new(DashMap::default()),
54        }
55    }
56}
57
58#[async_trait]
59impl IdempotencyStore for InMemoryIdempotencyStore {
60    async fn get(&self, key: &str) -> anyhow::Result<Option<Arc<JobResult>>> {
61        // Clean up expired entries periodically
62        self.cleanup_expired();
63
64        self.store.get(key).map_or_else(
65            || Ok(None),
66            |entry| {
67                if entry.expires_at > SystemTime::now() {
68                    Ok(Some(Arc::clone(&entry.result)))
69                } else {
70                    Ok(None)
71                }
72            },
73        )
74    }
75
76    async fn set(&self, key: &str, result: Arc<JobResult>, ttl: Duration) -> anyhow::Result<()> {
77        let expires_at = SystemTime::now()
78            .checked_add(ttl)
79            .unwrap_or_else(|| SystemTime::now() + Duration::from_secs(365 * 24 * 60 * 60 * 100)); // 100 years
80        self.store
81            .insert(key.to_string(), IdempotencyEntry { result, expires_at });
82        Ok(())
83    }
84
85    async fn remove(&self, key: &str) -> anyhow::Result<()> {
86        self.store.remove(key);
87        Ok(())
88    }
89}
90
91/// Redis-based idempotency store for production use
92#[cfg(feature = "redis")]
93pub struct RedisIdempotencyStore {
94    client: redis::Client,
95    key_prefix: String,
96}
97
98#[cfg(feature = "redis")]
99impl RedisIdempotencyStore {
100    /// Create a new Redis idempotency store
101    ///
102    /// # Arguments
103    /// * `redis_url` - Redis connection URL (e.g., "redis://127.0.0.1:6379")
104    /// * `key_prefix` - Prefix for idempotency keys (default: "riglr:idempotency:")
105    pub fn new(redis_url: &str, key_prefix: Option<&str>) -> anyhow::Result<Self> {
106        let client = redis::Client::open(redis_url)?;
107        Ok(Self {
108            client,
109            key_prefix: key_prefix
110                .map_or_else(|| "riglr:idempotency:".to_string(), |s| s.to_string()),
111        })
112    }
113
114    fn make_key(&self, key: &str) -> String {
115        format!("{}{}", self.key_prefix, key)
116    }
117}
118
119#[cfg(feature = "redis")]
120#[async_trait]
121impl IdempotencyStore for RedisIdempotencyStore {
122    async fn get(&self, key: &str) -> anyhow::Result<Option<Arc<JobResult>>> {
123        let mut conn = self.client.get_multiplexed_async_connection().await?;
124        let redis_key = self.make_key(key);
125
126        let result: Option<String> = redis::cmd("GET")
127            .arg(&redis_key)
128            .query_async(&mut conn)
129            .await?;
130
131        match result {
132            Some(json_str) => {
133                let result: JobResult = serde_json::from_str(&json_str)?;
134                Ok(Some(Arc::new(result)))
135            }
136            None => Ok(None),
137        }
138    }
139
140    async fn set(&self, key: &str, result: Arc<JobResult>, ttl: Duration) -> anyhow::Result<()> {
141        let mut conn = self.client.get_multiplexed_async_connection().await?;
142        let redis_key = self.make_key(key);
143        let json_str = serde_json::to_string(&*result)?;
144        let ttl_seconds = ttl.as_secs() as usize;
145
146        redis::cmd("SETEX")
147            .arg(&redis_key)
148            .arg(ttl_seconds)
149            .arg(json_str)
150            .query_async::<()>(&mut conn)
151            .await?;
152
153        Ok(())
154    }
155
156    async fn remove(&self, key: &str) -> anyhow::Result<()> {
157        let mut conn = self.client.get_multiplexed_async_connection().await?;
158        let redis_key = self.make_key(key);
159
160        redis::cmd("DEL")
161            .arg(&redis_key)
162            .query_async::<()>(&mut conn)
163            .await?;
164
165        Ok(())
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use serde_json::json;
173
174    // Test InMemoryIdempotencyStore::new()
175    #[test]
176    fn test_in_memory_idempotency_store_new() {
177        let store = InMemoryIdempotencyStore::default();
178        assert!(store.store.is_empty());
179    }
180
181    // Test InMemoryIdempotencyStore::default()
182    #[test]
183    fn test_in_memory_idempotency_store_default() {
184        let store = InMemoryIdempotencyStore::default();
185        assert!(store.store.is_empty());
186    }
187
188    // Test basic get/set/remove operations
189    #[tokio::test]
190    async fn test_in_memory_idempotency_store_basic_operations() {
191        let store = InMemoryIdempotencyStore::default();
192
193        let result = JobResult::success(&"test_value").unwrap();
194        let key = "test_key";
195
196        // Initially, key should not exist
197        assert!(store.get(key).await.unwrap().is_none());
198
199        // Store a result
200        store
201            .set(key, Arc::new(result), Duration::from_secs(60))
202            .await
203            .unwrap();
204
205        // Should be able to retrieve it
206        let retrieved = store.get(key).await.unwrap();
207        assert!(retrieved.is_some());
208        assert!(retrieved.unwrap().is_success());
209
210        // Remove the entry
211        store.remove(key).await.unwrap();
212        assert!(store.get(key).await.unwrap().is_none());
213    }
214
215    // Test with failure result types
216    #[tokio::test]
217    async fn test_in_memory_store_with_failure_results() {
218        let store = InMemoryIdempotencyStore::default();
219        let key = "failure_key";
220
221        // Test retriable failure
222        let retriable_failure = JobResult::Failure {
223            error: crate::error::ToolError::retriable_string("Network timeout"),
224        };
225        store
226            .set(key, Arc::new(retriable_failure), Duration::from_secs(60))
227            .await
228            .unwrap();
229
230        let retrieved = store.get(key).await.unwrap().unwrap();
231        assert!(!retrieved.is_success());
232        assert!(retrieved.is_retriable());
233
234        // Test permanent failure
235        let permanent_failure = JobResult::Failure {
236            error: crate::error::ToolError::permanent_string("Invalid input"),
237        };
238        store
239            .set(key, Arc::new(permanent_failure), Duration::from_secs(60))
240            .await
241            .unwrap();
242
243        let retrieved = store.get(key).await.unwrap().unwrap();
244        assert!(!retrieved.is_success());
245        assert!(!retrieved.is_retriable());
246    }
247
248    // Test with success result with transaction hash
249    #[tokio::test]
250    async fn test_in_memory_store_with_tx_hash() {
251        let store = InMemoryIdempotencyStore::default();
252        let key = "tx_key";
253
254        let result = JobResult::success_with_tx(&json!({"amount": 100}), "0x123abc").unwrap();
255        store
256            .set(key, Arc::new(result), Duration::from_secs(60))
257            .await
258            .unwrap();
259
260        let retrieved = store.get(key).await.unwrap().unwrap();
261        assert!(retrieved.is_success());
262    }
263
264    // Test expiry behavior
265    #[tokio::test]
266    async fn test_idempotency_expiry() {
267        let store = InMemoryIdempotencyStore::default();
268
269        let result = JobResult::success(&"test_value").unwrap();
270        let key = "test_key";
271
272        // Store with short TTL (very generous for instrumented runs)
273        store
274            .set(key, Arc::new(result), Duration::from_millis(200))
275            .await
276            .unwrap();
277
278        // Should exist initially
279        assert!(store.get(key).await.unwrap().is_some());
280
281        // Wait for expiry (very generous timeout for instrumented runs)
282        tokio::time::sleep(Duration::from_millis(500)).await;
283
284        // Should be expired now
285        assert!(store.get(key).await.unwrap().is_none());
286    }
287
288    // Test cleanup_expired functionality
289    #[tokio::test]
290    async fn test_cleanup_expired_entries() {
291        let store = InMemoryIdempotencyStore::default();
292        let result = JobResult::success(&"test").unwrap();
293
294        // Add multiple entries with different TTLs
295        store
296            .set(
297                "short_ttl",
298                Arc::new(result.clone()),
299                Duration::from_millis(100),
300            )
301            .await
302            .unwrap();
303        store
304            .set("long_ttl", Arc::new(result), Duration::from_secs(60))
305            .await
306            .unwrap();
307
308        // Both should exist initially
309        assert!(store.get("short_ttl").await.unwrap().is_some());
310        assert!(store.get("long_ttl").await.unwrap().is_some());
311        assert_eq!(store.store.len(), 2);
312
313        // Wait for short TTL to expire
314        tokio::time::sleep(Duration::from_millis(300)).await;
315
316        // Accessing any key should trigger cleanup
317        let _ = store.get("long_ttl").await.unwrap();
318
319        // Short TTL should be cleaned up, long TTL should remain
320        assert!(store.get("short_ttl").await.unwrap().is_none());
321        assert!(store.get("long_ttl").await.unwrap().is_some());
322    }
323
324    // Test get with expired entry returns None even if entry exists
325    #[tokio::test]
326    async fn test_get_expired_entry_returns_none() {
327        let store = InMemoryIdempotencyStore::default();
328        let result = JobResult::success(&"test").unwrap();
329        let key = "expire_test";
330
331        // Store with very short TTL
332        store
333            .set(key, Arc::new(result), Duration::from_millis(50))
334            .await
335            .unwrap();
336
337        // Wait for expiry
338        tokio::time::sleep(Duration::from_millis(150)).await;
339
340        // Get should return None even though entry might still be in map before cleanup
341        assert!(store.get(key).await.unwrap().is_none());
342    }
343
344    // Test remove non-existent key
345    #[tokio::test]
346    async fn test_remove_non_existent_key() {
347        let store = InMemoryIdempotencyStore::default();
348
349        // Should not panic or error when removing non-existent key
350        store.remove("non_existent").await.unwrap();
351    }
352
353    // Test multiple concurrent operations
354    #[tokio::test]
355    async fn test_concurrent_operations() {
356        let store = Arc::new(InMemoryIdempotencyStore::default());
357        let result = JobResult::success(&"concurrent_test").unwrap();
358
359        // Spawn multiple tasks setting different keys
360        let mut handles = vec![];
361        for i in 0..10 {
362            let store_clone = Arc::clone(&store);
363            let result_clone = result.clone();
364            let handle = tokio::spawn(async move {
365                let key = format!("concurrent_key_{}", i);
366                store_clone
367                    .set(&key, Arc::new(result_clone), Duration::from_secs(60))
368                    .await
369                    .unwrap();
370
371                // Verify we can retrieve it
372                let retrieved = store_clone.get(&key).await.unwrap();
373                assert!(retrieved.is_some());
374            });
375            handles.push(handle);
376        }
377
378        // Wait for all tasks to complete
379        for handle in handles {
380            handle.await.unwrap();
381        }
382
383        // Verify all entries exist
384        for i in 0..10 {
385            let key = format!("concurrent_key_{}", i);
386            assert!(store.get(&key).await.unwrap().is_some());
387        }
388    }
389
390    // Test zero duration TTL
391    #[tokio::test]
392    async fn test_zero_duration_ttl() {
393        let store = InMemoryIdempotencyStore::default();
394        let result = JobResult::success(&"zero_ttl").unwrap();
395        let key = "zero_key";
396
397        // Set with zero duration (should expire immediately)
398        store
399            .set(key, Arc::new(result), Duration::from_secs(0))
400            .await
401            .unwrap();
402
403        // Should return None as it's already expired
404        assert!(store.get(key).await.unwrap().is_none());
405    }
406
407    // Test very large TTL
408    #[tokio::test]
409    async fn test_large_ttl() {
410        let store = InMemoryIdempotencyStore::default();
411        let result = JobResult::success(&"large_ttl").unwrap();
412        let key = "large_key";
413
414        // Set with very large TTL
415        store
416            .set(key, Arc::new(result), Duration::from_secs(u64::MAX))
417            .await
418            .unwrap();
419
420        // Should still be retrievable
421        assert!(store.get(key).await.unwrap().is_some());
422    }
423
424    // Test empty key
425    #[tokio::test]
426    async fn test_empty_key() {
427        let store = InMemoryIdempotencyStore::default();
428        let result = JobResult::success(&"empty_key_test").unwrap();
429
430        // Should handle empty key without issues
431        store
432            .set("", Arc::new(result), Duration::from_secs(60))
433            .await
434            .unwrap();
435        assert!(store.get("").await.unwrap().is_some());
436        store.remove("").await.unwrap();
437        assert!(store.get("").await.unwrap().is_none());
438    }
439
440    // Test special characters in key
441    #[tokio::test]
442    async fn test_special_characters_in_key() {
443        let store = InMemoryIdempotencyStore::default();
444        let result = JobResult::success(&"special_chars").unwrap();
445        let key = "key:with/special\\chars@#$%";
446
447        store
448            .set(key, Arc::new(result), Duration::from_secs(60))
449            .await
450            .unwrap();
451        assert!(store.get(key).await.unwrap().is_some());
452        store.remove(key).await.unwrap();
453        assert!(store.get(key).await.unwrap().is_none());
454    }
455
456    // Test multiple sets to same key (overwrite)
457    #[tokio::test]
458    async fn test_overwrite_same_key() {
459        let store = InMemoryIdempotencyStore::default();
460        let key = "overwrite_key";
461
462        let result1 = JobResult::success(&"first_value").unwrap();
463        let result2 = JobResult::success(&"second_value").unwrap();
464
465        // Set first value
466        store
467            .set(key, Arc::new(result1), Duration::from_secs(60))
468            .await
469            .unwrap();
470        let retrieved1 = store.get(key).await.unwrap().unwrap();
471
472        // Set second value (should overwrite)
473        store
474            .set(key, Arc::new(result2), Duration::from_secs(60))
475            .await
476            .unwrap();
477        let retrieved2 = store.get(key).await.unwrap().unwrap();
478
479        // Values should be different (second should have overwritten first)
480        assert_ne!(
481            serde_json::to_string(&retrieved1).unwrap(),
482            serde_json::to_string(&retrieved2).unwrap()
483        );
484    }
485
486    // Test IdempotencyEntry creation and expiry logic
487    #[test]
488    fn test_idempotency_entry_creation() {
489        let result = JobResult::success(&"test").unwrap();
490        let expires_at = SystemTime::now() + Duration::from_secs(60);
491
492        let entry = IdempotencyEntry {
493            result: Arc::new(result.clone()),
494            expires_at,
495        };
496
497        // Entry should be cloneable
498        let cloned_entry = entry.clone();
499        assert!(cloned_entry.expires_at == entry.expires_at);
500    }
501
502    // Redis tests (only compiled when redis feature is enabled)
503    #[cfg(feature = "redis")]
504    mod redis_tests {
505        use super::*;
506
507        #[test]
508        fn test_redis_store_new_with_default_prefix() {
509            // Test with a valid URL format but don't require actual Redis connection
510            let result = RedisIdempotencyStore::new("redis://127.0.0.1:6379", None);
511            match result {
512                Ok(store) => {
513                    assert_eq!(store.key_prefix, "riglr:idempotency:");
514                }
515                Err(_) => {
516                    // Redis client creation may fail if redis crate is not available, which is ok
517                }
518            }
519        }
520
521        #[test]
522        fn test_redis_store_new_with_custom_prefix() {
523            let result = RedisIdempotencyStore::new("redis://127.0.0.1:6379", Some("custom:"));
524            match result {
525                Ok(store) => {
526                    assert_eq!(store.key_prefix, "custom:");
527                }
528                Err(_) => {
529                    // Redis client creation may fail if redis crate is not available, which is ok
530                }
531            }
532        }
533
534        #[test]
535        fn test_redis_make_key() {
536            // Test make_key only if we can create a store
537            let result = RedisIdempotencyStore::new("redis://127.0.0.1:6379", Some("test:"));
538            if let Ok(store) = result {
539                assert_eq!(store.make_key("mykey"), "test:mykey");
540                assert_eq!(store.make_key(""), "test:");
541                assert_eq!(store.make_key("key:with:colons"), "test:key:with:colons");
542            }
543            // If we can't create a store, skip this test (Redis not available)
544        }
545
546        #[test]
547        fn test_redis_invalid_url() {
548            let result = RedisIdempotencyStore::new("invalid_url", None);
549            assert!(result.is_err());
550        }
551    }
552}