Skip to main content

sourcery_postgres/
lib.rs

1//! Postgres-backed event sourcing implementations.
2//!
3//! This crate provides `PostgreSQL` implementations of the core Sourcery
4//! traits:
5//!
6//! - [`Store`] - An implementation of [`sourcery_core::store::EventStore`]
7//! - [`snapshot::Store`] - An implementation of
8//!   [`sourcery_core::snapshot::SnapshotStore`]
9//!
10//! Both use the same database and can share a connection pool.
11
12pub mod snapshot;
13
14use std::marker::PhantomData;
15
16use nonempty::NonEmpty;
17use serde::{Serialize, de::DeserializeOwned};
18use sourcery_core::{
19    concurrency::ConcurrencyConflict,
20    event::DomainEvent,
21    store::{
22        CommitError, Committed, EventFilter, EventStore, GloballyOrderedStore, LoadEventsResult,
23        OptimisticCommitError, StoredEvent,
24    },
25};
26use sqlx::{PgPool, Postgres, QueryBuilder, Row};
27
28#[derive(Debug, thiserror::Error)]
29pub enum Error {
30    #[error("database error: {0}")]
31    Database(#[from] sqlx::Error),
32    #[error("invalid position value from database: {0}")]
33    InvalidPosition(i64),
34    #[error("database did not return an inserted position")]
35    MissingReturnedPosition,
36    #[error("serialization error: {0}")]
37    Serialization(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
38    #[error("deserialization error: {0}")]
39    Deserialization(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
40}
41
42/// A PostgreSQL-backed [`EventStore`].
43///
44/// Defaults are intentionally conservative:
45/// - Positions are global and monotonic (`i64`, backed by `BIGSERIAL`).
46/// - Metadata is stored as `jsonb` (`M: Serialize + DeserializeOwned`).
47/// - Event data is stored as `jsonb`.
48#[derive(Clone)]
49pub struct Store<M> {
50    pool: PgPool,
51    _phantom: PhantomData<M>,
52}
53
54impl<M> Store<M> {
55    #[must_use]
56    pub const fn new(pool: PgPool) -> Self {
57        Self {
58            pool,
59            _phantom: PhantomData,
60        }
61    }
62}
63
64impl<M> Store<M>
65where
66    M: Sync,
67{
68    /// Apply the initial schema (idempotent).
69    ///
70    /// This uses `CREATE TABLE IF NOT EXISTS` style DDL so it can be run on
71    /// startup.
72    ///
73    /// # Errors
74    ///
75    /// Returns a `sqlx::Error` if any of the schema creation queries fail.
76    #[tracing::instrument(skip(self))]
77    pub async fn migrate(&self) -> Result<(), sqlx::Error> {
78        // Streams track per-aggregate last position for optimistic concurrency.
79        sqlx::query(
80            r"
81            CREATE TABLE IF NOT EXISTS es_streams (
82                aggregate_kind TEXT NOT NULL,
83                aggregate_id   UUID NOT NULL,
84                last_position  BIGINT NULL,
85                PRIMARY KEY (aggregate_kind, aggregate_id)
86            )
87            ",
88        )
89        .execute(&self.pool)
90        .await?;
91
92        sqlx::query(
93            r"
94            CREATE TABLE IF NOT EXISTS es_events (
95                position       BIGSERIAL PRIMARY KEY,
96                aggregate_kind TEXT NOT NULL,
97                aggregate_id   UUID NOT NULL,
98                event_kind     TEXT NOT NULL,
99                data           JSONB NOT NULL,
100                metadata       JSONB NOT NULL,
101                created_at     TIMESTAMPTZ NOT NULL DEFAULT now()
102            )
103            ",
104        )
105        .execute(&self.pool)
106        .await?;
107
108        sqlx::query(
109            r"CREATE INDEX IF NOT EXISTS es_events_by_kind_and_position ON es_events(event_kind, position)",
110        )
111        .execute(&self.pool)
112        .await?;
113
114        sqlx::query(
115            r"CREATE INDEX IF NOT EXISTS es_events_by_stream_and_position ON es_events(aggregate_kind, aggregate_id, position)",
116        )
117        .execute(&self.pool)
118        .await?;
119
120        Ok(())
121    }
122}
123
124impl<M> EventStore for Store<M>
125where
126    M: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
127{
128    type Data = serde_json::Value;
129    type Error = Error;
130    type Id = uuid::Uuid;
131    type Metadata = M;
132    type Position = i64;
133
134    fn decode_event<E>(
135        &self,
136        stored: &StoredEvent<Self::Id, Self::Position, Self::Data, Self::Metadata>,
137    ) -> Result<E, Self::Error>
138    where
139        E: DomainEvent + serde::de::DeserializeOwned,
140    {
141        serde_json::from_value(stored.data.clone()).map_err(|e| Error::Deserialization(Box::new(e)))
142    }
143
144    async fn stream_version<'a>(
145        &'a self,
146        aggregate_kind: &'a str,
147        aggregate_id: &'a Self::Id,
148    ) -> Result<Option<Self::Position>, Self::Error> {
149        let result: Option<i64> = sqlx::query_scalar(
150            r"SELECT last_position FROM es_streams WHERE aggregate_kind = $1 AND aggregate_id = $2",
151        )
152        .bind(aggregate_kind)
153        .bind(aggregate_id)
154        .fetch_optional(&self.pool)
155        .await?
156        .flatten();
157
158        Ok(result)
159    }
160
161    #[tracing::instrument(
162        skip(self, events, metadata),
163        fields(
164            aggregate_kind,
165            aggregate_id = %aggregate_id,
166            events_len = events.len()
167        )
168    )]
169    async fn commit_events<'a, E>(
170        &'a self,
171        aggregate_kind: &'a str,
172        aggregate_id: &'a Self::Id,
173        events: NonEmpty<E>,
174        metadata: &'a Self::Metadata,
175    ) -> Result<Committed<i64>, CommitError<Self::Error>>
176    where
177        E: sourcery_core::event::EventKind + serde::Serialize + Send + Sync + 'a,
178        Self::Metadata: Clone,
179    {
180        // Serialize all events first
181        let mut prepared: Vec<(String, serde_json::Value)> = Vec::with_capacity(events.len());
182        for (index, event) in events.iter().enumerate() {
183            let data = serde_json::to_value(event).map_err(|e| CommitError::Serialization {
184                index,
185                source: Error::Serialization(Box::new(e)),
186            })?;
187            prepared.push((event.kind().to_string(), data));
188        }
189
190        let mut tx = self
191            .pool
192            .begin()
193            .await
194            .map_err(|e| CommitError::Store(Error::Database(e)))?;
195
196        sqlx::query(
197            r"
198                INSERT INTO es_streams (aggregate_kind, aggregate_id, last_position)
199                VALUES ($1, $2, NULL)
200                ON CONFLICT (aggregate_kind, aggregate_id) DO NOTHING
201                ",
202        )
203        .bind(aggregate_kind)
204        .bind(aggregate_id)
205        .execute(&mut *tx)
206        .await
207        .map_err(|e| CommitError::Store(Error::Database(e)))?;
208
209        let mut qb = QueryBuilder::<Postgres>::new(
210            "INSERT INTO es_events (aggregate_kind, aggregate_id, event_kind, data, metadata) ",
211        );
212        qb.push_values(prepared, |mut b, (kind, data)| {
213            b.push_bind(aggregate_kind);
214            b.push_bind(aggregate_id);
215            b.push_bind(kind);
216            b.push_bind(sqlx::types::Json(data));
217            b.push_bind(sqlx::types::Json(metadata.clone()));
218        });
219        qb.push(" RETURNING position");
220
221        let rows: Vec<i64> = qb
222            .build_query_scalar()
223            .fetch_all(&mut *tx)
224            .await
225            .map_err(|e| CommitError::Store(Error::Database(e)))?;
226
227        let last_position = rows
228            .last()
229            .ok_or_else(|| CommitError::Store(Error::MissingReturnedPosition))?;
230
231        sqlx::query(
232            r"
233                UPDATE es_streams
234                SET last_position = $1
235                WHERE aggregate_kind = $2 AND aggregate_id = $3
236                ",
237        )
238        .bind(last_position)
239        .bind(aggregate_kind)
240        .bind(aggregate_id)
241        .execute(&mut *tx)
242        .await
243        .map_err(|e| CommitError::Store(Error::Database(e)))?;
244
245        tx.commit()
246            .await
247            .map_err(|e| CommitError::Store(Error::Database(e)))?;
248
249        Ok(Committed {
250            last_position: *last_position,
251        })
252    }
253
254    #[tracing::instrument(
255        skip(self, events, metadata),
256        fields(
257            aggregate_kind,
258            aggregate_id = %aggregate_id,
259            expected_version,
260            events_len = events.len()
261        )
262    )]
263    async fn commit_events_optimistic<'a, E>(
264        &'a self,
265        aggregate_kind: &'a str,
266        aggregate_id: &'a Self::Id,
267        expected_version: Option<Self::Position>,
268        events: NonEmpty<E>,
269        metadata: &'a Self::Metadata,
270    ) -> Result<Committed<i64>, OptimisticCommitError<i64, Self::Error>>
271    where
272        E: sourcery_core::event::EventKind + serde::Serialize + Send + Sync + 'a,
273        Self::Metadata: Clone,
274    {
275        // Serialize all events first
276        let mut prepared: Vec<(String, serde_json::Value)> = Vec::with_capacity(events.len());
277        for (index, event) in events.iter().enumerate() {
278            let data =
279                serde_json::to_value(event).map_err(|e| OptimisticCommitError::Serialization {
280                    index,
281                    source: Error::Serialization(Box::new(e)),
282                })?;
283            prepared.push((event.kind().to_string(), data));
284        }
285
286        let mut tx = self
287            .pool
288            .begin()
289            .await
290            .map_err(|e| OptimisticCommitError::Store(Error::Database(e)))?;
291
292        sqlx::query(
293            r"
294                INSERT INTO es_streams (aggregate_kind, aggregate_id, last_position)
295                VALUES ($1, $2, NULL)
296                ON CONFLICT (aggregate_kind, aggregate_id) DO NOTHING
297                ",
298        )
299        .bind(aggregate_kind)
300        .bind(aggregate_id)
301        .execute(&mut *tx)
302        .await
303        .map_err(|e| OptimisticCommitError::Store(Error::Database(e)))?;
304
305        let current: Option<i64> = sqlx::query_scalar::<_, Option<i64>>(
306            r"
307                SELECT last_position
308                FROM es_streams
309                WHERE aggregate_kind = $1 AND aggregate_id = $2
310                FOR UPDATE
311                ",
312        )
313        .bind(aggregate_kind)
314        .bind(aggregate_id)
315        .fetch_one(&mut *tx)
316        .await
317        .map_err(|e| OptimisticCommitError::Store(Error::Database(e)))?;
318
319        // Version check
320        match expected_version {
321            Some(expected) => {
322                if current != Some(expected) {
323                    return Err(OptimisticCommitError::Conflict(ConcurrencyConflict {
324                        expected: Some(expected),
325                        actual: current,
326                    }));
327                }
328            }
329            None => {
330                // Expected new stream (no events)
331                if let Some(actual) = current {
332                    return Err(OptimisticCommitError::Conflict(ConcurrencyConflict {
333                        expected: None,
334                        actual: Some(actual),
335                    }));
336                }
337            }
338        }
339
340        let mut qb = QueryBuilder::<Postgres>::new(
341            "INSERT INTO es_events (aggregate_kind, aggregate_id, event_kind, data, metadata) ",
342        );
343        qb.push_values(prepared, |mut b, (kind, data)| {
344            b.push_bind(aggregate_kind);
345            b.push_bind(aggregate_id);
346            b.push_bind(kind);
347            b.push_bind(sqlx::types::Json(data));
348            b.push_bind(sqlx::types::Json(metadata.clone()));
349        });
350        qb.push(" RETURNING position");
351
352        let rows: Vec<i64> = qb
353            .build_query_scalar()
354            .fetch_all(&mut *tx)
355            .await
356            .map_err(|e| OptimisticCommitError::Store(Error::Database(e)))?;
357
358        let last_position = rows
359            .last()
360            .ok_or_else(|| OptimisticCommitError::Store(Error::MissingReturnedPosition))?;
361
362        sqlx::query(
363            r"
364                UPDATE es_streams
365                SET last_position = $1
366                WHERE aggregate_kind = $2 AND aggregate_id = $3
367                ",
368        )
369        .bind(last_position)
370        .bind(aggregate_kind)
371        .bind(aggregate_id)
372        .execute(&mut *tx)
373        .await
374        .map_err(|e| OptimisticCommitError::Store(Error::Database(e)))?;
375
376        tx.commit()
377            .await
378            .map_err(|e| OptimisticCommitError::Store(Error::Database(e)))?;
379
380        Ok(Committed {
381            last_position: *last_position,
382        })
383    }
384
385    #[allow(clippy::type_complexity)]
386    #[tracing::instrument(skip(self, filters), fields(filters_len = filters.len()))]
387    async fn load_events<'a>(
388        &'a self,
389        filters: &'a [EventFilter<Self::Id, Self::Position>],
390    ) -> LoadEventsResult<Self::Id, Self::Position, Self::Data, Self::Metadata, Self::Error> {
391        if filters.is_empty() {
392            return Ok(Vec::new());
393        }
394
395        let mut qb = QueryBuilder::<Postgres>::new(
396            "SELECT aggregate_kind, aggregate_id, event_kind, position, data, metadata FROM (",
397        );
398
399        for (i, filter) in filters.iter().enumerate() {
400            if i > 0 {
401                qb.push(" UNION ALL ");
402            }
403
404            qb.push(
405                "SELECT aggregate_kind, aggregate_id, event_kind, position, data, metadata FROM \
406                 es_events WHERE event_kind = ",
407            )
408            .push_bind(&filter.event_kind);
409
410            if let Some(kind) = &filter.aggregate_kind {
411                qb.push(" AND aggregate_kind = ").push_bind(kind);
412            }
413
414            if let Some(id) = &filter.aggregate_id {
415                qb.push(" AND aggregate_id = ").push_bind(id);
416            }
417
418            if let Some(after) = filter.after_position {
419                if after < 0 {
420                    return Err(Error::InvalidPosition(after));
421                }
422                qb.push(" AND position > ").push_bind(after);
423            }
424        }
425
426        qb.push(") t ORDER BY position ASC");
427
428        let rows = qb.build().fetch_all(&self.pool).await?;
429
430        let mut out = Vec::with_capacity(rows.len());
431        for row in rows {
432            let aggregate_kind: String = row.try_get("aggregate_kind")?;
433            let aggregate_id: uuid::Uuid = row.try_get("aggregate_id")?;
434            let event_kind: String = row.try_get("event_kind")?;
435            let position: i64 = row.try_get("position")?;
436            let data: sqlx::types::Json<serde_json::Value> = row.try_get("data")?;
437            let metadata: sqlx::types::Json<M> = row.try_get("metadata")?;
438
439            out.push(StoredEvent {
440                aggregate_kind,
441                aggregate_id,
442                kind: event_kind,
443                position,
444                data: data.0,
445                metadata: metadata.0,
446            });
447        }
448
449        Ok(out)
450    }
451}
452
453impl<M> GloballyOrderedStore for Store<M> where
454    M: Serialize + DeserializeOwned + Clone + Send + Sync + 'static
455{
456}