Skip to main content

ferro_ai/confirmation/
store.rs

1use super::{ConfirmationExpired, ConfirmationStore, PendingActionInfo};
2use crate::error::Error;
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use dashmap::DashMap;
6use ferro_events::dispatch_sync;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::task::AbortHandle;
10use tokio::time::sleep;
11
12/// Internal storage entry for a pending confirmation action.
13struct StoredAction {
14    payload: serde_json::Value,
15    created_at: DateTime<Utc>,
16    abort_handle: AbortHandle,
17}
18
19/// In-memory confirmation store backed by a [`DashMap`].
20///
21/// Each entry carries an [`AbortHandle`] for its TTL expiry task, so confirming
22/// or rejecting an action immediately cancels the background timer.
23///
24/// # Concurrency
25///
26/// `DashMap` guards are **never** held across `.await` points. In `list_pending`,
27/// the entries are collected into a `Vec` while holding the shard locks, and the
28/// guard is dropped before returning.
29pub struct InMemoryConfirmationStore {
30    inner: Arc<DashMap<String, StoredAction>>,
31}
32
33impl InMemoryConfirmationStore {
34    /// Create a new store.
35    pub fn new() -> Self {
36        Self {
37            inner: Arc::new(DashMap::new()),
38        }
39    }
40}
41
42impl Default for InMemoryConfirmationStore {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48#[async_trait]
49impl ConfirmationStore for InMemoryConfirmationStore {
50    async fn request_confirmation(
51        &self,
52        key: &str,
53        payload: serde_json::Value,
54        ttl: Duration,
55    ) -> Result<(), Error> {
56        // Abort any existing TTL task for this key before inserting the new entry.
57        if let Some((_, old)) = self.inner.remove(key) {
58            old.abort_handle.abort();
59        }
60
61        let store = Arc::clone(&self.inner);
62        let key_owned = key.to_string();
63
64        let abort_handle = tokio::spawn(async move {
65            sleep(ttl).await;
66            if store.remove(&key_owned).is_some() {
67                dispatch_sync(ConfirmationExpired {
68                    key: key_owned,
69                    expired_at: Utc::now(),
70                });
71            }
72        })
73        .abort_handle();
74
75        self.inner.insert(
76            key.to_string(),
77            StoredAction {
78                payload,
79                created_at: Utc::now(),
80                abort_handle,
81            },
82        );
83
84        Ok(())
85    }
86
87    async fn confirm(&self, key: &str) -> Result<Option<serde_json::Value>, Error> {
88        match self.inner.remove(key) {
89            Some((_, entry)) => {
90                entry.abort_handle.abort();
91                Ok(Some(entry.payload))
92            }
93            None => Ok(None),
94        }
95    }
96
97    async fn reject(&self, key: &str) -> Result<bool, Error> {
98        match self.inner.remove(key) {
99            Some((_, entry)) => {
100                entry.abort_handle.abort();
101                Ok(true)
102            }
103            None => Ok(false),
104        }
105    }
106
107    async fn get(&self, key: &str) -> Result<Option<serde_json::Value>, Error> {
108        // Clone payload while guard is held, then drop guard before return.
109        let payload = self.inner.get(key).map(|entry| entry.payload.clone());
110        Ok(payload)
111    }
112
113    async fn list_pending(&self) -> Result<Vec<PendingActionInfo>, Error> {
114        // Collect into Vec while holding shard locks, drop all guards before returning.
115        let pending: Vec<PendingActionInfo> = self
116            .inner
117            .iter()
118            .map(|entry| PendingActionInfo {
119                key: entry.key().clone(),
120                created_at: entry.value().created_at,
121            })
122            .collect();
123        Ok(pending)
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    fn store() -> InMemoryConfirmationStore {
132        InMemoryConfirmationStore::new()
133    }
134
135    fn payload(val: &str) -> serde_json::Value {
136        serde_json::json!({"action": val})
137    }
138
139    // --- Basic CRUD ---
140
141    #[tokio::test]
142    async fn test_request_and_get() {
143        let s = store();
144        let p = payload("delete");
145        s.request_confirmation("k1", p.clone(), Duration::from_secs(60))
146            .await
147            .unwrap();
148        let result = s.get("k1").await.unwrap();
149        assert_eq!(result, Some(p));
150    }
151
152    #[tokio::test]
153    async fn test_confirm_returns_payload_and_removes_entry() {
154        let s = store();
155        let p = payload("publish");
156        s.request_confirmation("k2", p.clone(), Duration::from_secs(60))
157            .await
158            .unwrap();
159        let confirmed = s.confirm("k2").await.unwrap();
160        assert_eq!(confirmed, Some(p));
161        // Entry must be gone after confirm
162        assert_eq!(s.get("k2").await.unwrap(), None);
163    }
164
165    #[tokio::test]
166    async fn test_confirm_nonexistent_returns_none() {
167        let s = store();
168        let result = s.confirm("no-such-key").await.unwrap();
169        assert_eq!(result, None);
170    }
171
172    #[tokio::test]
173    async fn test_reject_removes_entry_and_returns_true() {
174        let s = store();
175        s.request_confirmation("k3", payload("send"), Duration::from_secs(60))
176            .await
177            .unwrap();
178        let rejected = s.reject("k3").await.unwrap();
179        assert!(rejected);
180        assert_eq!(s.get("k3").await.unwrap(), None);
181    }
182
183    #[tokio::test]
184    async fn test_reject_nonexistent_returns_false() {
185        let s = store();
186        let result = s.reject("ghost").await.unwrap();
187        assert!(!result);
188    }
189
190    #[tokio::test]
191    async fn test_list_pending_returns_all_entries() {
192        let s = store();
193        s.request_confirmation("a", payload("alpha"), Duration::from_secs(60))
194            .await
195            .unwrap();
196        s.request_confirmation("b", payload("beta"), Duration::from_secs(60))
197            .await
198            .unwrap();
199
200        let mut pending = s.list_pending().await.unwrap();
201        pending.sort_by_key(|p| p.key.clone());
202        assert_eq!(pending.len(), 2);
203        assert_eq!(pending[0].key, "a");
204        assert_eq!(pending[1].key, "b");
205    }
206
207    #[tokio::test]
208    async fn test_overwrite_replaces_payload() {
209        let s = store();
210        s.request_confirmation("k4", payload("first"), Duration::from_secs(60))
211            .await
212            .unwrap();
213        s.request_confirmation("k4", payload("second"), Duration::from_secs(60))
214            .await
215            .unwrap();
216        let result = s.get("k4").await.unwrap();
217        assert_eq!(result, Some(payload("second")));
218        // Still only one pending entry
219        assert_eq!(s.list_pending().await.unwrap().len(), 1);
220    }
221
222    // --- TTL expiry ---
223
224    /// Yield to the tokio scheduler so spawned tasks can register timers before a time advance.
225    ///
226    /// Spawned tasks (e.g. TTL expiry tasks) are not polled until the current task yields.
227    /// Calling this after `request_confirmation` lets the spawned task run until it
228    /// registers its `sleep` timer with the runtime. After `tokio::time::advance`, the
229    /// timer fires and further yields allow the task to complete.
230    async fn yield_to_register_timer() {
231        tokio::task::yield_now().await;
232    }
233
234    /// Yield after a time advance so woken tasks can complete their post-timer work.
235    async fn yield_after_advance() {
236        for _ in 0..5 {
237            tokio::task::yield_now().await;
238        }
239    }
240
241    #[tokio::test(start_paused = true)]
242    async fn test_entry_removed_after_ttl_expires() {
243        let s = store();
244        s.request_confirmation("ttl-key", payload("expire-me"), Duration::from_millis(100))
245            .await
246            .unwrap();
247        // Yield so the spawned TTL task registers its sleep timer.
248        yield_to_register_timer().await;
249        // Not expired yet
250        assert!(s.get("ttl-key").await.unwrap().is_some());
251
252        tokio::time::advance(Duration::from_millis(150)).await;
253        yield_after_advance().await;
254
255        assert_eq!(s.get("ttl-key").await.unwrap(), None);
256    }
257
258    #[tokio::test(start_paused = true)]
259    async fn test_confirm_before_ttl_prevents_expiry() {
260        let s = store();
261        s.request_confirmation(
262            "confirm-early",
263            payload("confirm"),
264            Duration::from_millis(100),
265        )
266        .await
267        .unwrap();
268        yield_to_register_timer().await;
269        let result = s.confirm("confirm-early").await.unwrap();
270        assert!(result.is_some());
271
272        // Advance past TTL — the task should have been aborted, no panic
273        tokio::time::advance(Duration::from_millis(200)).await;
274        yield_after_advance().await;
275
276        // Key should remain absent (was removed by confirm, not by expiry)
277        assert_eq!(s.get("confirm-early").await.unwrap(), None);
278    }
279
280    #[tokio::test(start_paused = true)]
281    async fn test_reject_before_ttl_prevents_expiry() {
282        let s = store();
283        s.request_confirmation(
284            "reject-early",
285            payload("reject"),
286            Duration::from_millis(100),
287        )
288        .await
289        .unwrap();
290        yield_to_register_timer().await;
291        let rejected = s.reject("reject-early").await.unwrap();
292        assert!(rejected);
293
294        tokio::time::advance(Duration::from_millis(200)).await;
295        yield_after_advance().await;
296
297        assert_eq!(s.get("reject-early").await.unwrap(), None);
298    }
299
300    #[tokio::test(start_paused = true)]
301    async fn test_independent_ttls() {
302        let s = store();
303        s.request_confirmation("short", payload("s"), Duration::from_millis(100))
304            .await
305            .unwrap();
306        s.request_confirmation("long", payload("l"), Duration::from_millis(200))
307            .await
308            .unwrap();
309        // Let both spawned tasks register their timers.
310        yield_to_register_timer().await;
311        yield_to_register_timer().await;
312
313        // Advance past short TTL only
314        tokio::time::advance(Duration::from_millis(150)).await;
315        yield_after_advance().await;
316
317        assert_eq!(
318            s.get("short").await.unwrap(),
319            None,
320            "short should be expired"
321        );
322        assert!(
323            s.get("long").await.unwrap().is_some(),
324            "long should still be alive"
325        );
326
327        // Advance past long TTL too
328        tokio::time::advance(Duration::from_millis(100)).await;
329        yield_after_advance().await;
330
331        assert_eq!(
332            s.get("long").await.unwrap(),
333            None,
334            "long should now be expired"
335        );
336    }
337
338    #[tokio::test(start_paused = true)]
339    async fn test_overwrite_cancels_old_ttl_timer() {
340        let s = store();
341        // Register with a short TTL
342        s.request_confirmation("over-key", payload("original"), Duration::from_millis(100))
343            .await
344            .unwrap();
345        // Yield to let the original task register its timer.
346        yield_to_register_timer().await;
347        // Overwrite with a longer TTL immediately — aborts the old timer.
348        s.request_confirmation(
349            "over-key",
350            payload("replacement"),
351            Duration::from_millis(300),
352        )
353        .await
354        .unwrap();
355        // Yield to let the replacement task register its timer.
356        yield_to_register_timer().await;
357
358        // Advance past the original (short) TTL — the old timer should be dead.
359        tokio::time::advance(Duration::from_millis(150)).await;
360        yield_after_advance().await;
361
362        // The new payload should still be there (old timer was aborted).
363        assert_eq!(
364            s.get("over-key").await.unwrap(),
365            Some(payload("replacement")),
366            "overwritten entry should still be alive under new TTL"
367        );
368
369        // Advance past the new TTL.
370        tokio::time::advance(Duration::from_millis(200)).await;
371        yield_after_advance().await;
372
373        assert_eq!(
374            s.get("over-key").await.unwrap(),
375            None,
376            "new TTL should now expire"
377        );
378    }
379}