use crate::sqlx;
use std::time::SystemTime;
use async_trait::async_trait;
use cratestack_axum::idempotency::{
IDEMPOTENCY_TABLE_DDL, IdempotencyRecord, IdempotencyStore, ReservationOutcome,
};
use cratestack_core::CoolError;
#[derive(Clone)]
pub struct SqlxIdempotencyStore {
pool: sqlx::PgPool,
}
impl SqlxIdempotencyStore {
pub fn new(pool: sqlx::PgPool) -> Self {
Self { pool }
}
pub async fn ensure_schema(&self) -> Result<(), CoolError> {
for statement in IDEMPOTENCY_TABLE_DDL
.split(';')
.map(str::trim)
.filter(|s| !s.is_empty())
{
sqlx::query(statement)
.execute(&self.pool)
.await
.map_err(|error| CoolError::Database(error.to_string()))?;
}
Ok(())
}
pub async fn garbage_collect(&self) -> Result<u64, CoolError> {
let result = sqlx::query("DELETE FROM cratestack_idempotency WHERE expires_at < NOW()")
.execute(&self.pool)
.await
.map_err(|error| CoolError::Database(error.to_string()))?;
Ok(result.rows_affected())
}
}
#[async_trait]
impl IdempotencyStore for SqlxIdempotencyStore {
async fn reserve_or_fetch(
&self,
principal: &str,
key: &str,
request_hash: [u8; 32],
expires_at: SystemTime,
) -> Result<ReservationOutcome, CoolError> {
let expires_at: chrono::DateTime<chrono::Utc> = expires_at.into();
let new_token = uuid::Uuid::new_v4();
let row: Option<(
Vec<u8>,
uuid::Uuid,
Option<i32>,
Option<Vec<u8>>,
Option<Vec<u8>>,
chrono::DateTime<chrono::Utc>,
chrono::DateTime<chrono::Utc>,
bool,
)> = sqlx::query_as(
"INSERT INTO cratestack_idempotency (
principal_fingerprint, key, request_hash, reservation_id, expires_at
) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (principal_fingerprint, key) DO UPDATE SET
request_hash = EXCLUDED.request_hash,
reservation_id = EXCLUDED.reservation_id,
response_status = NULL,
response_headers = NULL,
response_body = NULL,
created_at = NOW(),
expires_at = EXCLUDED.expires_at
WHERE cratestack_idempotency.expires_at <= NOW()
RETURNING request_hash, reservation_id, response_status, response_headers,
response_body, created_at, expires_at, (xmax = 0) AS was_inserted",
)
.bind(principal)
.bind(key)
.bind(request_hash.as_slice())
.bind(new_token)
.bind(expires_at)
.fetch_optional(&self.pool)
.await
.map_err(|error| CoolError::Database(error.to_string()))?;
if let Some((_, token, _, _, _, _, _, _)) = row {
return Ok(ReservationOutcome::Reserved { token });
}
let existing: Option<(
Vec<u8>,
Option<i32>,
Option<Vec<u8>>,
Option<Vec<u8>>,
chrono::DateTime<chrono::Utc>,
chrono::DateTime<chrono::Utc>,
)> = sqlx::query_as(
"SELECT request_hash, response_status, response_headers,
response_body, created_at, expires_at
FROM cratestack_idempotency
WHERE principal_fingerprint = $1 AND key = $2",
)
.bind(principal)
.bind(key)
.fetch_optional(&self.pool)
.await
.map_err(|error| CoolError::Database(error.to_string()))?;
let Some((stored_hash, status, headers, body, created_at, existing_expires_at)) = existing
else {
return Ok(ReservationOutcome::InFlight);
};
let stored: [u8; 32] = stored_hash
.as_slice()
.try_into()
.map_err(|_| CoolError::Internal("corrupt idempotency hash length".to_owned()))?;
if stored != request_hash {
return Ok(ReservationOutcome::Conflict);
}
match (status, body) {
(Some(s), Some(b)) => {
let response_status: u16 = u16::try_from(s).unwrap_or(500);
Ok(ReservationOutcome::Replay(IdempotencyRecord {
principal_fingerprint: principal.to_owned(),
key: key.to_owned(),
request_hash: stored,
response_status,
response_headers: headers.unwrap_or_default(),
response_body: b,
created_at: created_at.into(),
expires_at: existing_expires_at.into(),
}))
}
_ => Ok(ReservationOutcome::InFlight),
}
}
async fn complete(
&self,
principal: &str,
key: &str,
token: uuid::Uuid,
status: u16,
headers: &[u8],
body: &[u8],
) -> Result<(), CoolError> {
sqlx::query(
"UPDATE cratestack_idempotency
SET response_status = $1,
response_headers = $2,
response_body = $3
WHERE principal_fingerprint = $4
AND key = $5
AND reservation_id = $6
AND response_body IS NULL",
)
.bind(status as i32)
.bind(headers)
.bind(body)
.bind(principal)
.bind(key)
.bind(token)
.execute(&self.pool)
.await
.map(|_| ())
.map_err(|error| CoolError::Database(error.to_string()))
}
async fn release(
&self,
principal: &str,
key: &str,
token: uuid::Uuid,
) -> Result<(), CoolError> {
sqlx::query(
"DELETE FROM cratestack_idempotency
WHERE principal_fingerprint = $1
AND key = $2
AND reservation_id = $3
AND response_body IS NULL",
)
.bind(principal)
.bind(key)
.bind(token)
.execute(&self.pool)
.await
.map(|_| ())
.map_err(|error| CoolError::Database(error.to_string()))
}
}
pub fn expiry_from(created_at: SystemTime, ttl: std::time::Duration) -> SystemTime {
created_at + ttl
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::{Duration, SystemTime};
#[test]
fn expiry_adds_ttl_to_creation() {
let now = SystemTime::UNIX_EPOCH;
let expiry = expiry_from(now, Duration::from_secs(24 * 3600));
assert_eq!(
expiry.duration_since(SystemTime::UNIX_EPOCH).unwrap(),
Duration::from_secs(24 * 3600),
);
}
}