disintegrate_postgres/
snapshotter.rs

1//! # PostgreSQL Snapshotter
2//!
3//! This module provides an implementation of the `Snapshotter` trait using PostgreSQL as the underlying storage.
4//! It allows storing and retrieving snapshots from a PostgreSQL database.
5use async_trait::async_trait;
6use disintegrate::{BoxDynError, Event, IntoState, StateSnapshotter, StreamQuery};
7use disintegrate::{StatePart, StateQuery};
8use md5::{Digest, Md5};
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11use sqlx::PgPool;
12use sqlx::Row;
13use uuid::Uuid;
14
15use crate::{Error, PgEventId};
16
17#[cfg(test)]
18mod tests;
19
20/// PostgreSQL implementation for the `Snapshotter` trait.
21///
22/// The `PgSnapshotter` struct implements the `Snapshotter` trait for PostgreSQL databases.
23/// It allows for stroring and retrieving snapshots of `StateQuery` from PostgreSQL database.
24#[derive(Clone)]
25pub struct PgSnapshotter {
26    pool: PgPool,
27    every: u64,
28}
29
30impl PgSnapshotter {
31    /// Creates and initializes a new instance of `PgSnapshotter` with the specified PostgreSQL connection pool and snapshot frequency.
32    ///
33    /// # Arguments
34    ///
35    /// - `pool`: A PostgreSQL connection pool (`PgPool`) representing the database connection.
36    /// - `every`: The frequency of snapshot creation, specified as the number of events between consecutive snapshots.
37    ///
38    /// # Returns
39    ///
40    /// A new `PgSnapshotter` instance.
41    pub async fn new(pool: PgPool, every: u64) -> Result<Self, Error> {
42        setup(&pool).await?;
43        Ok(Self::new_uninitialized(pool, every))
44    }
45
46    /// Creates a new instance of `PgSnapshotter` with the specified PostgreSQL connection pool and snapshot frequency.
47    ///
48    /// This constructor does not initialize the database. If you need to initialize the database,
49    /// use `PgSnapshotter::new` instead.
50    ///
51    /// If you use this constructor, ensure that the database is already initialized.
52    /// Refer to the SQL files in the `snapshotter/sql` folder for the necessary schema.
53    ///
54    /// # Arguments
55    ///
56    /// - `pool`: A PostgreSQL connection pool (`PgPool`) representing the database connection.
57    /// - `every`: The frequency of snapshot creation, defined as the number of events between consecutive snapshots.
58    ///
59    /// # Returns
60    ///
61    /// A new `PgSnapshotter` instance.
62    pub fn new_uninitialized(pool: PgPool, every: u64) -> Self {
63        Self { pool, every }
64    }
65}
66
67#[async_trait]
68impl StateSnapshotter<PgEventId> for PgSnapshotter {
69    async fn load_snapshot<S>(&self, default: StatePart<PgEventId, S>) -> StatePart<PgEventId, S>
70    where
71        S: Send + Sync + DeserializeOwned + StateQuery + 'static,
72    {
73        let query = query_key(&default.query());
74        let stored_snapshot =
75            sqlx::query("SELECT name, query, payload, version FROM snapshot where id = $1")
76                .bind(snapshot_id(S::NAME, &query))
77                .fetch_one(&self.pool)
78                .await;
79        if let Ok(row) = stored_snapshot {
80            let snapshot_name: String = row.get(0);
81            let snapshot_query: String = row.get(1);
82            if S::NAME == snapshot_name && query == snapshot_query {
83                let payload = serde_json::from_str(row.get(2)).unwrap_or(default.into_state());
84                return StatePart::new(row.get(3), payload);
85            }
86        }
87
88        default
89    }
90
91    async fn store_snapshot<S>(&self, state: &StatePart<PgEventId, S>) -> Result<(), BoxDynError>
92    where
93        S: Send + Sync + Serialize + StateQuery + 'static,
94    {
95        if state.applied_events() <= self.every {
96            return Ok(());
97        }
98        let query = query_key(&state.query());
99        let id = snapshot_id(S::NAME, &query);
100        let version = state.version();
101        let payload = serde_json::to_string(&state.clone().into_state())?;
102        sqlx::query("INSERT INTO snapshot (id, name, query, payload, version) VALUES ($1,$2,$3,$4,$5) ON CONFLICT(id) DO UPDATE SET name = $2, query = $3, payload = $4, version = $5 WHERE snapshot.version < $5")
103        .bind(id)
104        .bind(S::NAME)
105        .bind(query)
106        .bind(payload)
107        .bind(version)
108        .execute(&self.pool)
109        .await?;
110
111        Ok(())
112    }
113}
114
115fn snapshot_id(state_name: &str, query: &str) -> Uuid {
116    let mut hasher = Md5::new();
117    hasher.update(state_name);
118
119    uuid::Uuid::new_v3(
120        &uuid::Uuid::from_bytes(hasher.finalize().into()),
121        query.as_bytes(),
122    )
123}
124
125fn query_key<E: Event + Clone>(query: &StreamQuery<PgEventId, E>) -> String {
126    let mut result = String::new();
127    for f in query.filters() {
128        let excluded_events = if let Some(exclued_events) = f.excluded_events() {
129            format!("-{}", exclued_events.join(","))
130        } else {
131            "".to_string()
132        };
133        result += &format!(
134            "({}|{}{}|{})",
135            f.origin(),
136            f.events().join(","),
137            excluded_events,
138            f.identifiers()
139                .iter()
140                .map(|(k, v)| format!("{k}={v}"))
141                .collect::<Vec<_>>()
142                .join(",")
143        );
144    }
145    result
146}
147
148pub async fn setup(pool: &PgPool) -> Result<(), Error> {
149    sqlx::query(include_str!("snapshotter/sql/table_snapshot.sql"))
150        .execute(pool)
151        .await?;
152    Ok(())
153}