1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
//! # PostgreSQL Snapshotter
//!
//! This module provides an implementation of the `Snapshotter` trait using PostgreSQL as the underlying storage.
//! It allows storing and retrieving snapshots from a PostgreSQL database.
use async_trait::async_trait;
use disintegrate::stream_query::StreamFilter;
use disintegrate::{BoxDynError, IntoState, StateSnapshotter};
use disintegrate::{StatePart, StateQuery};
use md5::{Digest, Md5};
use serde::de::DeserializeOwned;
use serde::Serialize;
use sqlx::PgPool;
use sqlx::Row;
use uuid::Uuid;

use crate::Error;

#[cfg(test)]
mod tests;

/// PostgreSQL implementation for the `Snapshotter` trait.
///
/// The `PgSnapshotter` struct implements the `Snapshotter` trait for PostgreSQL databases.
/// It allows for stroring and retrieving snapshots of `StateQuery` from PostgreSQL database.
#[derive(Clone)]
pub struct PgSnapshotter {
    pool: PgPool,
    every: u64,
}

impl PgSnapshotter {
    /// Creates a new instance of `PgSnapshotter` with the specified PostgreSQL connection pool and snapshot frequency.
    ///
    /// # Arguments
    ///
    /// - `pool`: A PostgreSQL connection pool (`PgPool`) representing the database connection.
    /// - `every`: The frequency of snapshot creation, specified as the number of events between consecutive snapshots.
    ///
    /// # Returns
    ///
    /// A new `PgSnapshotter` instance.
    pub async fn new(pool: PgPool, every: u64) -> Result<Self, Error> {
        setup(&pool).await?;
        Ok(Self { pool, every })
    }
}

#[async_trait]
impl StateSnapshotter for PgSnapshotter {
    async fn load_snapshot<S>(&self, default: StatePart<S>) -> StatePart<S>
    where
        S: Send + Sync + DeserializeOwned + StateQuery + 'static,
    {
        let query = query_key(default.query().filter());
        let stored_snapshot =
            sqlx::query("SELECT name, query, payload, version FROM snapshot where id = $1")
                .bind(snapshot_id(S::NAME, &query))
                .fetch_one(&self.pool)
                .await;
        if let Ok(row) = stored_snapshot {
            let snapshot_name: String = row.get(0);
            let snapshot_query: String = row.get(1);
            if S::NAME == snapshot_name && query == snapshot_query {
                let payload = serde_json::from_str(row.get(2)).unwrap_or(default.into_state());
                return StatePart::new(row.get(3), payload);
            }
        }

        default
    }

    async fn store_snapshot<S>(&self, state: &StatePart<S>) -> Result<(), BoxDynError>
    where
        S: Send + Sync + Serialize + StateQuery + 'static,
    {
        if state.applied_events() <= self.every {
            return Ok(());
        }
        let query = query_key(state.query().filter());
        let id = snapshot_id(S::NAME, &query);
        let version = state.version();
        let payload = serde_json::to_string(&state.clone().into_state())?;
        sqlx::query("INSERT INTO snapshot (id, name, query, payload, version) VALUES ($1,$2,$3,$4,$5) ON CONFLICT(id) DO UPDATE SET name = $2, query = $3, payload = $4, version = $5 WHERE snapshot.version < $5")
        .bind(id)
        .bind(S::NAME)
        .bind(query)
        .bind(payload)
        .bind(version)
        .execute(&self.pool)
        .await?;

        Ok(())
    }
}

fn snapshot_id(state_name: &str, query: &str) -> Uuid {
    let mut hasher = Md5::new();
    hasher.update(state_name);

    uuid::Uuid::new_v3(
        &uuid::Uuid::from_bytes(hasher.finalize().into()),
        query.as_bytes(),
    )
}

fn query_key(filter: &StreamFilter) -> String {
    match filter {
        StreamFilter::Events { names } => {
            format!("({})", names.join(","))
        }
        StreamFilter::ExcludeEvents { names } => {
            format!("not({})", names.join(","))
        }
        StreamFilter::Eq { ident, value } => format!("{ident}={value}"),
        StreamFilter::And { l, r } => format!("{}&{}", query_key(l), query_key(r)),
        StreamFilter::Or { l, r } => format!("{}|{}", query_key(l), query_key(r)),
        StreamFilter::Origin { id } => format!(">={id}"),
    }
}

pub async fn setup(pool: &PgPool) -> Result<(), Error> {
    sqlx::query(include_str!("snapshotter/sql/table_snapshot.sql"))
        .execute(pool)
        .await?;
    Ok(())
}