disintegrate_postgres/
event_store.rs1mod append;
6mod query;
7#[cfg(test)]
8mod tests;
9
10use append::{InsertEventSequenceBuilder, InsertEventsBuilder};
11use futures::stream::BoxStream;
12use query::CriteriaBuilder;
13use sqlx::postgres::{PgPool, PgPoolOptions};
14use sqlx::Row;
15use std::error::Error as StdError;
16
17use std::marker::PhantomData;
18
19use crate::{Error, Migrator, PgEventId};
20use async_stream::stream;
21use async_trait::async_trait;
22use disintegrate::EventStore;
23use disintegrate::{Event, PersistedEvent};
24use disintegrate::{StreamItem, StreamQuery};
25use disintegrate_serde::Serde;
26
27use futures::StreamExt;
28
29#[derive(Clone)]
31pub struct PgEventStore<E, S>
32where
33 S: Serde<E> + Send + Sync,
34{
35 pub(crate) pool: PgPool,
36 sequence_pool: PgPool,
37 serde: S,
38 event_type: PhantomData<E>,
39}
40
41impl<E, S> PgEventStore<E, S>
42where
43 S: Serde<E> + Send + Sync + Clone,
44 E: Event + Clone,
45{
46 pub async fn try_new(pool: PgPool, serde: S) -> Result<Self, Error> {
53 let event_store = Self::new_uninitialized(pool, serde);
54 Migrator::new(event_store.clone())
55 .init_event_store()
56 .await?;
57 Ok(event_store)
58 }
59 pub fn new_uninitialized(pool: PgPool, serde: S) -> Self {
75 let main_connections = pool.options().get_max_connections();
76
77 let sequence_connections = std::cmp::max(2, (main_connections as f32 * 0.25).ceil() as u32);
79
80 let sequence_pool = PgPoolOptions::new()
81 .max_connections(sequence_connections)
82 .connect_lazy_with((*pool.connect_options()).clone());
83
84 Self {
85 pool,
86 sequence_pool,
87 serde,
88 event_type: PhantomData,
89 }
90 }
91
92 pub fn with_sequence_pool_connections(mut self, connections: u32) -> Self {
105 self.sequence_pool = PgPoolOptions::new()
106 .max_connections(connections.min(self.pool.options().get_max_connections()))
107 .connect_lazy_with((*self.pool.connect_options()).clone());
108 self
109 }
110}
111
112impl<E, S> PgEventStore<E, S>
113where
114 S: Serde<E> + Send + Sync,
115 E: Event + Send + Sync,
116{
117 pub(crate) fn stream_with<'a, QE, EX>(
129 &'a self,
130 executor: EX,
131 query: &'a StreamQuery<PgEventId, QE>,
132 ) -> BoxStream<'a, Result<StreamItem<PgEventId, QE>, Error>>
133 where
134 EX: sqlx::PgExecutor<'a> + Send + Sync + 'a,
135 QE: TryFrom<E> + Event + Clone + Send + Sync + 'static,
136 <QE as TryFrom<E>>::Error: StdError + Send + Sync + 'static,
137 {
138 let sql = format!(
139 r#"SELECT event.event_id, event.payload, epoch.__epoch_id
140 FROM (values (event_store_current_epoch())) AS epoch(__epoch_id)
141 LEFT JOIN event ON event.event_id <= epoch.__epoch_id AND ({criteria})
142 ORDER BY event_id ASC"#,
143 criteria = CriteriaBuilder::new(query).build()
144 );
145
146 stream! {
147 let mut rows = sqlx::query(&sql).fetch(executor);
148 let mut epoch_id: PgEventId = 0;
149 while let Some(row) = rows.next().await {
150 let row = row?;
151 let event_id: Option<i64> = row.get(0);
152 epoch_id = row.get(2);
153 if let Some(event_id) = event_id {
154 let payload = self.serde.deserialize(row.get(1))?;
155 let payload: QE = payload
156 .try_into()
157 .map_err(|e| Error::QueryEventMapping(Box::new(e)))?;
158 yield Ok(StreamItem::Event(PersistedEvent::new(event_id, payload)));
159 }
160 }
161 yield Ok(StreamItem::End(epoch_id));
162 }
163 .boxed()
164 }
165}
166
167#[async_trait]
173impl<E, S> EventStore<PgEventId, E> for PgEventStore<E, S>
174where
175 E: Event + Send + Sync,
176 S: Serde<E> + Send + Sync,
177{
178 type Error = Error;
179
180 fn stream<'a, QE>(
196 &'a self,
197 query: &'a StreamQuery<PgEventId, QE>,
198 ) -> BoxStream<'a, Result<StreamItem<PgEventId, QE>, Self::Error>>
199 where
200 QE: TryFrom<E> + Event + 'static + Clone + Send + Sync,
201 <QE as TryFrom<E>>::Error: StdError + 'static + Send + Sync,
202 {
203 self.stream_with(&self.pool, query)
204 }
205
206 async fn append<QE>(
228 &self,
229 events: Vec<E>,
230 query: StreamQuery<PgEventId, QE>,
231 version: PgEventId,
232 ) -> Result<Vec<PersistedEvent<PgEventId, E>>, Self::Error>
233 where
234 E: Clone + 'async_trait,
235 QE: Event + Clone + Send + Sync,
236 {
237 let mut tx = self.pool.begin().await?;
238 sqlx::query("SELECT event_store_begin_epoch()")
239 .execute(&mut *tx)
240 .await?;
241 let mut sequence_insert = InsertEventSequenceBuilder::new(&events);
242 let event_ids: Vec<PgEventId> = sequence_insert
243 .build()
244 .fetch_all(&self.sequence_pool)
245 .await?
246 .into_iter()
247 .map(|r| r.get(0))
248 .collect();
249
250 let Some(last_event_id) = event_ids.last().copied() else {
251 return Ok(vec![]);
252 };
253
254 sqlx::query(&format!(r#"UPDATE event_sequence es SET consumed = consumed + 1, committed = (es.event_id = ANY($1))
255 FROM (SELECT event_id FROM event_sequence WHERE event_id = ANY($1)
256 OR ((consumed = 0 OR committed = true)
257 AND (event_id <= $2 AND ({}))) ORDER BY event_id FOR UPDATE) upd WHERE es.event_id = upd.event_id"#,
258 CriteriaBuilder::new(&query.change_origin(version)).build()))
259 .bind(&event_ids)
260 .bind(last_event_id)
261 .execute(&mut *tx)
262 .await
263 .map_err(map_concurrency_err)?;
264
265 let persisted_events = event_ids
266 .iter()
267 .zip(events)
268 .map(|(event_id, event)| PersistedEvent::new(*event_id, event))
269 .collect::<Vec<_>>();
270 InsertEventsBuilder::new(persisted_events.as_slice(), &self.serde)
271 .build()
272 .execute(&mut *tx)
273 .await?;
274
275 tx.commit().await?;
276
277 Ok(persisted_events)
278 }
279
280 async fn append_without_validation(
292 &self,
293 events: Vec<E>,
294 ) -> Result<Vec<PersistedEvent<PgEventId, E>>, Self::Error>
295 where
296 E: Clone + 'async_trait,
297 {
298 let mut tx = self.pool.begin().await?;
299 sqlx::query("SELECT event_store_begin_epoch()")
300 .execute(&mut *tx)
301 .await?;
302 let mut sequence_insert = InsertEventSequenceBuilder::new(&events).with_consumed(true);
303 let event_ids: Vec<PgEventId> = sequence_insert
304 .build()
305 .fetch_all(&self.sequence_pool)
306 .await?
307 .into_iter()
308 .map(|r| r.get(0))
309 .collect();
310
311 sqlx::query("UPDATE event_sequence es SET committed = true WHERE event_id = ANY($1)")
312 .bind(&event_ids)
313 .execute(&mut *tx)
314 .await
315 .map_err(map_concurrency_err)?;
316
317 let persisted_events = event_ids
318 .iter()
319 .zip(events)
320 .map(|(event_id, event)| PersistedEvent::new(*event_id, event))
321 .collect::<Vec<_>>();
322 InsertEventsBuilder::new(persisted_events.as_slice(), &self.serde)
323 .build()
324 .execute(&mut *tx)
325 .await?;
326
327 tx.commit().await?;
328
329 Ok(persisted_events)
330 }
331}
332
333fn map_concurrency_err(err: sqlx::Error) -> Error {
334 if let sqlx::Error::Database(ref description) = err {
335 if description.code().as_deref() == Some("23514") {
336 return Error::Concurrency;
337 }
338 }
339 Error::Database(err)
340}