Skip to main content

cratestack_sqlx/
idempotency.rs

1//! Postgres-backed [`IdempotencyStore`].
2//!
3//! Banks need duplicate-execution protection even under concurrency, so
4//! this implementation uses the atomic reservation pattern: a single
5//! upsert claims the key (or surfaces the existing claim), the middleware
6//! runs the handler only when it owns the reservation, then writes the
7//! captured response with `complete`.
8//!
9//! Expired rows are unconditionally replaced on the next reservation —
10//! the `ON CONFLICT DO UPDATE WHERE cratestack_idempotency.expires_at <=
11//! NOW()` clause lets the new caller take over a stale row in the same
12//! statement that would otherwise have hit the unique-key wall.
13
14use std::time::SystemTime;
15
16use async_trait::async_trait;
17use cratestack_axum::idempotency::{
18    IDEMPOTENCY_TABLE_DDL, IdempotencyRecord, IdempotencyStore, ReservationOutcome,
19};
20use cratestack_core::CoolError;
21
22#[derive(Clone)]
23pub struct SqlxIdempotencyStore {
24    pool: sqlx::PgPool,
25}
26
27impl SqlxIdempotencyStore {
28    pub fn new(pool: sqlx::PgPool) -> Self {
29        Self { pool }
30    }
31
32    /// Ensure the table exists. Banks typically run this via their own
33    /// migration tooling; we expose it here for convenience.
34    pub async fn ensure_schema(&self) -> Result<(), CoolError> {
35        // Multi-statement DDL (table + index) — prepared statements only
36        // accept one statement at a time, so split + execute sequentially.
37        for statement in IDEMPOTENCY_TABLE_DDL
38            .split(';')
39            .map(str::trim)
40            .filter(|s| !s.is_empty())
41        {
42            sqlx::query(statement)
43                .execute(&self.pool)
44                .await
45                .map_err(|error| CoolError::Database(error.to_string()))?;
46        }
47        Ok(())
48    }
49
50    /// Delete expired rows. Run periodically (e.g. via a scheduled task
51    /// or the `cratestack idempotency gc` CLI subcommand) — the request
52    /// path does not auto-GC, although `reserve_or_fetch` does take over
53    /// any single expired row it tries to claim.
54    pub async fn garbage_collect(&self) -> Result<u64, CoolError> {
55        let result = sqlx::query("DELETE FROM cratestack_idempotency WHERE expires_at < NOW()")
56            .execute(&self.pool)
57            .await
58            .map_err(|error| CoolError::Database(error.to_string()))?;
59        Ok(result.rows_affected())
60    }
61}
62
63#[async_trait]
64impl IdempotencyStore for SqlxIdempotencyStore {
65    async fn reserve_or_fetch(
66        &self,
67        principal: &str,
68        key: &str,
69        request_hash: [u8; 32],
70        expires_at: SystemTime,
71    ) -> Result<ReservationOutcome, CoolError> {
72        let expires_at: chrono::DateTime<chrono::Utc> = expires_at.into();
73        // Generate a fresh reservation token. If our INSERT or expired-
74        // row UPDATE wins, this token identifies our reservation. A
75        // handler that runs past the TTL and gets reclaimed by a retry
76        // will see its token replaced in-row, and any later complete/
77        // release from the stale handler becomes a no-op.
78        let new_token = uuid::Uuid::new_v4();
79        // Single upsert that:
80        //   - inserts a fresh pending row if the key is absent;
81        //   - takes over the row if the existing one has expired (the
82        //     `WHERE` filter on the DO UPDATE branch);
83        //   - leaves the row alone otherwise.
84        // The `xmax = 0` trick distinguishes a real INSERT (true) from
85        // an UPDATE-on-conflict (false). PG sets xmax to the locking
86        // transaction id on an UPDATE; pristine inserts read xmax = 0.
87        let row: Option<(
88            Vec<u8>,
89            uuid::Uuid,
90            Option<i32>,
91            Option<Vec<u8>>,
92            Option<Vec<u8>>,
93            chrono::DateTime<chrono::Utc>,
94            chrono::DateTime<chrono::Utc>,
95            bool,
96        )> = sqlx::query_as(
97            "INSERT INTO cratestack_idempotency (
98                principal_fingerprint, key, request_hash, reservation_id, expires_at
99             ) VALUES ($1, $2, $3, $4, $5)
100             ON CONFLICT (principal_fingerprint, key) DO UPDATE SET
101                request_hash = EXCLUDED.request_hash,
102                reservation_id = EXCLUDED.reservation_id,
103                response_status = NULL,
104                response_headers = NULL,
105                response_body = NULL,
106                created_at = NOW(),
107                expires_at = EXCLUDED.expires_at
108             WHERE cratestack_idempotency.expires_at <= NOW()
109             RETURNING request_hash, reservation_id, response_status, response_headers,
110                       response_body, created_at, expires_at, (xmax = 0) AS was_inserted",
111        )
112        .bind(principal)
113        .bind(key)
114        .bind(request_hash.as_slice())
115        .bind(new_token)
116        .bind(expires_at)
117        .fetch_optional(&self.pool)
118        .await
119        .map_err(|error| CoolError::Database(error.to_string()))?;
120
121        if let Some((_, token, _, _, _, _, _, _)) = row {
122            // Either a fresh insert (was_inserted = true) or an expired
123            // row we just reclaimed (was_inserted = false but UPDATE
124            // happened). In both cases the caller owns the reservation
125            // and the row carries the token we just generated.
126            return Ok(ReservationOutcome::Reserved { token });
127        }
128
129        // ON CONFLICT WHERE evaluated to false (existing row is live).
130        // Read it back and classify.
131        let existing: Option<(
132            Vec<u8>,
133            Option<i32>,
134            Option<Vec<u8>>,
135            Option<Vec<u8>>,
136            chrono::DateTime<chrono::Utc>,
137            chrono::DateTime<chrono::Utc>,
138        )> = sqlx::query_as(
139            "SELECT request_hash, response_status, response_headers,
140                    response_body, created_at, expires_at
141             FROM cratestack_idempotency
142             WHERE principal_fingerprint = $1 AND key = $2",
143        )
144        .bind(principal)
145        .bind(key)
146        .fetch_optional(&self.pool)
147        .await
148        .map_err(|error| CoolError::Database(error.to_string()))?;
149
150        let Some((stored_hash, status, headers, body, created_at, existing_expires_at)) = existing
151        else {
152            // Vanished between the upsert and the read (a concurrent GC
153            // could do this in theory). Surface as InFlight so the
154            // caller retries shortly rather than running the handler on
155            // a state we don't fully understand.
156            return Ok(ReservationOutcome::InFlight);
157        };
158
159        let stored: [u8; 32] = stored_hash
160            .as_slice()
161            .try_into()
162            .map_err(|_| CoolError::Internal("corrupt idempotency hash length".to_owned()))?;
163        if stored != request_hash {
164            return Ok(ReservationOutcome::Conflict);
165        }
166
167        match (status, body) {
168            (Some(s), Some(b)) => {
169                let response_status: u16 = u16::try_from(s).unwrap_or(500);
170                Ok(ReservationOutcome::Replay(IdempotencyRecord {
171                    principal_fingerprint: principal.to_owned(),
172                    key: key.to_owned(),
173                    request_hash: stored,
174                    response_status,
175                    response_headers: headers.unwrap_or_default(),
176                    response_body: b,
177                    created_at: created_at.into(),
178                    expires_at: existing_expires_at.into(),
179                }))
180            }
181            _ => Ok(ReservationOutcome::InFlight),
182        }
183    }
184
185    async fn complete(
186        &self,
187        principal: &str,
188        key: &str,
189        token: uuid::Uuid,
190        status: u16,
191        headers: &[u8],
192        body: &[u8],
193    ) -> Result<(), CoolError> {
194        // Only completes the row we actually reserved. `reservation_id =
195        // $token` is the proof; `response_body IS NULL` keeps us from
196        // double-writing a finished slot. A handler that ran past its
197        // TTL will find the row's `reservation_id` rotated out by the
198        // retry that reclaimed it, and this UPDATE matches zero rows.
199        sqlx::query(
200            "UPDATE cratestack_idempotency
201             SET response_status = $1,
202                 response_headers = $2,
203                 response_body = $3
204             WHERE principal_fingerprint = $4
205               AND key = $5
206               AND reservation_id = $6
207               AND response_body IS NULL",
208        )
209        .bind(status as i32)
210        .bind(headers)
211        .bind(body)
212        .bind(principal)
213        .bind(key)
214        .bind(token)
215        .execute(&self.pool)
216        .await
217        .map(|_| ())
218        .map_err(|error| CoolError::Database(error.to_string()))
219    }
220
221    async fn release(
222        &self,
223        principal: &str,
224        key: &str,
225        token: uuid::Uuid,
226    ) -> Result<(), CoolError> {
227        // Only drop our own pending row — never delete a completed one,
228        // and never delete a row whose reservation has been rotated to
229        // a different owner.
230        sqlx::query(
231            "DELETE FROM cratestack_idempotency
232             WHERE principal_fingerprint = $1
233               AND key = $2
234               AND reservation_id = $3
235               AND response_body IS NULL",
236        )
237        .bind(principal)
238        .bind(key)
239        .bind(token)
240        .execute(&self.pool)
241        .await
242        .map(|_| ())
243        .map_err(|error| CoolError::Database(error.to_string()))
244    }
245}
246
247/// Compute when a record originally captured at `created_at` will expire.
248/// Pulled out for unit-test reach; the SystemTime arithmetic is otherwise
249/// awkward to assert against without a clock injection point.
250pub fn expiry_from(created_at: SystemTime, ttl: std::time::Duration) -> SystemTime {
251    created_at + ttl
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use std::time::{Duration, SystemTime};
258
259    #[test]
260    fn expiry_adds_ttl_to_creation() {
261        let now = SystemTime::UNIX_EPOCH;
262        let expiry = expiry_from(now, Duration::from_secs(24 * 3600));
263        assert_eq!(
264            expiry.duration_since(SystemTime::UNIX_EPOCH).unwrap(),
265            Duration::from_secs(24 * 3600),
266        );
267    }
268}