use super::{ConfirmationExpired, ConfirmationStore, PendingActionInfo};
use crate::error::Error;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use ferro_events::dispatch_sync;
use std::sync::Arc;
use std::time::Duration;
use tokio::task::AbortHandle;
use tokio::time::sleep;
struct StoredAction {
payload: serde_json::Value,
created_at: DateTime<Utc>,
abort_handle: AbortHandle,
}
pub struct InMemoryConfirmationStore {
inner: Arc<DashMap<String, StoredAction>>,
}
impl InMemoryConfirmationStore {
pub fn new() -> Self {
Self {
inner: Arc::new(DashMap::new()),
}
}
}
impl Default for InMemoryConfirmationStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ConfirmationStore for InMemoryConfirmationStore {
async fn request_confirmation(
&self,
key: &str,
payload: serde_json::Value,
ttl: Duration,
) -> Result<(), Error> {
if let Some((_, old)) = self.inner.remove(key) {
old.abort_handle.abort();
}
let store = Arc::clone(&self.inner);
let key_owned = key.to_string();
let abort_handle = tokio::spawn(async move {
sleep(ttl).await;
if store.remove(&key_owned).is_some() {
dispatch_sync(ConfirmationExpired {
key: key_owned,
expired_at: Utc::now(),
});
}
})
.abort_handle();
self.inner.insert(
key.to_string(),
StoredAction {
payload,
created_at: Utc::now(),
abort_handle,
},
);
Ok(())
}
async fn confirm(&self, key: &str) -> Result<Option<serde_json::Value>, Error> {
match self.inner.remove(key) {
Some((_, entry)) => {
entry.abort_handle.abort();
Ok(Some(entry.payload))
}
None => Ok(None),
}
}
async fn reject(&self, key: &str) -> Result<bool, Error> {
match self.inner.remove(key) {
Some((_, entry)) => {
entry.abort_handle.abort();
Ok(true)
}
None => Ok(false),
}
}
async fn get(&self, key: &str) -> Result<Option<serde_json::Value>, Error> {
let payload = self.inner.get(key).map(|entry| entry.payload.clone());
Ok(payload)
}
async fn list_pending(&self) -> Result<Vec<PendingActionInfo>, Error> {
let pending: Vec<PendingActionInfo> = self
.inner
.iter()
.map(|entry| PendingActionInfo {
key: entry.key().clone(),
created_at: entry.value().created_at,
})
.collect();
Ok(pending)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn store() -> InMemoryConfirmationStore {
InMemoryConfirmationStore::new()
}
fn payload(val: &str) -> serde_json::Value {
serde_json::json!({"action": val})
}
#[tokio::test]
async fn test_request_and_get() {
let s = store();
let p = payload("delete");
s.request_confirmation("k1", p.clone(), Duration::from_secs(60))
.await
.unwrap();
let result = s.get("k1").await.unwrap();
assert_eq!(result, Some(p));
}
#[tokio::test]
async fn test_confirm_returns_payload_and_removes_entry() {
let s = store();
let p = payload("publish");
s.request_confirmation("k2", p.clone(), Duration::from_secs(60))
.await
.unwrap();
let confirmed = s.confirm("k2").await.unwrap();
assert_eq!(confirmed, Some(p));
assert_eq!(s.get("k2").await.unwrap(), None);
}
#[tokio::test]
async fn test_confirm_nonexistent_returns_none() {
let s = store();
let result = s.confirm("no-such-key").await.unwrap();
assert_eq!(result, None);
}
#[tokio::test]
async fn test_reject_removes_entry_and_returns_true() {
let s = store();
s.request_confirmation("k3", payload("send"), Duration::from_secs(60))
.await
.unwrap();
let rejected = s.reject("k3").await.unwrap();
assert!(rejected);
assert_eq!(s.get("k3").await.unwrap(), None);
}
#[tokio::test]
async fn test_reject_nonexistent_returns_false() {
let s = store();
let result = s.reject("ghost").await.unwrap();
assert!(!result);
}
#[tokio::test]
async fn test_list_pending_returns_all_entries() {
let s = store();
s.request_confirmation("a", payload("alpha"), Duration::from_secs(60))
.await
.unwrap();
s.request_confirmation("b", payload("beta"), Duration::from_secs(60))
.await
.unwrap();
let mut pending = s.list_pending().await.unwrap();
pending.sort_by_key(|p| p.key.clone());
assert_eq!(pending.len(), 2);
assert_eq!(pending[0].key, "a");
assert_eq!(pending[1].key, "b");
}
#[tokio::test]
async fn test_overwrite_replaces_payload() {
let s = store();
s.request_confirmation("k4", payload("first"), Duration::from_secs(60))
.await
.unwrap();
s.request_confirmation("k4", payload("second"), Duration::from_secs(60))
.await
.unwrap();
let result = s.get("k4").await.unwrap();
assert_eq!(result, Some(payload("second")));
assert_eq!(s.list_pending().await.unwrap().len(), 1);
}
async fn yield_to_register_timer() {
tokio::task::yield_now().await;
}
async fn yield_after_advance() {
for _ in 0..5 {
tokio::task::yield_now().await;
}
}
#[tokio::test(start_paused = true)]
async fn test_entry_removed_after_ttl_expires() {
let s = store();
s.request_confirmation("ttl-key", payload("expire-me"), Duration::from_millis(100))
.await
.unwrap();
yield_to_register_timer().await;
assert!(s.get("ttl-key").await.unwrap().is_some());
tokio::time::advance(Duration::from_millis(150)).await;
yield_after_advance().await;
assert_eq!(s.get("ttl-key").await.unwrap(), None);
}
#[tokio::test(start_paused = true)]
async fn test_confirm_before_ttl_prevents_expiry() {
let s = store();
s.request_confirmation(
"confirm-early",
payload("confirm"),
Duration::from_millis(100),
)
.await
.unwrap();
yield_to_register_timer().await;
let result = s.confirm("confirm-early").await.unwrap();
assert!(result.is_some());
tokio::time::advance(Duration::from_millis(200)).await;
yield_after_advance().await;
assert_eq!(s.get("confirm-early").await.unwrap(), None);
}
#[tokio::test(start_paused = true)]
async fn test_reject_before_ttl_prevents_expiry() {
let s = store();
s.request_confirmation(
"reject-early",
payload("reject"),
Duration::from_millis(100),
)
.await
.unwrap();
yield_to_register_timer().await;
let rejected = s.reject("reject-early").await.unwrap();
assert!(rejected);
tokio::time::advance(Duration::from_millis(200)).await;
yield_after_advance().await;
assert_eq!(s.get("reject-early").await.unwrap(), None);
}
#[tokio::test(start_paused = true)]
async fn test_independent_ttls() {
let s = store();
s.request_confirmation("short", payload("s"), Duration::from_millis(100))
.await
.unwrap();
s.request_confirmation("long", payload("l"), Duration::from_millis(200))
.await
.unwrap();
yield_to_register_timer().await;
yield_to_register_timer().await;
tokio::time::advance(Duration::from_millis(150)).await;
yield_after_advance().await;
assert_eq!(
s.get("short").await.unwrap(),
None,
"short should be expired"
);
assert!(
s.get("long").await.unwrap().is_some(),
"long should still be alive"
);
tokio::time::advance(Duration::from_millis(100)).await;
yield_after_advance().await;
assert_eq!(
s.get("long").await.unwrap(),
None,
"long should now be expired"
);
}
#[tokio::test(start_paused = true)]
async fn test_overwrite_cancels_old_ttl_timer() {
let s = store();
s.request_confirmation("over-key", payload("original"), Duration::from_millis(100))
.await
.unwrap();
yield_to_register_timer().await;
s.request_confirmation(
"over-key",
payload("replacement"),
Duration::from_millis(300),
)
.await
.unwrap();
yield_to_register_timer().await;
tokio::time::advance(Duration::from_millis(150)).await;
yield_after_advance().await;
assert_eq!(
s.get("over-key").await.unwrap(),
Some(payload("replacement")),
"overwritten entry should still be alive under new TTL"
);
tokio::time::advance(Duration::from_millis(200)).await;
yield_after_advance().await;
assert_eq!(
s.get("over-key").await.unwrap(),
None,
"new TTL should now expire"
);
}
}