Skip to main content

cratestack_sqlx/
idempotency.rs

1//! Postgres-backed [`IdempotencyStore`].
2//!
3//! Banks need duplicate-execution protection even under concurrency, so
4//! this implementation uses the atomic reservation pattern: a single
5//! upsert claims the key (or surfaces the existing claim), the middleware
6//! runs the handler only when it owns the reservation, then writes the
7//! captured response with `complete`.
8//!
9//! Expired rows are unconditionally replaced on the next reservation —
10//! the `ON CONFLICT DO UPDATE WHERE cratestack_idempotency.expires_at <=
11//! NOW()` clause lets the new caller take over a stale row in the same
12//! statement that would otherwise have hit the unique-key wall.
13use crate::sqlx;
14
15
16use std::time::SystemTime;
17
18use async_trait::async_trait;
19use cratestack_axum::idempotency::{
20    IDEMPOTENCY_TABLE_DDL, IdempotencyRecord, IdempotencyStore, ReservationOutcome,
21};
22use cratestack_core::CoolError;
23
24#[derive(Clone)]
25pub struct SqlxIdempotencyStore {
26    pool: sqlx::PgPool,
27}
28
29impl SqlxIdempotencyStore {
30    pub fn new(pool: sqlx::PgPool) -> Self {
31        Self { pool }
32    }
33
34    /// Ensure the table exists. Banks typically run this via their own
35    /// migration tooling; we expose it here for convenience.
36    pub async fn ensure_schema(&self) -> Result<(), CoolError> {
37        // Multi-statement DDL (table + index) — prepared statements only
38        // accept one statement at a time, so split + execute sequentially.
39        for statement in IDEMPOTENCY_TABLE_DDL
40            .split(';')
41            .map(str::trim)
42            .filter(|s| !s.is_empty())
43        {
44            sqlx::query(statement)
45                .execute(&self.pool)
46                .await
47                .map_err(|error| CoolError::Database(error.to_string()))?;
48        }
49        Ok(())
50    }
51
52    /// Delete expired rows. Run periodically (e.g. via a scheduled task
53    /// or the `cratestack idempotency gc` CLI subcommand) — the request
54    /// path does not auto-GC, although `reserve_or_fetch` does take over
55    /// any single expired row it tries to claim.
56    pub async fn garbage_collect(&self) -> Result<u64, CoolError> {
57        let result = sqlx::query("DELETE FROM cratestack_idempotency WHERE expires_at < NOW()")
58            .execute(&self.pool)
59            .await
60            .map_err(|error| CoolError::Database(error.to_string()))?;
61        Ok(result.rows_affected())
62    }
63}
64
65#[async_trait]
66impl IdempotencyStore for SqlxIdempotencyStore {
67    async fn reserve_or_fetch(
68        &self,
69        principal: &str,
70        key: &str,
71        request_hash: [u8; 32],
72        expires_at: SystemTime,
73    ) -> Result<ReservationOutcome, CoolError> {
74        let expires_at: chrono::DateTime<chrono::Utc> = expires_at.into();
75        // Generate a fresh reservation token. If our INSERT or expired-
76        // row UPDATE wins, this token identifies our reservation. A
77        // handler that runs past the TTL and gets reclaimed by a retry
78        // will see its token replaced in-row, and any later complete/
79        // release from the stale handler becomes a no-op.
80        let new_token = uuid::Uuid::new_v4();
81        // Single upsert that:
82        //   - inserts a fresh pending row if the key is absent;
83        //   - takes over the row if the existing one has expired (the
84        //     `WHERE` filter on the DO UPDATE branch);
85        //   - leaves the row alone otherwise.
86        // The `xmax = 0` trick distinguishes a real INSERT (true) from
87        // an UPDATE-on-conflict (false). PG sets xmax to the locking
88        // transaction id on an UPDATE; pristine inserts read xmax = 0.
89        let row: Option<(
90            Vec<u8>,
91            uuid::Uuid,
92            Option<i32>,
93            Option<Vec<u8>>,
94            Option<Vec<u8>>,
95            chrono::DateTime<chrono::Utc>,
96            chrono::DateTime<chrono::Utc>,
97            bool,
98        )> = sqlx::query_as(
99            "INSERT INTO cratestack_idempotency (
100                principal_fingerprint, key, request_hash, reservation_id, expires_at
101             ) VALUES ($1, $2, $3, $4, $5)
102             ON CONFLICT (principal_fingerprint, key) DO UPDATE SET
103                request_hash = EXCLUDED.request_hash,
104                reservation_id = EXCLUDED.reservation_id,
105                response_status = NULL,
106                response_headers = NULL,
107                response_body = NULL,
108                created_at = NOW(),
109                expires_at = EXCLUDED.expires_at
110             WHERE cratestack_idempotency.expires_at <= NOW()
111             RETURNING request_hash, reservation_id, response_status, response_headers,
112                       response_body, created_at, expires_at, (xmax = 0) AS was_inserted",
113        )
114        .bind(principal)
115        .bind(key)
116        .bind(request_hash.as_slice())
117        .bind(new_token)
118        .bind(expires_at)
119        .fetch_optional(&self.pool)
120        .await
121        .map_err(|error| CoolError::Database(error.to_string()))?;
122
123        if let Some((_, token, _, _, _, _, _, _)) = row {
124            // Either a fresh insert (was_inserted = true) or an expired
125            // row we just reclaimed (was_inserted = false but UPDATE
126            // happened). In both cases the caller owns the reservation
127            // and the row carries the token we just generated.
128            return Ok(ReservationOutcome::Reserved { token });
129        }
130
131        // ON CONFLICT WHERE evaluated to false (existing row is live).
132        // Read it back and classify.
133        let existing: Option<(
134            Vec<u8>,
135            Option<i32>,
136            Option<Vec<u8>>,
137            Option<Vec<u8>>,
138            chrono::DateTime<chrono::Utc>,
139            chrono::DateTime<chrono::Utc>,
140        )> = sqlx::query_as(
141            "SELECT request_hash, response_status, response_headers,
142                    response_body, created_at, expires_at
143             FROM cratestack_idempotency
144             WHERE principal_fingerprint = $1 AND key = $2",
145        )
146        .bind(principal)
147        .bind(key)
148        .fetch_optional(&self.pool)
149        .await
150        .map_err(|error| CoolError::Database(error.to_string()))?;
151
152        let Some((stored_hash, status, headers, body, created_at, existing_expires_at)) = existing
153        else {
154            // Vanished between the upsert and the read (a concurrent GC
155            // could do this in theory). Surface as InFlight so the
156            // caller retries shortly rather than running the handler on
157            // a state we don't fully understand.
158            return Ok(ReservationOutcome::InFlight);
159        };
160
161        let stored: [u8; 32] = stored_hash
162            .as_slice()
163            .try_into()
164            .map_err(|_| CoolError::Internal("corrupt idempotency hash length".to_owned()))?;
165        if stored != request_hash {
166            return Ok(ReservationOutcome::Conflict);
167        }
168
169        match (status, body) {
170            (Some(s), Some(b)) => {
171                let response_status: u16 = u16::try_from(s).unwrap_or(500);
172                Ok(ReservationOutcome::Replay(IdempotencyRecord {
173                    principal_fingerprint: principal.to_owned(),
174                    key: key.to_owned(),
175                    request_hash: stored,
176                    response_status,
177                    response_headers: headers.unwrap_or_default(),
178                    response_body: b,
179                    created_at: created_at.into(),
180                    expires_at: existing_expires_at.into(),
181                }))
182            }
183            _ => Ok(ReservationOutcome::InFlight),
184        }
185    }
186
187    async fn complete(
188        &self,
189        principal: &str,
190        key: &str,
191        token: uuid::Uuid,
192        status: u16,
193        headers: &[u8],
194        body: &[u8],
195    ) -> Result<(), CoolError> {
196        // Only completes the row we actually reserved. `reservation_id =
197        // $token` is the proof; `response_body IS NULL` keeps us from
198        // double-writing a finished slot. A handler that ran past its
199        // TTL will find the row's `reservation_id` rotated out by the
200        // retry that reclaimed it, and this UPDATE matches zero rows.
201        sqlx::query(
202            "UPDATE cratestack_idempotency
203             SET response_status = $1,
204                 response_headers = $2,
205                 response_body = $3
206             WHERE principal_fingerprint = $4
207               AND key = $5
208               AND reservation_id = $6
209               AND response_body IS NULL",
210        )
211        .bind(status as i32)
212        .bind(headers)
213        .bind(body)
214        .bind(principal)
215        .bind(key)
216        .bind(token)
217        .execute(&self.pool)
218        .await
219        .map(|_| ())
220        .map_err(|error| CoolError::Database(error.to_string()))
221    }
222
223    async fn release(
224        &self,
225        principal: &str,
226        key: &str,
227        token: uuid::Uuid,
228    ) -> Result<(), CoolError> {
229        // Only drop our own pending row — never delete a completed one,
230        // and never delete a row whose reservation has been rotated to
231        // a different owner.
232        sqlx::query(
233            "DELETE FROM cratestack_idempotency
234             WHERE principal_fingerprint = $1
235               AND key = $2
236               AND reservation_id = $3
237               AND response_body IS NULL",
238        )
239        .bind(principal)
240        .bind(key)
241        .bind(token)
242        .execute(&self.pool)
243        .await
244        .map(|_| ())
245        .map_err(|error| CoolError::Database(error.to_string()))
246    }
247}
248
249/// Compute when a record originally captured at `created_at` will expire.
250/// Pulled out for unit-test reach; the SystemTime arithmetic is otherwise
251/// awkward to assert against without a clock injection point.
252pub fn expiry_from(created_at: SystemTime, ttl: std::time::Duration) -> SystemTime {
253    created_at + ttl
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use std::time::{Duration, SystemTime};
260
261    #[test]
262    fn expiry_adds_ttl_to_creation() {
263        let now = SystemTime::UNIX_EPOCH;
264        let expiry = expiry_from(now, Duration::from_secs(24 * 3600));
265        assert_eq!(
266            expiry.duration_since(SystemTime::UNIX_EPOCH).unwrap(),
267            Duration::from_secs(24 * 3600),
268        );
269    }
270}