1use std::time::SystemTime;
15
16use async_trait::async_trait;
17use cratestack_axum::idempotency::{
18 IDEMPOTENCY_TABLE_DDL, IdempotencyRecord, IdempotencyStore, ReservationOutcome,
19};
20use cratestack_core::CoolError;
21
22#[derive(Clone)]
23pub struct SqlxIdempotencyStore {
24 pool: sqlx::PgPool,
25}
26
27impl SqlxIdempotencyStore {
28 pub fn new(pool: sqlx::PgPool) -> Self {
29 Self { pool }
30 }
31
32 pub async fn ensure_schema(&self) -> Result<(), CoolError> {
35 for statement in IDEMPOTENCY_TABLE_DDL
38 .split(';')
39 .map(str::trim)
40 .filter(|s| !s.is_empty())
41 {
42 sqlx::query(statement)
43 .execute(&self.pool)
44 .await
45 .map_err(|error| CoolError::Database(error.to_string()))?;
46 }
47 Ok(())
48 }
49
50 pub async fn garbage_collect(&self) -> Result<u64, CoolError> {
55 let result = sqlx::query("DELETE FROM cratestack_idempotency WHERE expires_at < NOW()")
56 .execute(&self.pool)
57 .await
58 .map_err(|error| CoolError::Database(error.to_string()))?;
59 Ok(result.rows_affected())
60 }
61}
62
63#[async_trait]
64impl IdempotencyStore for SqlxIdempotencyStore {
65 async fn reserve_or_fetch(
66 &self,
67 principal: &str,
68 key: &str,
69 request_hash: [u8; 32],
70 expires_at: SystemTime,
71 ) -> Result<ReservationOutcome, CoolError> {
72 let expires_at: chrono::DateTime<chrono::Utc> = expires_at.into();
73 let new_token = uuid::Uuid::new_v4();
79 let row: Option<(
88 Vec<u8>,
89 uuid::Uuid,
90 Option<i32>,
91 Option<Vec<u8>>,
92 Option<Vec<u8>>,
93 chrono::DateTime<chrono::Utc>,
94 chrono::DateTime<chrono::Utc>,
95 bool,
96 )> = sqlx::query_as(
97 "INSERT INTO cratestack_idempotency (
98 principal_fingerprint, key, request_hash, reservation_id, expires_at
99 ) VALUES ($1, $2, $3, $4, $5)
100 ON CONFLICT (principal_fingerprint, key) DO UPDATE SET
101 request_hash = EXCLUDED.request_hash,
102 reservation_id = EXCLUDED.reservation_id,
103 response_status = NULL,
104 response_headers = NULL,
105 response_body = NULL,
106 created_at = NOW(),
107 expires_at = EXCLUDED.expires_at
108 WHERE cratestack_idempotency.expires_at <= NOW()
109 RETURNING request_hash, reservation_id, response_status, response_headers,
110 response_body, created_at, expires_at, (xmax = 0) AS was_inserted",
111 )
112 .bind(principal)
113 .bind(key)
114 .bind(request_hash.as_slice())
115 .bind(new_token)
116 .bind(expires_at)
117 .fetch_optional(&self.pool)
118 .await
119 .map_err(|error| CoolError::Database(error.to_string()))?;
120
121 if let Some((_, token, _, _, _, _, _, _)) = row {
122 return Ok(ReservationOutcome::Reserved { token });
127 }
128
129 let existing: Option<(
132 Vec<u8>,
133 Option<i32>,
134 Option<Vec<u8>>,
135 Option<Vec<u8>>,
136 chrono::DateTime<chrono::Utc>,
137 chrono::DateTime<chrono::Utc>,
138 )> = sqlx::query_as(
139 "SELECT request_hash, response_status, response_headers,
140 response_body, created_at, expires_at
141 FROM cratestack_idempotency
142 WHERE principal_fingerprint = $1 AND key = $2",
143 )
144 .bind(principal)
145 .bind(key)
146 .fetch_optional(&self.pool)
147 .await
148 .map_err(|error| CoolError::Database(error.to_string()))?;
149
150 let Some((stored_hash, status, headers, body, created_at, existing_expires_at)) = existing
151 else {
152 return Ok(ReservationOutcome::InFlight);
157 };
158
159 let stored: [u8; 32] = stored_hash
160 .as_slice()
161 .try_into()
162 .map_err(|_| CoolError::Internal("corrupt idempotency hash length".to_owned()))?;
163 if stored != request_hash {
164 return Ok(ReservationOutcome::Conflict);
165 }
166
167 match (status, body) {
168 (Some(s), Some(b)) => {
169 let response_status: u16 = u16::try_from(s).unwrap_or(500);
170 Ok(ReservationOutcome::Replay(IdempotencyRecord {
171 principal_fingerprint: principal.to_owned(),
172 key: key.to_owned(),
173 request_hash: stored,
174 response_status,
175 response_headers: headers.unwrap_or_default(),
176 response_body: b,
177 created_at: created_at.into(),
178 expires_at: existing_expires_at.into(),
179 }))
180 }
181 _ => Ok(ReservationOutcome::InFlight),
182 }
183 }
184
185 async fn complete(
186 &self,
187 principal: &str,
188 key: &str,
189 token: uuid::Uuid,
190 status: u16,
191 headers: &[u8],
192 body: &[u8],
193 ) -> Result<(), CoolError> {
194 sqlx::query(
200 "UPDATE cratestack_idempotency
201 SET response_status = $1,
202 response_headers = $2,
203 response_body = $3
204 WHERE principal_fingerprint = $4
205 AND key = $5
206 AND reservation_id = $6
207 AND response_body IS NULL",
208 )
209 .bind(status as i32)
210 .bind(headers)
211 .bind(body)
212 .bind(principal)
213 .bind(key)
214 .bind(token)
215 .execute(&self.pool)
216 .await
217 .map(|_| ())
218 .map_err(|error| CoolError::Database(error.to_string()))
219 }
220
221 async fn release(
222 &self,
223 principal: &str,
224 key: &str,
225 token: uuid::Uuid,
226 ) -> Result<(), CoolError> {
227 sqlx::query(
231 "DELETE FROM cratestack_idempotency
232 WHERE principal_fingerprint = $1
233 AND key = $2
234 AND reservation_id = $3
235 AND response_body IS NULL",
236 )
237 .bind(principal)
238 .bind(key)
239 .bind(token)
240 .execute(&self.pool)
241 .await
242 .map(|_| ())
243 .map_err(|error| CoolError::Database(error.to_string()))
244 }
245}
246
247pub fn expiry_from(created_at: SystemTime, ttl: std::time::Duration) -> SystemTime {
251 created_at + ttl
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use std::time::{Duration, SystemTime};
258
259 #[test]
260 fn expiry_adds_ttl_to_creation() {
261 let now = SystemTime::UNIX_EPOCH;
262 let expiry = expiry_from(now, Duration::from_secs(24 * 3600));
263 assert_eq!(
264 expiry.duration_since(SystemTime::UNIX_EPOCH).unwrap(),
265 Duration::from_secs(24 * 3600),
266 );
267 }
268}