atomr_persistence_sql/
snapshot.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6use atomr_persistence::{JournalError, SnapshotMetadata, SnapshotStore};
7use sqlx::any::AnyPoolOptions;
8use sqlx::AnyPool;
9
10use crate::config::SqlConfig;
11use crate::schema::{ensure_schema, init_drivers};
12
13pub struct SqlSnapshotStore {
14 pool: AnyPool,
15 cfg: SqlConfig,
16}
17
18impl SqlSnapshotStore {
19 pub async fn connect(cfg: SqlConfig) -> Result<Arc<Self>, JournalError> {
20 init_drivers();
21 let pool = AnyPoolOptions::new()
22 .max_connections(cfg.max_connections)
23 .connect(&cfg.url)
24 .await
25 .map_err(JournalError::backend)?;
26 ensure_schema(&pool, &cfg).await?;
27 Ok(Arc::new(Self { pool, cfg }))
28 }
29
30 pub async fn from_pool(pool: AnyPool, cfg: SqlConfig) -> Result<Arc<Self>, JournalError> {
31 ensure_schema(&pool, &cfg).await?;
32 Ok(Arc::new(Self { pool, cfg }))
33 }
34
35 pub fn pool(&self) -> &AnyPool {
36 &self.pool
37 }
38
39 pub fn config(&self) -> &SqlConfig {
40 &self.cfg
41 }
42}
43
44#[async_trait]
45impl SnapshotStore for SqlSnapshotStore {
46 async fn save(&self, meta: SnapshotMetadata, payload: Vec<u8>) {
47 let created_at = chrono::Utc::now().timestamp_millis();
48 let _ = sqlx::query(
49 "INSERT INTO snapshot_store (persistence_id, sequence_nr, payload, timestamp, created_at) VALUES (?, ?, ?, ?, ?)",
50 )
51 .bind(&meta.persistence_id)
52 .bind(meta.sequence_nr as i64)
53 .bind(payload)
54 .bind(meta.timestamp as i64)
55 .bind(created_at)
56 .execute(&self.pool)
57 .await;
58 }
59
60 async fn load(&self, persistence_id: &str) -> Option<(SnapshotMetadata, Vec<u8>)> {
61 let row: Option<(String, i64, Vec<u8>, i64)> = sqlx::query_as(
62 "SELECT persistence_id, sequence_nr, payload, timestamp FROM snapshot_store \
63 WHERE persistence_id = ? ORDER BY sequence_nr DESC LIMIT 1",
64 )
65 .bind(persistence_id)
66 .fetch_optional(&self.pool)
67 .await
68 .ok()
69 .flatten();
70 row.map(|(pid, seq, payload, ts)| {
71 (SnapshotMetadata { persistence_id: pid, sequence_nr: seq as u64, timestamp: ts as u64 }, payload)
72 })
73 }
74
75 async fn delete(&self, persistence_id: &str, to_sequence_nr: u64) {
76 let _ = sqlx::query("DELETE FROM snapshot_store WHERE persistence_id = ? AND sequence_nr <= ?")
77 .bind(persistence_id)
78 .bind(to_sequence_nr as i64)
79 .execute(&self.pool)
80 .await;
81 }
82}