use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use crate::stripe::error::{StripeWebhookError, StripeWebhookResult};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ProcessingStatus {
InProgress,
Completed,
Failed { error: String },
}
#[derive(Debug, Clone)]
pub struct IdempotencyEntry {
pub event_id: String,
pub status: ProcessingStatus,
pub received_at: Instant,
pub updated_at: Instant,
pub attempts: u32,
}
#[async_trait::async_trait]
pub trait IdempotencyStore: Send + Sync + 'static {
async fn check_and_record(&self, event_id: &str) -> StripeWebhookResult<bool>;
async fn mark_completed(&self, event_id: &str) -> StripeWebhookResult<()>;
async fn mark_failed(&self, event_id: &str, error: &str) -> StripeWebhookResult<()>;
async fn get_status(&self, event_id: &str) -> StripeWebhookResult<Option<IdempotencyEntry>>;
async fn cleanup(&self) -> StripeWebhookResult<usize>;
}
pub struct InMemoryIdempotencyStore {
entries: Arc<RwLock<HashMap<String, IdempotencyEntry>>>,
ttl: Duration,
max_entries: usize,
}
impl InMemoryIdempotencyStore {
pub fn new(ttl: Duration, max_entries: usize) -> Self {
Self {
entries: Arc::new(RwLock::new(HashMap::new())),
ttl,
max_entries,
}
}
pub fn from_config(config: &crate::stripe::config::StripeWebhookConfig) -> Self {
Self::new(config.idempotency_ttl, config.idempotency_max_entries)
}
fn is_expired(&self, entry: &IdempotencyEntry) -> bool {
entry.received_at.elapsed() > self.ttl
}
pub async fn len(&self) -> usize {
self.entries.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.entries.read().await.is_empty()
}
}
#[async_trait::async_trait]
impl IdempotencyStore for InMemoryIdempotencyStore {
async fn check_and_record(&self, event_id: &str) -> StripeWebhookResult<bool> {
let mut entries = self.entries.write().await;
let now = Instant::now();
if let Some(existing) = entries.get(event_id) {
if !self.is_expired(existing) {
tracing::debug!(
event_id,
status = ?existing.status,
"Event already in idempotency store"
);
return Ok(false);
}
entries.remove(event_id);
}
if entries.len() >= self.max_entries {
let mut to_remove: Vec<String> = entries
.iter()
.filter(|(_, entry)| self.is_expired(entry))
.map(|(id, _)| id.clone())
.collect();
if to_remove.len() < entries.len() / 10 {
let mut by_age: Vec<_> = entries.iter().collect();
by_age.sort_by_key(|(_, entry)| entry.received_at);
to_remove.extend(
by_age
.iter()
.take(entries.len() / 10)
.map(|(id, _)| (*id).clone()),
);
}
for id in to_remove {
entries.remove(&id);
}
tracing::info!(
remaining = entries.len(),
max = self.max_entries,
"Evicted old idempotency entries"
);
}
entries.insert(
event_id.to_string(),
IdempotencyEntry {
event_id: event_id.to_string(),
status: ProcessingStatus::InProgress,
received_at: now,
updated_at: now,
attempts: 1,
},
);
tracing::debug!(event_id, "New event recorded in idempotency store");
Ok(true)
}
async fn mark_completed(&self, event_id: &str) -> StripeWebhookResult<()> {
let mut entries = self.entries.write().await;
if let Some(entry) = entries.get_mut(event_id) {
entry.status = ProcessingStatus::Completed;
entry.updated_at = Instant::now();
tracing::debug!(event_id, "Event marked as completed");
} else {
tracing::warn!(
event_id,
"Attempted to mark non-existent event as completed"
);
}
Ok(())
}
async fn mark_failed(&self, event_id: &str, error: &str) -> StripeWebhookResult<()> {
let mut entries = self.entries.write().await;
if let Some(entry) = entries.get_mut(event_id) {
entry.status = ProcessingStatus::Failed {
error: error.to_string(),
};
entry.updated_at = Instant::now();
tracing::debug!(event_id, error, "Event marked as failed");
} else {
tracing::warn!(event_id, "Attempted to mark non-existent event as failed");
}
Ok(())
}
async fn get_status(&self, event_id: &str) -> StripeWebhookResult<Option<IdempotencyEntry>> {
let entries = self.entries.read().await;
if let Some(entry) = entries.get(event_id) {
if self.is_expired(entry) {
return Ok(None);
}
return Ok(Some(entry.clone()));
}
Ok(None)
}
async fn cleanup(&self) -> StripeWebhookResult<usize> {
let mut entries = self.entries.write().await;
let before = entries.len();
entries.retain(|_, entry| !self.is_expired(entry));
let removed = before - entries.len();
if removed > 0 {
tracing::info!(
removed,
remaining = entries.len(),
"Cleaned up expired idempotency entries"
);
}
Ok(removed)
}
}
pub struct IdempotencyMiddleware<S: IdempotencyStore> {
store: Arc<S>,
}
impl<S: IdempotencyStore> IdempotencyMiddleware<S> {
pub fn new(store: Arc<S>) -> Self {
Self { store }
}
pub async fn should_process(&self, event_id: &str) -> StripeWebhookResult<bool> {
match self.store.check_and_record(event_id).await {
Ok(true) => Ok(true),
Ok(false) => {
Err(StripeWebhookError::AlreadyProcessed {
event_id: event_id.to_string(),
})
}
Err(e) => Err(e),
}
}
pub async fn complete(&self, event_id: &str) -> StripeWebhookResult<()> {
self.store.mark_completed(event_id).await
}
pub async fn fail(&self, event_id: &str, error: &str) -> StripeWebhookResult<()> {
self.store.mark_failed(event_id, error).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_check_and_record_new_event() {
let store = InMemoryIdempotencyStore::new(Duration::from_secs(3600), 1000);
let result = store.check_and_record("evt_123").await.unwrap();
assert!(result);
let result = store.check_and_record("evt_123").await.unwrap();
assert!(!result); }
#[tokio::test]
async fn test_check_and_record_different_events() {
let store = InMemoryIdempotencyStore::new(Duration::from_secs(3600), 1000);
assert!(store.check_and_record("evt_1").await.unwrap());
assert!(store.check_and_record("evt_2").await.unwrap());
assert!(store.check_and_record("evt_3").await.unwrap());
assert!(!store.check_and_record("evt_1").await.unwrap());
assert!(!store.check_and_record("evt_2").await.unwrap());
}
#[tokio::test]
async fn test_mark_completed() {
let store = InMemoryIdempotencyStore::new(Duration::from_secs(3600), 1000);
store.check_and_record("evt_123").await.unwrap();
store.mark_completed("evt_123").await.unwrap();
let entry = store.get_status("evt_123").await.unwrap().unwrap();
assert_eq!(entry.status, ProcessingStatus::Completed);
}
#[tokio::test]
async fn test_mark_failed() {
let store = InMemoryIdempotencyStore::new(Duration::from_secs(3600), 1000);
store.check_and_record("evt_123").await.unwrap();
store
.mark_failed("evt_123", "Database error")
.await
.unwrap();
let entry = store.get_status("evt_123").await.unwrap().unwrap();
assert!(matches!(entry.status, ProcessingStatus::Failed { .. }));
}
#[tokio::test]
async fn test_expired_entries() {
let store = InMemoryIdempotencyStore::new(Duration::from_millis(10), 1000);
store.check_and_record("evt_123").await.unwrap();
tokio::time::sleep(Duration::from_millis(20)).await;
assert!(store.check_and_record("evt_123").await.unwrap());
}
#[tokio::test]
async fn test_cleanup() {
let store = InMemoryIdempotencyStore::new(Duration::from_millis(10), 1000);
store.check_and_record("evt_1").await.unwrap();
store.check_and_record("evt_2").await.unwrap();
tokio::time::sleep(Duration::from_millis(20)).await;
let removed = store.cleanup().await.unwrap();
assert_eq!(removed, 2);
assert!(store.is_empty().await);
}
#[tokio::test]
async fn test_max_entries_eviction() {
let store = InMemoryIdempotencyStore::new(Duration::from_secs(3600), 10);
for i in 0..15 {
store.check_and_record(&format!("evt_{}", i)).await.unwrap();
}
assert!(store.len().await <= 15);
}
#[tokio::test]
async fn test_idempotency_middleware() {
let store = Arc::new(InMemoryIdempotencyStore::new(
Duration::from_secs(3600),
1000,
));
let middleware = IdempotencyMiddleware::new(store);
assert!(middleware.should_process("evt_123").await.is_ok());
let result = middleware.should_process("evt_123").await;
assert!(matches!(
result,
Err(StripeWebhookError::AlreadyProcessed { .. })
));
middleware.complete("evt_123").await.unwrap();
}
}