Skip to main content

atomr_persistence_sql/
snapshot.rs

1//! `SnapshotStore` implementation backed by sqlx.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use atomr_persistence::{JournalError, SnapshotMetadata, SnapshotStore};
7use sqlx::any::AnyPoolOptions;
8use sqlx::AnyPool;
9
10use crate::config::SqlConfig;
11use crate::schema::{ensure_schema, init_drivers};
12
13pub struct SqlSnapshotStore {
14    pool: AnyPool,
15    cfg: SqlConfig,
16}
17
18impl SqlSnapshotStore {
19    pub async fn connect(cfg: SqlConfig) -> Result<Arc<Self>, JournalError> {
20        init_drivers();
21        let pool = AnyPoolOptions::new()
22            .max_connections(cfg.max_connections)
23            .connect(&cfg.url)
24            .await
25            .map_err(JournalError::backend)?;
26        ensure_schema(&pool, &cfg).await?;
27        Ok(Arc::new(Self { pool, cfg }))
28    }
29
30    pub async fn from_pool(pool: AnyPool, cfg: SqlConfig) -> Result<Arc<Self>, JournalError> {
31        ensure_schema(&pool, &cfg).await?;
32        Ok(Arc::new(Self { pool, cfg }))
33    }
34
35    pub fn pool(&self) -> &AnyPool {
36        &self.pool
37    }
38
39    pub fn config(&self) -> &SqlConfig {
40        &self.cfg
41    }
42}
43
44#[async_trait]
45impl SnapshotStore for SqlSnapshotStore {
46    async fn save(&self, meta: SnapshotMetadata, payload: Vec<u8>) {
47        let created_at = chrono::Utc::now().timestamp_millis();
48        let _ = sqlx::query(
49            "INSERT INTO snapshot_store (persistence_id, sequence_nr, payload, timestamp, created_at) VALUES (?, ?, ?, ?, ?)",
50        )
51        .bind(&meta.persistence_id)
52        .bind(meta.sequence_nr as i64)
53        .bind(payload)
54        .bind(meta.timestamp as i64)
55        .bind(created_at)
56        .execute(&self.pool)
57        .await;
58    }
59
60    async fn load(&self, persistence_id: &str) -> Option<(SnapshotMetadata, Vec<u8>)> {
61        let row: Option<(String, i64, Vec<u8>, i64)> = sqlx::query_as(
62            "SELECT persistence_id, sequence_nr, payload, timestamp FROM snapshot_store \
63             WHERE persistence_id = ? ORDER BY sequence_nr DESC LIMIT 1",
64        )
65        .bind(persistence_id)
66        .fetch_optional(&self.pool)
67        .await
68        .ok()
69        .flatten();
70        row.map(|(pid, seq, payload, ts)| {
71            (SnapshotMetadata { persistence_id: pid, sequence_nr: seq as u64, timestamp: ts as u64 }, payload)
72        })
73    }
74
75    async fn delete(&self, persistence_id: &str, to_sequence_nr: u64) {
76        let _ = sqlx::query("DELETE FROM snapshot_store WHERE persistence_id = ? AND sequence_nr <= ?")
77            .bind(persistence_id)
78            .bind(to_sequence_nr as i64)
79            .execute(&self.pool)
80            .await;
81    }
82}