1use super::{ConfirmationExpired, ConfirmationStore, PendingActionInfo};
2use crate::error::Error;
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use dashmap::DashMap;
6use ferro_events::dispatch_sync;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::task::AbortHandle;
10use tokio::time::sleep;
11
12struct StoredAction {
14 payload: serde_json::Value,
15 created_at: DateTime<Utc>,
16 abort_handle: AbortHandle,
17}
18
19pub struct InMemoryConfirmationStore {
30 inner: Arc<DashMap<String, StoredAction>>,
31}
32
33impl InMemoryConfirmationStore {
34 pub fn new() -> Self {
36 Self {
37 inner: Arc::new(DashMap::new()),
38 }
39 }
40}
41
42impl Default for InMemoryConfirmationStore {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48#[async_trait]
49impl ConfirmationStore for InMemoryConfirmationStore {
50 async fn request_confirmation(
51 &self,
52 key: &str,
53 payload: serde_json::Value,
54 ttl: Duration,
55 ) -> Result<(), Error> {
56 if let Some((_, old)) = self.inner.remove(key) {
58 old.abort_handle.abort();
59 }
60
61 let store = Arc::clone(&self.inner);
62 let key_owned = key.to_string();
63
64 let abort_handle = tokio::spawn(async move {
65 sleep(ttl).await;
66 if store.remove(&key_owned).is_some() {
67 dispatch_sync(ConfirmationExpired {
68 key: key_owned,
69 expired_at: Utc::now(),
70 });
71 }
72 })
73 .abort_handle();
74
75 self.inner.insert(
76 key.to_string(),
77 StoredAction {
78 payload,
79 created_at: Utc::now(),
80 abort_handle,
81 },
82 );
83
84 Ok(())
85 }
86
87 async fn confirm(&self, key: &str) -> Result<Option<serde_json::Value>, Error> {
88 match self.inner.remove(key) {
89 Some((_, entry)) => {
90 entry.abort_handle.abort();
91 Ok(Some(entry.payload))
92 }
93 None => Ok(None),
94 }
95 }
96
97 async fn reject(&self, key: &str) -> Result<bool, Error> {
98 match self.inner.remove(key) {
99 Some((_, entry)) => {
100 entry.abort_handle.abort();
101 Ok(true)
102 }
103 None => Ok(false),
104 }
105 }
106
107 async fn get(&self, key: &str) -> Result<Option<serde_json::Value>, Error> {
108 let payload = self.inner.get(key).map(|entry| entry.payload.clone());
110 Ok(payload)
111 }
112
113 async fn list_pending(&self) -> Result<Vec<PendingActionInfo>, Error> {
114 let pending: Vec<PendingActionInfo> = self
116 .inner
117 .iter()
118 .map(|entry| PendingActionInfo {
119 key: entry.key().clone(),
120 created_at: entry.value().created_at,
121 })
122 .collect();
123 Ok(pending)
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 fn store() -> InMemoryConfirmationStore {
132 InMemoryConfirmationStore::new()
133 }
134
135 fn payload(val: &str) -> serde_json::Value {
136 serde_json::json!({"action": val})
137 }
138
139 #[tokio::test]
142 async fn test_request_and_get() {
143 let s = store();
144 let p = payload("delete");
145 s.request_confirmation("k1", p.clone(), Duration::from_secs(60))
146 .await
147 .unwrap();
148 let result = s.get("k1").await.unwrap();
149 assert_eq!(result, Some(p));
150 }
151
152 #[tokio::test]
153 async fn test_confirm_returns_payload_and_removes_entry() {
154 let s = store();
155 let p = payload("publish");
156 s.request_confirmation("k2", p.clone(), Duration::from_secs(60))
157 .await
158 .unwrap();
159 let confirmed = s.confirm("k2").await.unwrap();
160 assert_eq!(confirmed, Some(p));
161 assert_eq!(s.get("k2").await.unwrap(), None);
163 }
164
165 #[tokio::test]
166 async fn test_confirm_nonexistent_returns_none() {
167 let s = store();
168 let result = s.confirm("no-such-key").await.unwrap();
169 assert_eq!(result, None);
170 }
171
172 #[tokio::test]
173 async fn test_reject_removes_entry_and_returns_true() {
174 let s = store();
175 s.request_confirmation("k3", payload("send"), Duration::from_secs(60))
176 .await
177 .unwrap();
178 let rejected = s.reject("k3").await.unwrap();
179 assert!(rejected);
180 assert_eq!(s.get("k3").await.unwrap(), None);
181 }
182
183 #[tokio::test]
184 async fn test_reject_nonexistent_returns_false() {
185 let s = store();
186 let result = s.reject("ghost").await.unwrap();
187 assert!(!result);
188 }
189
190 #[tokio::test]
191 async fn test_list_pending_returns_all_entries() {
192 let s = store();
193 s.request_confirmation("a", payload("alpha"), Duration::from_secs(60))
194 .await
195 .unwrap();
196 s.request_confirmation("b", payload("beta"), Duration::from_secs(60))
197 .await
198 .unwrap();
199
200 let mut pending = s.list_pending().await.unwrap();
201 pending.sort_by_key(|p| p.key.clone());
202 assert_eq!(pending.len(), 2);
203 assert_eq!(pending[0].key, "a");
204 assert_eq!(pending[1].key, "b");
205 }
206
207 #[tokio::test]
208 async fn test_overwrite_replaces_payload() {
209 let s = store();
210 s.request_confirmation("k4", payload("first"), Duration::from_secs(60))
211 .await
212 .unwrap();
213 s.request_confirmation("k4", payload("second"), Duration::from_secs(60))
214 .await
215 .unwrap();
216 let result = s.get("k4").await.unwrap();
217 assert_eq!(result, Some(payload("second")));
218 assert_eq!(s.list_pending().await.unwrap().len(), 1);
220 }
221
222 async fn yield_to_register_timer() {
231 tokio::task::yield_now().await;
232 }
233
234 async fn yield_after_advance() {
236 for _ in 0..5 {
237 tokio::task::yield_now().await;
238 }
239 }
240
241 #[tokio::test(start_paused = true)]
242 async fn test_entry_removed_after_ttl_expires() {
243 let s = store();
244 s.request_confirmation("ttl-key", payload("expire-me"), Duration::from_millis(100))
245 .await
246 .unwrap();
247 yield_to_register_timer().await;
249 assert!(s.get("ttl-key").await.unwrap().is_some());
251
252 tokio::time::advance(Duration::from_millis(150)).await;
253 yield_after_advance().await;
254
255 assert_eq!(s.get("ttl-key").await.unwrap(), None);
256 }
257
258 #[tokio::test(start_paused = true)]
259 async fn test_confirm_before_ttl_prevents_expiry() {
260 let s = store();
261 s.request_confirmation(
262 "confirm-early",
263 payload("confirm"),
264 Duration::from_millis(100),
265 )
266 .await
267 .unwrap();
268 yield_to_register_timer().await;
269 let result = s.confirm("confirm-early").await.unwrap();
270 assert!(result.is_some());
271
272 tokio::time::advance(Duration::from_millis(200)).await;
274 yield_after_advance().await;
275
276 assert_eq!(s.get("confirm-early").await.unwrap(), None);
278 }
279
280 #[tokio::test(start_paused = true)]
281 async fn test_reject_before_ttl_prevents_expiry() {
282 let s = store();
283 s.request_confirmation(
284 "reject-early",
285 payload("reject"),
286 Duration::from_millis(100),
287 )
288 .await
289 .unwrap();
290 yield_to_register_timer().await;
291 let rejected = s.reject("reject-early").await.unwrap();
292 assert!(rejected);
293
294 tokio::time::advance(Duration::from_millis(200)).await;
295 yield_after_advance().await;
296
297 assert_eq!(s.get("reject-early").await.unwrap(), None);
298 }
299
300 #[tokio::test(start_paused = true)]
301 async fn test_independent_ttls() {
302 let s = store();
303 s.request_confirmation("short", payload("s"), Duration::from_millis(100))
304 .await
305 .unwrap();
306 s.request_confirmation("long", payload("l"), Duration::from_millis(200))
307 .await
308 .unwrap();
309 yield_to_register_timer().await;
311 yield_to_register_timer().await;
312
313 tokio::time::advance(Duration::from_millis(150)).await;
315 yield_after_advance().await;
316
317 assert_eq!(
318 s.get("short").await.unwrap(),
319 None,
320 "short should be expired"
321 );
322 assert!(
323 s.get("long").await.unwrap().is_some(),
324 "long should still be alive"
325 );
326
327 tokio::time::advance(Duration::from_millis(100)).await;
329 yield_after_advance().await;
330
331 assert_eq!(
332 s.get("long").await.unwrap(),
333 None,
334 "long should now be expired"
335 );
336 }
337
338 #[tokio::test(start_paused = true)]
339 async fn test_overwrite_cancels_old_ttl_timer() {
340 let s = store();
341 s.request_confirmation("over-key", payload("original"), Duration::from_millis(100))
343 .await
344 .unwrap();
345 yield_to_register_timer().await;
347 s.request_confirmation(
349 "over-key",
350 payload("replacement"),
351 Duration::from_millis(300),
352 )
353 .await
354 .unwrap();
355 yield_to_register_timer().await;
357
358 tokio::time::advance(Duration::from_millis(150)).await;
360 yield_after_advance().await;
361
362 assert_eq!(
364 s.get("over-key").await.unwrap(),
365 Some(payload("replacement")),
366 "overwritten entry should still be alive under new TTL"
367 );
368
369 tokio::time::advance(Duration::from_millis(200)).await;
371 yield_after_advance().await;
372
373 assert_eq!(
374 s.get("over-key").await.unwrap(),
375 None,
376 "new TTL should now expire"
377 );
378 }
379}