1use crate::sqlx;
14
15
16use std::time::SystemTime;
17
18use async_trait::async_trait;
19use cratestack_axum::idempotency::{
20 IDEMPOTENCY_TABLE_DDL, IdempotencyRecord, IdempotencyStore, ReservationOutcome,
21};
22use cratestack_core::CoolError;
23
24#[derive(Clone)]
25pub struct SqlxIdempotencyStore {
26 pool: sqlx::PgPool,
27}
28
29impl SqlxIdempotencyStore {
30 pub fn new(pool: sqlx::PgPool) -> Self {
31 Self { pool }
32 }
33
34 pub async fn ensure_schema(&self) -> Result<(), CoolError> {
37 for statement in IDEMPOTENCY_TABLE_DDL
40 .split(';')
41 .map(str::trim)
42 .filter(|s| !s.is_empty())
43 {
44 sqlx::query(statement)
45 .execute(&self.pool)
46 .await
47 .map_err(|error| CoolError::Database(error.to_string()))?;
48 }
49 Ok(())
50 }
51
52 pub async fn garbage_collect(&self) -> Result<u64, CoolError> {
57 let result = sqlx::query("DELETE FROM cratestack_idempotency WHERE expires_at < NOW()")
58 .execute(&self.pool)
59 .await
60 .map_err(|error| CoolError::Database(error.to_string()))?;
61 Ok(result.rows_affected())
62 }
63}
64
65#[async_trait]
66impl IdempotencyStore for SqlxIdempotencyStore {
67 async fn reserve_or_fetch(
68 &self,
69 principal: &str,
70 key: &str,
71 request_hash: [u8; 32],
72 expires_at: SystemTime,
73 ) -> Result<ReservationOutcome, CoolError> {
74 let expires_at: chrono::DateTime<chrono::Utc> = expires_at.into();
75 let new_token = uuid::Uuid::new_v4();
81 let row: Option<(
90 Vec<u8>,
91 uuid::Uuid,
92 Option<i32>,
93 Option<Vec<u8>>,
94 Option<Vec<u8>>,
95 chrono::DateTime<chrono::Utc>,
96 chrono::DateTime<chrono::Utc>,
97 bool,
98 )> = sqlx::query_as(
99 "INSERT INTO cratestack_idempotency (
100 principal_fingerprint, key, request_hash, reservation_id, expires_at
101 ) VALUES ($1, $2, $3, $4, $5)
102 ON CONFLICT (principal_fingerprint, key) DO UPDATE SET
103 request_hash = EXCLUDED.request_hash,
104 reservation_id = EXCLUDED.reservation_id,
105 response_status = NULL,
106 response_headers = NULL,
107 response_body = NULL,
108 created_at = NOW(),
109 expires_at = EXCLUDED.expires_at
110 WHERE cratestack_idempotency.expires_at <= NOW()
111 RETURNING request_hash, reservation_id, response_status, response_headers,
112 response_body, created_at, expires_at, (xmax = 0) AS was_inserted",
113 )
114 .bind(principal)
115 .bind(key)
116 .bind(request_hash.as_slice())
117 .bind(new_token)
118 .bind(expires_at)
119 .fetch_optional(&self.pool)
120 .await
121 .map_err(|error| CoolError::Database(error.to_string()))?;
122
123 if let Some((_, token, _, _, _, _, _, _)) = row {
124 return Ok(ReservationOutcome::Reserved { token });
129 }
130
131 let existing: Option<(
134 Vec<u8>,
135 Option<i32>,
136 Option<Vec<u8>>,
137 Option<Vec<u8>>,
138 chrono::DateTime<chrono::Utc>,
139 chrono::DateTime<chrono::Utc>,
140 )> = sqlx::query_as(
141 "SELECT request_hash, response_status, response_headers,
142 response_body, created_at, expires_at
143 FROM cratestack_idempotency
144 WHERE principal_fingerprint = $1 AND key = $2",
145 )
146 .bind(principal)
147 .bind(key)
148 .fetch_optional(&self.pool)
149 .await
150 .map_err(|error| CoolError::Database(error.to_string()))?;
151
152 let Some((stored_hash, status, headers, body, created_at, existing_expires_at)) = existing
153 else {
154 return Ok(ReservationOutcome::InFlight);
159 };
160
161 let stored: [u8; 32] = stored_hash
162 .as_slice()
163 .try_into()
164 .map_err(|_| CoolError::Internal("corrupt idempotency hash length".to_owned()))?;
165 if stored != request_hash {
166 return Ok(ReservationOutcome::Conflict);
167 }
168
169 match (status, body) {
170 (Some(s), Some(b)) => {
171 let response_status: u16 = u16::try_from(s).unwrap_or(500);
172 Ok(ReservationOutcome::Replay(IdempotencyRecord {
173 principal_fingerprint: principal.to_owned(),
174 key: key.to_owned(),
175 request_hash: stored,
176 response_status,
177 response_headers: headers.unwrap_or_default(),
178 response_body: b,
179 created_at: created_at.into(),
180 expires_at: existing_expires_at.into(),
181 }))
182 }
183 _ => Ok(ReservationOutcome::InFlight),
184 }
185 }
186
187 async fn complete(
188 &self,
189 principal: &str,
190 key: &str,
191 token: uuid::Uuid,
192 status: u16,
193 headers: &[u8],
194 body: &[u8],
195 ) -> Result<(), CoolError> {
196 sqlx::query(
202 "UPDATE cratestack_idempotency
203 SET response_status = $1,
204 response_headers = $2,
205 response_body = $3
206 WHERE principal_fingerprint = $4
207 AND key = $5
208 AND reservation_id = $6
209 AND response_body IS NULL",
210 )
211 .bind(status as i32)
212 .bind(headers)
213 .bind(body)
214 .bind(principal)
215 .bind(key)
216 .bind(token)
217 .execute(&self.pool)
218 .await
219 .map(|_| ())
220 .map_err(|error| CoolError::Database(error.to_string()))
221 }
222
223 async fn release(
224 &self,
225 principal: &str,
226 key: &str,
227 token: uuid::Uuid,
228 ) -> Result<(), CoolError> {
229 sqlx::query(
233 "DELETE FROM cratestack_idempotency
234 WHERE principal_fingerprint = $1
235 AND key = $2
236 AND reservation_id = $3
237 AND response_body IS NULL",
238 )
239 .bind(principal)
240 .bind(key)
241 .bind(token)
242 .execute(&self.pool)
243 .await
244 .map(|_| ())
245 .map_err(|error| CoolError::Database(error.to_string()))
246 }
247}
248
249pub fn expiry_from(created_at: SystemTime, ttl: std::time::Duration) -> SystemTime {
253 created_at + ttl
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use std::time::{Duration, SystemTime};
260
261 #[test]
262 fn expiry_adds_ttl_to_creation() {
263 let now = SystemTime::UNIX_EPOCH;
264 let expiry = expiry_from(now, Duration::from_secs(24 * 3600));
265 assert_eq!(
266 expiry.duration_since(SystemTime::UNIX_EPOCH).unwrap(),
267 Duration::from_secs(24 * 3600),
268 );
269 }
270}