use async_trait::async_trait;
use dashmap::DashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use crate::jobs::JobResult;
#[async_trait]
pub trait IdempotencyStore: Send + Sync {
async fn get(&self, key: &str) -> anyhow::Result<Option<Arc<JobResult>>>;
async fn set(&self, key: &str, result: Arc<JobResult>, ttl: Duration) -> anyhow::Result<()>;
async fn remove(&self, key: &str) -> anyhow::Result<()>;
}
#[derive(Clone, Debug)]
struct IdempotencyEntry {
result: Arc<JobResult>,
expires_at: SystemTime,
}
#[derive(Debug)]
pub struct InMemoryIdempotencyStore {
store: Arc<DashMap<String, IdempotencyEntry>>,
}
impl InMemoryIdempotencyStore {
#[must_use]
pub fn new() -> Self {
Self::default()
}
fn cleanup_expired(&self) {
let now = SystemTime::now();
self.store.retain(|_, entry| entry.expires_at > now);
}
}
impl Default for InMemoryIdempotencyStore {
fn default() -> Self {
Self {
store: Arc::new(DashMap::default()),
}
}
}
#[async_trait]
impl IdempotencyStore for InMemoryIdempotencyStore {
async fn get(&self, key: &str) -> anyhow::Result<Option<Arc<JobResult>>> {
self.cleanup_expired();
self.store.get(key).map_or_else(
|| Ok(None),
|entry| {
if entry.expires_at > SystemTime::now() {
Ok(Some(Arc::clone(&entry.result)))
} else {
Ok(None)
}
},
)
}
async fn set(&self, key: &str, result: Arc<JobResult>, ttl: Duration) -> anyhow::Result<()> {
let expires_at = SystemTime::now()
.checked_add(ttl)
.unwrap_or_else(|| SystemTime::now() + Duration::from_secs(365 * 24 * 60 * 60 * 100)); self.store
.insert(key.to_string(), IdempotencyEntry { result, expires_at });
Ok(())
}
async fn remove(&self, key: &str) -> anyhow::Result<()> {
self.store.remove(key);
Ok(())
}
}
#[cfg(feature = "redis")]
pub struct RedisIdempotencyStore {
client: redis::Client,
key_prefix: String,
}
#[cfg(feature = "redis")]
impl RedisIdempotencyStore {
pub fn new(redis_url: &str, key_prefix: Option<&str>) -> anyhow::Result<Self> {
let client = redis::Client::open(redis_url)?;
Ok(Self {
client,
key_prefix: key_prefix
.map_or_else(|| "riglr:idempotency:".to_string(), |s| s.to_string()),
})
}
fn make_key(&self, key: &str) -> String {
format!("{}{}", self.key_prefix, key)
}
}
#[cfg(feature = "redis")]
#[async_trait]
impl IdempotencyStore for RedisIdempotencyStore {
async fn get(&self, key: &str) -> anyhow::Result<Option<Arc<JobResult>>> {
let mut conn = self.client.get_multiplexed_async_connection().await?;
let redis_key = self.make_key(key);
let result: Option<String> = redis::cmd("GET")
.arg(&redis_key)
.query_async(&mut conn)
.await?;
match result {
Some(json_str) => {
let result: JobResult = serde_json::from_str(&json_str)?;
Ok(Some(Arc::new(result)))
}
None => Ok(None),
}
}
async fn set(&self, key: &str, result: Arc<JobResult>, ttl: Duration) -> anyhow::Result<()> {
let mut conn = self.client.get_multiplexed_async_connection().await?;
let redis_key = self.make_key(key);
let json_str = serde_json::to_string(&*result)?;
let ttl_seconds = ttl.as_secs() as usize;
redis::cmd("SETEX")
.arg(&redis_key)
.arg(ttl_seconds)
.arg(json_str)
.query_async::<()>(&mut conn)
.await?;
Ok(())
}
async fn remove(&self, key: &str) -> anyhow::Result<()> {
let mut conn = self.client.get_multiplexed_async_connection().await?;
let redis_key = self.make_key(key);
redis::cmd("DEL")
.arg(&redis_key)
.query_async::<()>(&mut conn)
.await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_in_memory_idempotency_store_new() {
let store = InMemoryIdempotencyStore::default();
assert!(store.store.is_empty());
}
#[test]
fn test_in_memory_idempotency_store_default() {
let store = InMemoryIdempotencyStore::default();
assert!(store.store.is_empty());
}
#[tokio::test]
async fn test_in_memory_idempotency_store_basic_operations() {
let store = InMemoryIdempotencyStore::default();
let result = JobResult::success(&"test_value").unwrap();
let key = "test_key";
assert!(store.get(key).await.unwrap().is_none());
store
.set(key, Arc::new(result), Duration::from_secs(60))
.await
.unwrap();
let retrieved = store.get(key).await.unwrap();
assert!(retrieved.is_some());
assert!(retrieved.unwrap().is_success());
store.remove(key).await.unwrap();
assert!(store.get(key).await.unwrap().is_none());
}
#[tokio::test]
async fn test_in_memory_store_with_failure_results() {
let store = InMemoryIdempotencyStore::default();
let key = "failure_key";
let retriable_failure = JobResult::Failure {
error: crate::error::ToolError::retriable_string("Network timeout"),
};
store
.set(key, Arc::new(retriable_failure), Duration::from_secs(60))
.await
.unwrap();
let retrieved = store.get(key).await.unwrap().unwrap();
assert!(!retrieved.is_success());
assert!(retrieved.is_retriable());
let permanent_failure = JobResult::Failure {
error: crate::error::ToolError::permanent_string("Invalid input"),
};
store
.set(key, Arc::new(permanent_failure), Duration::from_secs(60))
.await
.unwrap();
let retrieved = store.get(key).await.unwrap().unwrap();
assert!(!retrieved.is_success());
assert!(!retrieved.is_retriable());
}
#[tokio::test]
async fn test_in_memory_store_with_tx_hash() {
let store = InMemoryIdempotencyStore::default();
let key = "tx_key";
let result = JobResult::success_with_tx(&json!({"amount": 100}), "0x123abc").unwrap();
store
.set(key, Arc::new(result), Duration::from_secs(60))
.await
.unwrap();
let retrieved = store.get(key).await.unwrap().unwrap();
assert!(retrieved.is_success());
}
#[tokio::test]
async fn test_idempotency_expiry() {
let store = InMemoryIdempotencyStore::default();
let result = JobResult::success(&"test_value").unwrap();
let key = "test_key";
store
.set(key, Arc::new(result), Duration::from_millis(200))
.await
.unwrap();
assert!(store.get(key).await.unwrap().is_some());
tokio::time::sleep(Duration::from_millis(500)).await;
assert!(store.get(key).await.unwrap().is_none());
}
#[tokio::test]
async fn test_cleanup_expired_entries() {
let store = InMemoryIdempotencyStore::default();
let result = JobResult::success(&"test").unwrap();
store
.set(
"short_ttl",
Arc::new(result.clone()),
Duration::from_millis(100),
)
.await
.unwrap();
store
.set("long_ttl", Arc::new(result), Duration::from_secs(60))
.await
.unwrap();
assert!(store.get("short_ttl").await.unwrap().is_some());
assert!(store.get("long_ttl").await.unwrap().is_some());
assert_eq!(store.store.len(), 2);
tokio::time::sleep(Duration::from_millis(300)).await;
let _ = store.get("long_ttl").await.unwrap();
assert!(store.get("short_ttl").await.unwrap().is_none());
assert!(store.get("long_ttl").await.unwrap().is_some());
}
#[tokio::test]
async fn test_get_expired_entry_returns_none() {
let store = InMemoryIdempotencyStore::default();
let result = JobResult::success(&"test").unwrap();
let key = "expire_test";
store
.set(key, Arc::new(result), Duration::from_millis(50))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(150)).await;
assert!(store.get(key).await.unwrap().is_none());
}
#[tokio::test]
async fn test_remove_non_existent_key() {
let store = InMemoryIdempotencyStore::default();
store.remove("non_existent").await.unwrap();
}
#[tokio::test]
async fn test_concurrent_operations() {
let store = Arc::new(InMemoryIdempotencyStore::default());
let result = JobResult::success(&"concurrent_test").unwrap();
let mut handles = vec![];
for i in 0..10 {
let store_clone = Arc::clone(&store);
let result_clone = result.clone();
let handle = tokio::spawn(async move {
let key = format!("concurrent_key_{}", i);
store_clone
.set(&key, Arc::new(result_clone), Duration::from_secs(60))
.await
.unwrap();
let retrieved = store_clone.get(&key).await.unwrap();
assert!(retrieved.is_some());
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
for i in 0..10 {
let key = format!("concurrent_key_{}", i);
assert!(store.get(&key).await.unwrap().is_some());
}
}
#[tokio::test]
async fn test_zero_duration_ttl() {
let store = InMemoryIdempotencyStore::default();
let result = JobResult::success(&"zero_ttl").unwrap();
let key = "zero_key";
store
.set(key, Arc::new(result), Duration::from_secs(0))
.await
.unwrap();
assert!(store.get(key).await.unwrap().is_none());
}
#[tokio::test]
async fn test_large_ttl() {
let store = InMemoryIdempotencyStore::default();
let result = JobResult::success(&"large_ttl").unwrap();
let key = "large_key";
store
.set(key, Arc::new(result), Duration::from_secs(u64::MAX))
.await
.unwrap();
assert!(store.get(key).await.unwrap().is_some());
}
#[tokio::test]
async fn test_empty_key() {
let store = InMemoryIdempotencyStore::default();
let result = JobResult::success(&"empty_key_test").unwrap();
store
.set("", Arc::new(result), Duration::from_secs(60))
.await
.unwrap();
assert!(store.get("").await.unwrap().is_some());
store.remove("").await.unwrap();
assert!(store.get("").await.unwrap().is_none());
}
#[tokio::test]
async fn test_special_characters_in_key() {
let store = InMemoryIdempotencyStore::default();
let result = JobResult::success(&"special_chars").unwrap();
let key = "key:with/special\\chars@#$%";
store
.set(key, Arc::new(result), Duration::from_secs(60))
.await
.unwrap();
assert!(store.get(key).await.unwrap().is_some());
store.remove(key).await.unwrap();
assert!(store.get(key).await.unwrap().is_none());
}
#[tokio::test]
async fn test_overwrite_same_key() {
let store = InMemoryIdempotencyStore::default();
let key = "overwrite_key";
let result1 = JobResult::success(&"first_value").unwrap();
let result2 = JobResult::success(&"second_value").unwrap();
store
.set(key, Arc::new(result1), Duration::from_secs(60))
.await
.unwrap();
let retrieved1 = store.get(key).await.unwrap().unwrap();
store
.set(key, Arc::new(result2), Duration::from_secs(60))
.await
.unwrap();
let retrieved2 = store.get(key).await.unwrap().unwrap();
assert_ne!(
serde_json::to_string(&retrieved1).unwrap(),
serde_json::to_string(&retrieved2).unwrap()
);
}
#[test]
fn test_idempotency_entry_creation() {
let result = JobResult::success(&"test").unwrap();
let expires_at = SystemTime::now() + Duration::from_secs(60);
let entry = IdempotencyEntry {
result: Arc::new(result.clone()),
expires_at,
};
let cloned_entry = entry.clone();
assert!(cloned_entry.expires_at == entry.expires_at);
}
#[cfg(feature = "redis")]
mod redis_tests {
use super::*;
#[test]
fn test_redis_store_new_with_default_prefix() {
let result = RedisIdempotencyStore::new("redis://127.0.0.1:6379", None);
match result {
Ok(store) => {
assert_eq!(store.key_prefix, "riglr:idempotency:");
}
Err(_) => {
}
}
}
#[test]
fn test_redis_store_new_with_custom_prefix() {
let result = RedisIdempotencyStore::new("redis://127.0.0.1:6379", Some("custom:"));
match result {
Ok(store) => {
assert_eq!(store.key_prefix, "custom:");
}
Err(_) => {
}
}
}
#[test]
fn test_redis_make_key() {
let result = RedisIdempotencyStore::new("redis://127.0.0.1:6379", Some("test:"));
if let Ok(store) = result {
assert_eq!(store.make_key("mykey"), "test:mykey");
assert_eq!(store.make_key(""), "test:");
assert_eq!(store.make_key("key:with:colons"), "test:key:with:colons");
}
}
#[test]
fn test_redis_invalid_url() {
let result = RedisIdempotencyStore::new("invalid_url", None);
assert!(result.is_err());
}
}
}