use async_trait::async_trait;
use disintegrate::stream_query::StreamFilter;
use disintegrate::{BoxDynError, IntoState, StateSnapshotter};
use disintegrate::{StatePart, StateQuery};
use md5::{Digest, Md5};
use serde::de::DeserializeOwned;
use serde::Serialize;
use sqlx::PgPool;
use sqlx::Row;
use uuid::Uuid;
use crate::Error;
#[cfg(test)]
mod tests;
#[derive(Clone)]
pub struct PgSnapshotter {
pool: PgPool,
every: u64,
}
impl PgSnapshotter {
pub async fn new(pool: PgPool, every: u64) -> Result<Self, Error> {
setup(&pool).await?;
Ok(Self { pool, every })
}
}
#[async_trait]
impl StateSnapshotter for PgSnapshotter {
async fn load_snapshot<S>(&self, default: StatePart<S>) -> StatePart<S>
where
S: Send + Sync + DeserializeOwned + StateQuery + 'static,
{
let query = query_key(default.query().filter());
let stored_snapshot =
sqlx::query("SELECT name, query, payload, version FROM snapshot where id = $1")
.bind(snapshot_id(S::NAME, &query))
.fetch_one(&self.pool)
.await;
if let Ok(row) = stored_snapshot {
let snapshot_name: String = row.get(0);
let snapshot_query: String = row.get(1);
if S::NAME == snapshot_name && query == snapshot_query {
let payload = serde_json::from_str(row.get(2)).unwrap_or(default.into_state());
return StatePart::new(row.get(3), payload);
}
}
default
}
async fn store_snapshot<S>(&self, state: &StatePart<S>) -> Result<(), BoxDynError>
where
S: Send + Sync + Serialize + StateQuery + 'static,
{
if state.applied_events() <= self.every {
return Ok(());
}
let query = query_key(state.query().filter());
let id = snapshot_id(S::NAME, &query);
let version = state.version();
let payload = serde_json::to_string(&state.clone().into_state())?;
sqlx::query("INSERT INTO snapshot (id, name, query, payload, version) VALUES ($1,$2,$3,$4,$5) ON CONFLICT(id) DO UPDATE SET name = $2, query = $3, payload = $4, version = $5 WHERE snapshot.version < $5")
.bind(id)
.bind(S::NAME)
.bind(query)
.bind(payload)
.bind(version)
.execute(&self.pool)
.await?;
Ok(())
}
}
fn snapshot_id(state_name: &str, query: &str) -> Uuid {
let mut hasher = Md5::new();
hasher.update(state_name);
uuid::Uuid::new_v3(
&uuid::Uuid::from_bytes(hasher.finalize().into()),
query.as_bytes(),
)
}
fn query_key(filter: &StreamFilter) -> String {
match filter {
StreamFilter::Events { names } => {
format!("({})", names.join(","))
}
StreamFilter::ExcludeEvents { names } => {
format!("not({})", names.join(","))
}
StreamFilter::Eq { ident, value } => format!("{ident}={value}"),
StreamFilter::And { l, r } => format!("{}&{}", query_key(l), query_key(r)),
StreamFilter::Or { l, r } => format!("{}|{}", query_key(l), query_key(r)),
StreamFilter::Origin { id } => format!(">={id}"),
}
}
pub async fn setup(pool: &PgPool) -> Result<(), Error> {
sqlx::query(include_str!("snapshotter/sql/table_snapshot.sql"))
.execute(pool)
.await?;
Ok(())
}