disintegrate_postgres/
event_store.rs1mod append;
6mod query;
7#[cfg(test)]
8mod tests;
9
10use append::{InsertEventSequenceBuilder, InsertEventsBuilder};
11use futures::stream::BoxStream;
12use query::CriteriaBuilder;
13use sqlx::{PgPool, Row};
14use std::error::Error as StdError;
15use std::sync::Arc;
16use tokio::sync::Semaphore;
17
18use std::marker::PhantomData;
19
20use crate::{Error, PgEventId};
21use async_stream::stream;
22use async_trait::async_trait;
23use disintegrate::StreamQuery;
24use disintegrate::{DomainIdentifierInfo, EventStore};
25use disintegrate::{Event, PersistedEvent};
26use disintegrate_serde::Serde;
27
28use futures::StreamExt;
29
30#[derive(Clone)]
32pub struct PgEventStore<E, S>
33where
34 S: Serde<E> + Send + Sync,
35{
36 pub(crate) pool: PgPool,
37 concurrent_appends: Arc<tokio::sync::Semaphore>,
38 serde: S,
39 event_type: PhantomData<E>,
40}
41
42impl<E, S> PgEventStore<E, S>
43where
44 S: Serde<E> + Send + Sync,
45 E: Event,
46{
47 pub async fn new(pool: PgPool, serde: S) -> Result<Self, Error> {
54 setup::<E>(&pool).await?;
55 Ok(Self::new_uninitialized(pool, serde))
56 }
57 pub fn new_uninitialized(pool: PgPool, serde: S) -> Self {
73 const MAX_APPENDS_CONNECTIONS_PERCENT: f64 = 0.5;
74 let concurrent_appends = Arc::new(Semaphore::new(
75 (pool.options().get_max_connections() as f64 * MAX_APPENDS_CONNECTIONS_PERCENT).ceil()
76 as usize,
77 ));
78 Self {
79 pool,
80 concurrent_appends,
81 serde,
82 event_type: PhantomData,
83 }
84 }
85
86 pub fn with_max_appends_connections_percent(mut self, percentage: f64) -> Self {
103 assert!(
104 (0.0..=1.0).contains(&percentage),
105 "percentage must be between 0 and 1"
106 );
107
108 self.concurrent_appends = Arc::new(Semaphore::new(
109 (self.pool.options().get_max_connections() as f64 * percentage).ceil() as usize,
110 ));
111 self
112 }
113}
114
115#[async_trait]
121impl<E, S> EventStore<PgEventId, E> for PgEventStore<E, S>
122where
123 E: Event + Send + Sync,
124 S: Serde<E> + Send + Sync,
125{
126 type Error = Error;
127
128 fn stream<'a, QE>(
144 &'a self,
145 query: &'a StreamQuery<PgEventId, QE>,
146 ) -> BoxStream<'a, Result<PersistedEvent<PgEventId, QE>, Self::Error>>
147 where
148 QE: TryFrom<E> + Event + 'static + Clone + Send + Sync,
149 <QE as TryFrom<E>>::Error: StdError + 'static + Send + Sync,
150 {
151 stream! {
152 let epoch: i64 = sqlx::query_scalar("SELECT event_store_current_epoch()").fetch_one(&self.pool).await?;
153 let sql = format!("SELECT event_id, payload FROM event WHERE event_id <= {epoch} AND ({}) ORDER BY event_id ASC", CriteriaBuilder::new(query).build());
154
155 for await row in sqlx::query(&sql)
156 .fetch(&self.pool) {
157 let row = row?;
158 let id = row.get(0);
159
160 let payload = self.serde.deserialize(row.get(1))?;
161 yield Ok(PersistedEvent::new(id, payload.try_into().map_err(|e| Error::QueryEventMapping(Box::new(e)))?));
162 }
163 }
164 .boxed()
165 }
166
167 async fn append<QE>(
189 &self,
190 events: Vec<E>,
191 query: StreamQuery<PgEventId, QE>,
192 version: PgEventId,
193 ) -> Result<Vec<PersistedEvent<PgEventId, E>>, Self::Error>
194 where
195 E: Clone + 'async_trait,
196 QE: Event + Clone + Send + Sync,
197 {
198 let mut persisted_events = Vec::with_capacity(events.len());
199 let mut persisted_events_ids: Vec<PgEventId> = Vec::with_capacity(events.len());
200 let _permit = self.concurrent_appends.acquire().await?;
201 let mut tx = self.pool.begin().await?;
202 sqlx::query("SELECT event_store_begin_epoch()")
203 .execute(&mut *tx)
204 .await?;
205 for event in events {
206 let mut staged_event_insert = InsertEventSequenceBuilder::new(&event);
207 let row = staged_event_insert.build().fetch_one(&self.pool).await?;
208 persisted_events_ids.push(row.get(0));
209 persisted_events.push(PersistedEvent::new(row.get(0), event));
210 }
211
212 let Some(last_event_id) = persisted_events_ids.last().copied() else {
213 return Ok(vec![]);
214 };
215 sqlx::query(&format!(r#"UPDATE event_sequence es SET consumed = consumed + 1, committed = (es.event_id = ANY($1))
216 FROM (SELECT event_id FROM event_sequence WHERE event_id = ANY($1)
217 OR ((consumed = 0 OR committed = true)
218 AND (event_id <= $2 AND ({}))) ORDER BY event_id FOR UPDATE) upd WHERE es.event_id = upd.event_id"#,
219 CriteriaBuilder::new(&query.change_origin(version)).build()))
220 .bind(persisted_events_ids)
221 .bind(last_event_id)
222 .execute(&mut *tx)
223 .await
224 .map_err(map_concurrency_err)?;
225
226 InsertEventsBuilder::new(persisted_events.as_slice(), &self.serde)
227 .build()
228 .execute(&mut *tx)
229 .await?;
230
231 tx.commit().await?;
232
233 Ok(persisted_events)
234 }
235
236 async fn append_without_validation(
248 &self,
249 events: Vec<E>,
250 ) -> Result<Vec<PersistedEvent<PgEventId, E>>, Self::Error>
251 where
252 E: Clone + 'async_trait,
253 {
254 let mut persisted_events = Vec::with_capacity(events.len());
255 let mut persisted_events_ids: Vec<PgEventId> = Vec::with_capacity(events.len());
256 let _permit = self.concurrent_appends.acquire().await?;
257 let mut tx = self.pool.begin().await?;
258 sqlx::query("SELECT event_store_begin_epoch()")
259 .execute(&mut *tx)
260 .await?;
261 for event in events {
262 let mut sequence_insert = InsertEventSequenceBuilder::new(&event).with_consumed(true);
263 let row = sequence_insert.build().fetch_one(&self.pool).await?;
264 persisted_events_ids.push(row.get(0));
265 persisted_events.push(PersistedEvent::new(row.get(0), event));
266 }
267
268 sqlx::query("UPDATE event_sequence es SET committed = true WHERE event_id = ANY($1)")
269 .bind(persisted_events_ids)
270 .execute(&mut *tx)
271 .await
272 .map_err(map_concurrency_err)?;
273
274 InsertEventsBuilder::new(persisted_events.as_slice(), &self.serde)
275 .build()
276 .execute(&mut *tx)
277 .await?;
278
279 tx.commit().await?;
280
281 Ok(persisted_events)
282 }
283}
284
285pub async fn setup<E: Event>(pool: &PgPool) -> Result<(), Error> {
286 const RESERVED_NAMES: &[&str] = &["event_id", "payload", "event_type", "inserted_at"];
287
288 sqlx::query(include_str!("event_store/sql/table_event.sql"))
289 .execute(pool)
290 .await?;
291 sqlx::query(include_str!("event_store/sql/idx_event_type.sql"))
292 .execute(pool)
293 .await?;
294 sqlx::query(include_str!("event_store/sql/table_event_sequence.sql"))
295 .execute(pool)
296 .await?;
297 sqlx::query(include_str!("event_store/sql/idx_event_sequence_type.sql"))
298 .execute(pool)
299 .await?;
300 sqlx::query(include_str!(
301 "event_store/sql/idx_event_sequence_committed.sql"
302 ))
303 .execute(pool)
304 .await?;
305 sqlx::query(include_str!(
306 "event_store/sql/fn_event_store_current_epoch.sql"
307 ))
308 .execute(pool)
309 .await?;
310 sqlx::query(include_str!(
311 "event_store/sql/fn_event_store_begin_epoch.sql"
312 ))
313 .execute(pool)
314 .await?;
315
316 for domain_identifier in E::SCHEMA.domain_identifiers {
317 if RESERVED_NAMES.contains(&domain_identifier.ident) {
318 panic!("Domain identifier name {domain_identifier} is reserved. Please use a different name.", domain_identifier = domain_identifier.ident);
319 }
320 add_domain_identifier_column(pool, "event", domain_identifier).await?;
321 add_domain_identifier_column(pool, "event_sequence", domain_identifier).await?;
322 }
323 Ok(())
324}
325
326fn map_concurrency_err(err: sqlx::Error) -> Error {
327 if let sqlx::Error::Database(ref description) = err {
328 if description.code().as_deref() == Some("23514") {
329 return Error::Concurrency;
330 }
331 }
332 Error::Database(err)
333}
334
335async fn add_domain_identifier_column(
336 pool: &PgPool,
337 table: &str,
338 domain_identifier: &DomainIdentifierInfo,
339) -> Result<(), Error> {
340 let column_name = domain_identifier.ident;
341 let sql_type = match domain_identifier.type_info {
342 disintegrate::IdentifierType::String => "TEXT",
343 disintegrate::IdentifierType::i64 => "BIGINT",
344 disintegrate::IdentifierType::Uuid => "UUID",
345 };
346 sqlx::query(&format!(
347 "ALTER TABLE {table} ADD COLUMN IF NOT EXISTS {column_name} {sql_type}"
348 ))
349 .execute(pool)
350 .await?;
351
352 sqlx::query(&format!(
353 "CREATE INDEX IF NOT EXISTS idx_{table}_{column_name} ON {table} USING HASH ({column_name}) WHERE {column_name} IS NOT NULL"
354 ))
355 .execute(pool)
356 .await?;
357 Ok(())
358}