disintegrate_postgres/
snapshotter.rs1use async_trait::async_trait;
6use disintegrate::{BoxDynError, Event, IntoState, StateSnapshotter, StreamQuery};
7use disintegrate::{StatePart, StateQuery};
8use md5::{Digest, Md5};
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11use sqlx::PgPool;
12use sqlx::Row;
13use uuid::Uuid;
14
15use crate::{Error, PgEventId};
16
17#[cfg(test)]
18mod tests;
19
20#[derive(Clone)]
25pub struct PgSnapshotter {
26 pool: PgPool,
27 every: u64,
28}
29
30impl PgSnapshotter {
31 pub async fn new(pool: PgPool, every: u64) -> Result<Self, Error> {
42 setup(&pool).await?;
43 Ok(Self::new_uninitialized(pool, every))
44 }
45
46 pub fn new_uninitialized(pool: PgPool, every: u64) -> Self {
63 Self { pool, every }
64 }
65}
66
67#[async_trait]
68impl StateSnapshotter<PgEventId> for PgSnapshotter {
69 async fn load_snapshot<S>(&self, default: StatePart<PgEventId, S>) -> StatePart<PgEventId, S>
70 where
71 S: Send + Sync + DeserializeOwned + StateQuery + 'static,
72 {
73 let query = query_key(&default.query());
74 let stored_snapshot =
75 sqlx::query("SELECT name, query, payload, version FROM snapshot where id = $1")
76 .bind(snapshot_id(S::NAME, &query))
77 .fetch_one(&self.pool)
78 .await;
79 if let Ok(row) = stored_snapshot {
80 let snapshot_name: String = row.get(0);
81 let snapshot_query: String = row.get(1);
82 if S::NAME == snapshot_name && query == snapshot_query {
83 let payload = serde_json::from_str(row.get(2)).unwrap_or(default.into_state());
84 return StatePart::new(row.get(3), payload);
85 }
86 }
87
88 default
89 }
90
91 async fn store_snapshot<S>(&self, state: &StatePart<PgEventId, S>) -> Result<(), BoxDynError>
92 where
93 S: Send + Sync + Serialize + StateQuery + 'static,
94 {
95 if state.applied_events() <= self.every {
96 return Ok(());
97 }
98 let query = query_key(&state.query());
99 let id = snapshot_id(S::NAME, &query);
100 let version = state.version();
101 let payload = serde_json::to_string(&state.clone().into_state())?;
102 sqlx::query("INSERT INTO snapshot (id, name, query, payload, version) VALUES ($1,$2,$3,$4,$5) ON CONFLICT(id) DO UPDATE SET name = $2, query = $3, payload = $4, version = $5 WHERE snapshot.version < $5")
103 .bind(id)
104 .bind(S::NAME)
105 .bind(query)
106 .bind(payload)
107 .bind(version)
108 .execute(&self.pool)
109 .await?;
110
111 Ok(())
112 }
113}
114
115fn snapshot_id(state_name: &str, query: &str) -> Uuid {
116 let mut hasher = Md5::new();
117 hasher.update(state_name);
118
119 uuid::Uuid::new_v3(
120 &uuid::Uuid::from_bytes(hasher.finalize().into()),
121 query.as_bytes(),
122 )
123}
124
125fn query_key<E: Event + Clone>(query: &StreamQuery<PgEventId, E>) -> String {
126 let mut result = String::new();
127 for f in query.filters() {
128 let excluded_events = if let Some(exclued_events) = f.excluded_events() {
129 format!("-{}", exclued_events.join(","))
130 } else {
131 "".to_string()
132 };
133 result += &format!(
134 "({}|{}{}|{})",
135 f.origin(),
136 f.events().join(","),
137 excluded_events,
138 f.identifiers()
139 .iter()
140 .map(|(k, v)| format!("{k}={v}"))
141 .collect::<Vec<_>>()
142 .join(",")
143 );
144 }
145 result
146}
147
148pub async fn setup(pool: &PgPool) -> Result<(), Error> {
149 sqlx::query(include_str!("snapshotter/sql/table_snapshot.sql"))
150 .execute(pool)
151 .await?;
152 Ok(())
153}