Skip to main content

klauthed_data/idempotency/
redis.rs

1//! Redis-backed [`IdempotencyStore`].
2//!
3//! [`RedisIdempotencyStore`] implements the idempotency protocol on top of Redis:
4//!
5//! * **begin** — `SET key <json> NX PX <ttl_ms>`: claims the key atomically.
6//!   If the key already exists, `GET` reveals whether a request is
7//!   [`InProgress`](crate::idempotency::Outcome::InProgress) or
8//!   [`Completed`](crate::idempotency::Outcome::Completed).
9//! * **complete** — overwrites the stored record (same TTL) with status
10//!   `Completed` and the caller's response payload.
11//! * **get** — `GET` + deserialise.
12//!
13//! Keys auto-expire after the configured TTL, so the keyspace self-cleans without
14//! a background job.
15//!
16//! # Caveats
17//!
18//! `complete` is a `GET`-then-`SET`: if the key expires between the two calls a
19//! `DataError::Idempotency` is returned so the caller can decide how to handle
20//! the edge case. This is the standard single-instance Redis trade-off; for
21//! distributed atomicity a Lua compare-and-swap would be needed.
22//!
23//! Tests that need a live Redis are marked `#[ignore]`; run them with a server
24//! at `REDIS_URL` via `cargo test -p klauthed-data --features redis -- --ignored`.
25
26use async_trait::async_trait;
27use klauthed_core::time::Timestamp;
28use redis::aio::ConnectionManager;
29use redis::{ExistenceCheck, SetExpiry, SetOptions};
30use serde::{Deserialize, Serialize};
31
32use crate::error::DataError;
33use crate::idempotency::{IdempotencyRecord, IdempotencyStatus, IdempotencyStore, Outcome};
34
35/// Default TTL for idempotency keys: 24 hours.
36const DEFAULT_TTL_MS: u64 = 24 * 60 * 60 * 1_000;
37
38/// A Redis-backed [`IdempotencyStore`].
39///
40/// Records are serialised as JSON and stored with a configurable TTL so the
41/// keyspace self-cleans. Clone-cheap: holds only a [`ConnectionManager`] (an
42/// `Arc` internally).
43#[derive(Clone)]
44pub struct RedisIdempotencyStore {
45    conn: ConnectionManager,
46    ttl_ms: u64,
47}
48
49impl RedisIdempotencyStore {
50    /// Wrap a managed Redis connection with the default 24-hour TTL.
51    pub fn new(conn: ConnectionManager) -> Self {
52        Self { conn, ttl_ms: DEFAULT_TTL_MS }
53    }
54
55    /// Wrap a managed Redis connection with a custom TTL in milliseconds.
56    pub fn with_ttl_ms(conn: ConnectionManager, ttl_ms: u64) -> Self {
57        Self { conn, ttl_ms }
58    }
59}
60
61/// The shape stored in Redis for each idempotency key.
62#[derive(Serialize, Deserialize)]
63struct StoredRecord {
64    status: IdempotencyStatus,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    response: Option<serde_json::Value>,
67    created_at: Timestamp,
68    updated_at: Timestamp,
69}
70
71#[async_trait]
72impl IdempotencyStore for RedisIdempotencyStore {
73    async fn begin(&self, key: &str) -> Result<Outcome, DataError> {
74        let now = Timestamp::now();
75        let initial = serde_json::to_string(&StoredRecord {
76            status: IdempotencyStatus::InProgress,
77            response: None,
78            created_at: now,
79            updated_at: now,
80        })
81        .map_err(|e| DataError::Idempotency(format!("serialisation failed: {e}")))?;
82
83        let options = SetOptions::default()
84            .conditional_set(ExistenceCheck::NX)
85            .with_expiration(SetExpiry::PX(self.ttl_ms));
86
87        let mut conn = self.conn.clone();
88        // `SET … NX` returns `Some("OK")` on success, `None` when the key exists.
89        let claimed: Option<String> =
90            redis::cmd("SET").arg(key).arg(&initial).arg(&options).query_async(&mut conn).await?;
91
92        if claimed.is_some() {
93            return Ok(Outcome::New);
94        }
95
96        // Key already exists — inspect its current state.
97        let raw: Option<String> = redis::cmd("GET").arg(key).query_async(&mut conn).await?;
98
99        match raw {
100            // Key expired between our NX attempt and GET — treat as new claim.
101            None => Ok(Outcome::New),
102            Some(s) => {
103                let record: StoredRecord = serde_json::from_str(&s).map_err(|e| {
104                    DataError::Idempotency(format!("corrupt idempotency record for '{key}': {e}"))
105                })?;
106                match record.status {
107                    IdempotencyStatus::InProgress => Ok(Outcome::InProgress),
108                    IdempotencyStatus::Completed => {
109                        Ok(Outcome::Completed(record.response.unwrap_or(serde_json::Value::Null)))
110                    }
111                }
112            }
113        }
114    }
115
116    async fn complete(&self, key: &str, response: serde_json::Value) -> Result<(), DataError> {
117        let now = Timestamp::now();
118        let mut conn = self.conn.clone();
119
120        // Read the current record to preserve `created_at`.
121        let raw: Option<String> = redis::cmd("GET").arg(key).query_async(&mut conn).await?;
122
123        let created_at = match raw {
124            None => {
125                return Err(DataError::Idempotency(format!(
126                    "cannot complete unknown idempotency key '{key}'"
127                )));
128            }
129            Some(s) => {
130                serde_json::from_str::<StoredRecord>(&s).map(|r| r.created_at).unwrap_or(now)
131            }
132        };
133
134        let completed = serde_json::to_string(&StoredRecord {
135            status: IdempotencyStatus::Completed,
136            response: Some(response),
137            created_at,
138            updated_at: now,
139        })
140        .map_err(|e| DataError::Idempotency(format!("serialisation failed: {e}")))?;
141
142        // Overwrite with the same TTL (key existed a moment ago; if it expired
143        // in the gap the SET recreates it as Completed, which is correct).
144        redis::cmd("SET")
145            .arg(key)
146            .arg(&completed)
147            .arg("PX")
148            .arg(self.ttl_ms)
149            .query_async::<()>(&mut conn)
150            .await?;
151
152        Ok(())
153    }
154
155    async fn get(&self, key: &str) -> Result<Option<IdempotencyRecord>, DataError> {
156        let mut conn = self.conn.clone();
157        let raw: Option<String> = redis::cmd("GET").arg(key).query_async(&mut conn).await?;
158
159        match raw {
160            None => Ok(None),
161            Some(s) => {
162                let record: StoredRecord = serde_json::from_str(&s).map_err(|e| {
163                    DataError::Idempotency(format!("corrupt idempotency record for '{key}': {e}"))
164                })?;
165                Ok(Some(IdempotencyRecord {
166                    key: key.to_owned(),
167                    status: record.status,
168                    response: record.response,
169                    created_at: record.created_at,
170                    updated_at: record.updated_at,
171                }))
172            }
173        }
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use crate::locks::LockToken;
181
182    /// Connect to a live Redis at `REDIS_URL` (default `redis://127.0.0.1/`).
183    async fn live_store() -> RedisIdempotencyStore {
184        let url = std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1/".to_owned());
185        let client = redis::Client::open(url).expect("open redis client");
186        let conn = ConnectionManager::new(client).await.expect("connect redis");
187        RedisIdempotencyStore::with_ttl_ms(conn, 60_000) // 1-minute TTL for tests
188    }
189
190    #[tokio::test]
191    #[ignore = "requires a live Redis at REDIS_URL"]
192    async fn new_in_progress_complete_replay() {
193        let store = live_store().await;
194        let key = format!("klauthed:test:idem:{}", LockToken::new());
195
196        assert_eq!(store.begin(&key).await.unwrap(), Outcome::New);
197        assert_eq!(store.begin(&key).await.unwrap(), Outcome::InProgress);
198
199        let response = serde_json::json!({"charged": true, "amount": 100});
200        store.complete(&key, response.clone()).await.unwrap();
201
202        assert_eq!(store.begin(&key).await.unwrap(), Outcome::Completed(response));
203    }
204
205    #[tokio::test]
206    #[ignore = "requires a live Redis at REDIS_URL"]
207    async fn complete_unknown_key_errors() {
208        let store = live_store().await;
209        let key = format!("klauthed:test:idem:{}:missing", LockToken::new());
210
211        let err = store.complete(&key, serde_json::Value::Null).await.unwrap_err();
212        assert!(matches!(err, DataError::Idempotency(_)));
213    }
214
215    #[tokio::test]
216    #[ignore = "requires a live Redis at REDIS_URL"]
217    async fn get_returns_record_lifecycle() {
218        let store = live_store().await;
219        let key = format!("klauthed:test:idem:{}", LockToken::new());
220
221        assert!(store.get(&key).await.unwrap().is_none());
222
223        store.begin(&key).await.unwrap();
224        let rec = store.get(&key).await.unwrap().unwrap();
225        assert_eq!(rec.status, IdempotencyStatus::InProgress);
226        assert!(rec.response.is_none());
227
228        store.complete(&key, serde_json::json!(42)).await.unwrap();
229        let rec = store.get(&key).await.unwrap().unwrap();
230        assert_eq!(rec.status, IdempotencyStatus::Completed);
231        assert_eq!(rec.response, Some(serde_json::json!(42)));
232    }
233
234    #[tokio::test]
235    #[ignore = "requires a live Redis at REDIS_URL"]
236    async fn distinct_keys_are_independent() {
237        let store = live_store().await;
238        let a = format!("klauthed:test:idem:{}:a", LockToken::new());
239        let b = format!("klauthed:test:idem:{}:b", LockToken::new());
240
241        assert_eq!(store.begin(&a).await.unwrap(), Outcome::New);
242        assert_eq!(store.begin(&b).await.unwrap(), Outcome::New);
243        assert_eq!(store.begin(&a).await.unwrap(), Outcome::InProgress);
244    }
245}