use chrono::{DateTime, Utc};
use sqlx::{Postgres, Transaction};
use crate::{DataStoreError, InvariantID};
pub type SqlResult<T> = Result<T, DataStoreError>;
#[derive(Debug, Clone)]
pub struct InvariantRecord {
pub invariant_id: InvariantID,
pub asserts: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
pub async fn create(
tx: &mut Transaction<'_, Postgres>,
invariant_id: &InvariantID,
asserts: &str,
) -> SqlResult<()> {
let invariant_bytes = invariant_id.as_bytes();
let result = sqlx::query!(
r#"
INSERT INTO invariants (invariant_id, asserts)
VALUES ($1, $2)
"#,
invariant_bytes.as_slice(),
asserts
)
.execute(&mut **tx)
.await;
match result {
Ok(_) => Ok(()),
Err(sqlx::Error::Database(db_err)) if db_err.is_unique_violation() => {
Err(DataStoreError::AlreadyExists)
}
Err(e) => {
eprintln!("Database error creating invariant: {}", e);
Err(DataStoreError::Internal(e.to_string()))
}
}
}
pub async fn get(
tx: &mut Transaction<'_, Postgres>,
invariant_id: &InvariantID,
) -> SqlResult<Option<InvariantRecord>> {
let invariant_bytes = invariant_id.as_bytes();
let result = sqlx::query!(
r#"
SELECT invariant_id, asserts, created_at, updated_at
FROM invariants
WHERE invariant_id = $1
"#,
invariant_bytes.as_slice()
)
.fetch_optional(&mut **tx)
.await;
match result {
Ok(Some(row)) => {
let invariant_bytes: [u8; 32] = row
.invariant_id
.try_into()
.map_err(|_| DataStoreError::Internal("invalid invariant_id length".to_string()))?;
Ok(Some(InvariantRecord {
invariant_id: InvariantID::new(invariant_bytes),
asserts: row.asserts,
created_at: row.created_at,
updated_at: row.updated_at,
}))
}
Ok(None) => Ok(None),
Err(e) => {
eprintln!("Database error getting invariant: {}", e);
Err(DataStoreError::Internal(e.to_string()))
}
}
}
pub async fn update(
tx: &mut Transaction<'_, Postgres>,
invariant_id: &InvariantID,
asserts: &str,
) -> SqlResult<bool> {
let invariant_bytes = invariant_id.as_bytes();
let result = sqlx::query!(
r#"
UPDATE invariants
SET asserts = $2, updated_at = CURRENT_TIMESTAMP
WHERE invariant_id = $1
"#,
invariant_bytes.as_slice(),
asserts
)
.execute(&mut **tx)
.await;
match result {
Ok(result) => Ok(result.rows_affected() > 0),
Err(e) => {
eprintln!("Database error updating invariant: {}", e);
Err(DataStoreError::Internal(e.to_string()))
}
}
}
pub async fn delete(
tx: &mut Transaction<'_, Postgres>,
invariant_id: &InvariantID,
) -> SqlResult<bool> {
let invariant_bytes = invariant_id.as_bytes();
let result = sqlx::query!(
r#"
DELETE FROM invariants
WHERE invariant_id = $1
"#,
invariant_bytes.as_slice()
)
.execute(&mut **tx)
.await;
match result {
Ok(result) => Ok(result.rows_affected() > 0),
Err(e) => {
eprintln!("Database error deleting invariant: {}", e);
Err(DataStoreError::Internal(e.to_string()))
}
}
}
pub async fn list(tx: &mut Transaction<'_, Postgres>) -> SqlResult<Vec<InvariantRecord>> {
let result = sqlx::query!(
r#"
SELECT invariant_id, asserts, created_at, updated_at
FROM invariants
ORDER BY created_at ASC
"#
)
.fetch_all(&mut **tx)
.await;
match result {
Ok(rows) => {
let mut invariants = Vec::new();
for row in rows {
let invariant_bytes: [u8; 32] = row.invariant_id.try_into().map_err(|_| {
DataStoreError::Internal("invalid invariant_id length".to_string())
})?;
invariants.push(InvariantRecord {
invariant_id: InvariantID::new(invariant_bytes),
asserts: row.asserts,
created_at: row.created_at,
updated_at: row.updated_at,
});
}
Ok(invariants)
}
Err(e) => {
eprintln!("Database error listing invariants: {}", e);
Err(DataStoreError::Internal(e.to_string()))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn unique_invariant(test_name: &str) -> InvariantID {
use std::time::{SystemTime, UNIX_EPOCH};
let pid = std::process::id();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_micros() as u64;
let mut bytes = [0u8; 32];
bytes[0..4].copy_from_slice(&pid.to_le_bytes());
bytes[4..12].copy_from_slice(&now.to_le_bytes());
let test_bytes = test_name.as_bytes();
let copy_len = test_bytes.len().min(20);
bytes[12..12 + copy_len].copy_from_slice(&test_bytes[..copy_len]);
InvariantID::new(bytes)
}
#[tokio::test]
async fn create_and_get() {
let pool = super::super::tests::setup_test_db().await;
let invariant_id = unique_invariant("create_and_get");
let asserts = "x > 0 && y < 100";
let db_before = sqlx::query_scalar::<_, DateTime<Utc>>("SELECT CURRENT_TIMESTAMP")
.fetch_one(&pool)
.await
.unwrap();
let mut tx = pool.begin().await.unwrap();
create(&mut tx, &invariant_id, asserts).await.unwrap();
tx.commit().await.unwrap();
let db_after = sqlx::query_scalar::<_, DateTime<Utc>>("SELECT CURRENT_TIMESTAMP")
.fetch_one(&pool)
.await
.unwrap();
let mut tx = pool.begin().await.unwrap();
let record = get(&mut tx, &invariant_id).await.unwrap();
tx.commit().await.unwrap();
assert!(record.is_some());
let record = record.unwrap();
assert_eq!(record.invariant_id, invariant_id);
assert_eq!(record.asserts, asserts);
assert!(record.created_at >= db_before);
assert!(record.created_at <= db_after);
assert!(record.updated_at >= db_before);
assert!(record.updated_at <= db_after);
assert_eq!(record.created_at, record.updated_at);
}
#[tokio::test]
async fn create_duplicate_fails() {
let pool = super::super::tests::setup_test_db().await;
let invariant_id = unique_invariant("create_duplicate_fails");
let mut tx = pool.begin().await.unwrap();
create(&mut tx, &invariant_id, "x > 0").await.unwrap();
tx.commit().await.unwrap();
let mut tx = pool.begin().await.unwrap();
let result = create(&mut tx, &invariant_id, "y > 0").await;
assert!(matches!(result, Err(DataStoreError::AlreadyExists)));
}
#[tokio::test]
async fn update_existing() {
let pool = super::super::tests::setup_test_db().await;
let invariant_id = unique_invariant("update_existing");
let mut tx = pool.begin().await.unwrap();
create(&mut tx, &invariant_id, "x > 0").await.unwrap();
tx.commit().await.unwrap();
let mut tx = pool.begin().await.unwrap();
let record_before = get(&mut tx, &invariant_id).await.unwrap().unwrap();
tx.commit().await.unwrap();
assert_eq!(record_before.asserts, "x > 0");
assert_eq!(record_before.created_at, record_before.updated_at);
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let mut tx = pool.begin().await.unwrap();
let updated = update(&mut tx, &invariant_id, "y < 100").await.unwrap();
tx.commit().await.unwrap();
assert!(updated);
let mut tx = pool.begin().await.unwrap();
let record_after = get(&mut tx, &invariant_id).await.unwrap().unwrap();
tx.commit().await.unwrap();
assert_eq!(record_after.asserts, "y < 100");
assert_eq!(record_after.created_at, record_before.created_at);
assert!(record_after.updated_at > record_before.updated_at);
}
#[tokio::test]
async fn update_nonexistent() {
let pool = super::super::tests::setup_test_db().await;
let invariant_id = unique_invariant("update_nonexistent");
let mut tx = pool.begin().await.unwrap();
let updated = update(&mut tx, &invariant_id, "x > 0").await.unwrap();
tx.commit().await.unwrap();
assert!(!updated);
}
#[tokio::test]
async fn delete_existing() {
let pool = super::super::tests::setup_test_db().await;
let invariant_id = unique_invariant("delete_existing");
let mut tx = pool.begin().await.unwrap();
create(&mut tx, &invariant_id, "x > 0").await.unwrap();
tx.commit().await.unwrap();
let mut tx = pool.begin().await.unwrap();
let deleted = delete(&mut tx, &invariant_id).await.unwrap();
tx.commit().await.unwrap();
assert!(deleted);
let mut tx = pool.begin().await.unwrap();
let record = get(&mut tx, &invariant_id).await.unwrap();
tx.commit().await.unwrap();
assert!(record.is_none());
}
#[tokio::test]
async fn delete_nonexistent() {
let pool = super::super::tests::setup_test_db().await;
let invariant_id = unique_invariant("delete_nonexistent");
let mut tx = pool.begin().await.unwrap();
let deleted = delete(&mut tx, &invariant_id).await.unwrap();
tx.commit().await.unwrap();
assert!(!deleted);
}
#[tokio::test]
async fn list_multiple() {
let pool = super::super::tests::setup_test_db().await;
let invariant1 = unique_invariant("list_multiple_1");
let invariant2 = unique_invariant("list_multiple_2");
let invariant3 = unique_invariant("list_multiple_3");
let mut tx = pool.begin().await.unwrap();
create(&mut tx, &invariant1, "x > 0").await.unwrap();
create(&mut tx, &invariant2, "y > 0").await.unwrap();
create(&mut tx, &invariant3, "z > 0").await.unwrap();
tx.commit().await.unwrap();
let mut tx = pool.begin().await.unwrap();
let invariants = list(&mut tx).await.unwrap();
tx.commit().await.unwrap();
let ids: Vec<_> = invariants.iter().map(|r| r.invariant_id).collect();
assert!(ids.contains(&invariant1));
assert!(ids.contains(&invariant2));
assert!(ids.contains(&invariant3));
}
}