1use crate::pool::Pool;
2use error_ext::BoxError;
3use futures::{future::ok, Stream, StreamExt, TryStreamExt};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use sqlx::{postgres::PgRow, Encode, Postgres, Row, Transaction, Type};
7use std::{
8 fmt::{Debug, Display},
9 num::NonZeroU64,
10};
11use thiserror::Error;
12use tracing::instrument;
13
14#[trait_variant::make(Send)]
16pub trait Command {
17 type Entity: Entity;
19
20 type Rejection: Debug;
22
23 async fn handle(
26 self,
27 id: &<Self::Entity as Entity>::Id,
28 entity: &Self::Entity,
29 ) -> Result<
30 Vec<
31 impl Into<
32 EventWithMetadata<
33 <Self::Entity as Entity>::Event,
34 <Self::Entity as Entity>::Metadata,
35 >,
36 >,
37 >,
38 Self::Rejection,
39 >;
40}
41
42pub trait Entity {
44 type Id: Debug
46 + Display
47 + for<'q> Encode<'q, Postgres>
48 + Type<Postgres>
49 + Serialize
50 + for<'de> Deserialize<'de>
51 + Sync;
52
53 type Event: Debug + Serialize + for<'de> Deserialize<'de> + Sync;
55
56 type Metadata: Debug + Serialize + Sync;
58
59 const TYPE_NAME: &'static str;
61
62 fn handle_event(&mut self, event: Self::Event);
64}
65
66#[allow(async_fn_in_trait)]
68pub trait EntityExt
69where
70 Self: Entity + Sized,
71{
72 fn entity(self) -> EventSourcedEntityBuilder<Self, NoOpEventListener> {
75 EventSourcedEntityBuilder {
76 entity: self,
77 listener: None,
78 }
79 }
80}
81
82impl<E> EntityExt for E where E: Entity {}
83
84#[derive(Debug)]
86pub struct EventWithMetadata<E, M> {
87 pub event: E,
88 pub metadata: M,
89}
90
91impl<E> From<E> for EventWithMetadata<E, ()> {
92 fn from(event: E) -> Self {
93 EventWithMetadata {
94 event,
95 metadata: (),
96 }
97 }
98}
99
100pub trait EventExt
102where
103 Self: Sized,
104{
105 fn with_metadata<M>(self, metadata: M) -> EventWithMetadata<Self, M> {
107 EventWithMetadata {
108 event: self,
109 metadata,
110 }
111 }
112}
113
114impl<E> EventExt for E {}
115
116pub struct EventSourcedEntityBuilder<E, L> {
118 entity: E,
119 listener: Option<L>,
120}
121
122impl<E, L> EventSourcedEntityBuilder<E, L>
123where
124 E: Entity,
125{
126 pub fn with_listener<T>(self, listener: T) -> EventSourcedEntityBuilder<E, T> {
128 EventSourcedEntityBuilder {
129 entity: self.entity,
130 listener: Some(listener),
131 }
132 }
133
134 pub async fn build(self, id: E::Id, pool: Pool) -> Result<EventSourcedEntity<E, L>, Error> {
136 let events = current_events_by_id::<E>(&id, &pool).await;
137 let (last_version, entity) = events
138 .try_fold((None, self.entity), |(_, mut state), (version, event)| {
139 state.handle_event(event);
140 ok((Some(version), state))
141 })
142 .await?;
143
144 Ok(EventSourcedEntity {
145 entity,
146 id,
147 last_version,
148 pool,
149 listener: self.listener,
150 })
151 }
152}
153
154pub struct EventSourcedEntity<E, L>
156where
157 E: Entity,
158{
159 entity: E,
160 listener: Option<L>,
161 id: E::Id,
162 pool: Pool,
163 last_version: Option<NonZeroU64>,
164}
165
166impl<E, L> EventSourcedEntity<E, L>
167where
168 E: Entity,
169 L: EventListener<E::Event, E::Metadata>,
170{
171 pub async fn handle_command<C>(&mut self, command: C) -> Result<Result<&E, C::Rejection>, Error>
174 where
175 C: Command<Entity = E>,
176 {
177 let result = command.handle(&self.id, &self.entity).await.map(|es| {
178 es.into_iter()
179 .map(|into_ewm| into_ewm.into())
180 .collect::<Vec<_>>()
181 });
182 match result {
183 Ok(ewms) => {
184 if !ewms.is_empty() {
185 let version = persist::<E, _>(
186 &self.id,
187 self.last_version,
188 &ewms,
189 &self.pool,
190 &mut self.listener,
191 )
192 .await?;
193 self.last_version = Some(version);
194
195 for EventWithMetadata { event, .. } in ewms {
196 self.entity.handle_event(event);
197 }
198 }
199
200 Ok(Ok(&self.entity))
201 }
202
203 Err(rejection) => Ok(Err(rejection)),
204 }
205 }
206}
207
208#[trait_variant::make(Send)]
211pub trait EventListener<E, M> {
212 async fn listen(
213 &mut self,
214 ewm: &EventWithMetadata<E, M>,
215 tx: &mut Transaction<'_, Postgres>,
216 ) -> Result<(), BoxError>
217 where
218 E: Sync,
219 M: Sync;
220}
221
222pub struct NoOpEventListener;
224
225impl<E, M> EventListener<E, M> for NoOpEventListener {
226 async fn listen(
227 &mut self,
228 _ewm: &EventWithMetadata<E, M>,
229 _tx: &mut Transaction<'_, Postgres>,
230 ) -> Result<(), BoxError>
231 where
232 E: Sync,
233 M: Sync,
234 {
235 Ok(())
236 }
237}
238
239#[derive(Debug, Error)]
241pub enum Error {
242 #[error("{0}")]
243 Sqlx(String, #[source] sqlx::Error),
244
245 #[error("cannot serialize event")]
246 Ser(#[source] serde_json::Error),
247
248 #[error("cannot deserialize event")]
249 De(#[source] serde_json::Error),
250
251 #[error("expected version {0:?}, but was {1:?}")]
252 UnexpectedVersion(Option<NonZeroU64>, Option<NonZeroU64>),
253
254 #[error("listener error")]
255 Listener(#[source] BoxError),
256}
257
258#[instrument(skip(pool))]
259async fn current_events_by_id<'a, E>(
260 id: &'a E::Id,
261 pool: &'a Pool,
262) -> impl Stream<Item = Result<(NonZeroU64, E::Event), Error>> + Send + 'a
263where
264 E: Entity,
265{
266 sqlx::query(
267 "SELECT version, event
268 FROM event
269 WHERE entity_id = $1 AND type_name = $2
270 ORDER BY seq_no ASC",
271 )
272 .bind(id)
273 .bind(E::TYPE_NAME)
274 .fetch(&**pool)
275 .map_err(|error| Error::Sqlx("cannot get next event".to_string(), error))
276 .map(|row| {
277 row.and_then(|row| {
278 let version = (row.get::<i64, _>(0) as u64)
279 .try_into()
280 .expect("version greater zero");
281 let value = row.get::<Value, _>(1);
282 let event = serde_json::from_value::<E::Event>(value).map_err(Error::De)?;
283 Ok((version, event))
284 })
285 })
286}
287
288#[instrument(skip(ewms, listener))]
289async fn persist<E, L>(
290 id: &E::Id,
291 last_version: Option<NonZeroU64>,
292 ewms: &[EventWithMetadata<E::Event, E::Metadata>],
293 pool: &Pool,
294 listener: &mut Option<L>,
295) -> Result<NonZeroU64, Error>
296where
297 E: Entity,
298 L: EventListener<E::Event, E::Metadata>,
299{
300 assert!(!ewms.is_empty());
301
302 let mut tx = pool
303 .begin()
304 .await
305 .map_err(|error| Error::Sqlx("cannot begin transaction".to_string(), error))?;
306
307 let version = sqlx::query(
308 "SELECT MAX(version)
309 FROM event
310 WHERE entity_id = $1 AND type_name = $2",
311 )
312 .bind(id)
313 .bind(E::TYPE_NAME)
314 .fetch_one(&mut *tx)
315 .await
316 .map_err(|error| Error::Sqlx("cannot select max version".to_string(), error))
317 .map(into_version)?;
318
319 if version != last_version {
320 return Err(Error::UnexpectedVersion(version, last_version));
321 }
322
323 let mut version = last_version.map(|n| n.get() as i64).unwrap_or_default();
324 for ewm @ EventWithMetadata { event, metadata } in ewms.iter() {
325 version += 1;
326 let bytes = serde_json::to_value(event).map_err(Error::Ser)?;
327 let metadata = serde_json::to_value(metadata).map_err(Error::Ser)?;
328 sqlx::query(
329 "INSERT INTO event (entity_id, version, type_name, event, metadata)
330 VALUES ($1, $2, $3, $4, $5)",
331 )
332 .bind(id)
333 .bind(version)
334 .bind(E::TYPE_NAME)
335 .bind(&bytes)
336 .bind(metadata)
337 .execute(&mut *tx)
338 .await
339 .map_err(|error| Error::Sqlx("cannot execute statement".to_string(), error))?;
340
341 if let Some(listener) = listener {
342 listener
343 .listen(ewm, &mut tx)
344 .await
345 .map_err(Error::Listener)?;
346 }
347 }
348
349 tx.commit()
350 .await
351 .map_err(|error| Error::Sqlx("cannot commit transaction".to_string(), error))?;
352
353 let version = (version as u64).try_into().expect("version greater zero");
354 Ok(version)
355}
356
357fn into_version(row: PgRow) -> Option<NonZeroU64> {
358 row.try_get::<i64, _>(0)
360 .ok()
361 .map(|version| (version as u64).try_into().expect("version greater zero"))
362}
363
364#[cfg(test)]
365mod tests {
366 use crate::{
367 entity::{Command, Entity, EntityExt, EventExt, EventListener, EventWithMetadata},
368 pool::{Config, Pool},
369 test::run_postgres,
370 };
371 use error_ext::BoxError;
372 use serde::{Deserialize, Serialize};
373 use serde_json::{json, Value};
374 use sqlx::{postgres::PgSslMode, Executor, Row, Transaction};
375 use std::error::Error as StdError;
376 use time::OffsetDateTime;
377 use uuid::Uuid;
378
379 type TestResult = Result<(), Box<dyn StdError>>;
380
381 #[derive(Debug, Default, PartialEq, Eq)]
382 pub struct Counter(u64);
383
384 impl Entity for Counter {
385 type Id = Uuid;
386 type Event = Event;
387 type Metadata = Metadata;
388
389 const TYPE_NAME: &'static str = "counter";
390
391 fn handle_event(&mut self, event: Self::Event) {
392 match event {
393 Event::Increased { inc, .. } => self.0 += inc,
394 Event::Decreased { dec, .. } => self.0 -= dec,
395 }
396 }
397 }
398
399 #[derive(Debug, Serialize, Deserialize)]
400 pub enum Event {
401 Increased { id: Uuid, inc: u64 },
402 Decreased { id: Uuid, dec: u64 },
403 }
404
405 #[derive(Debug)]
406 pub struct Increase(u64);
407
408 impl Command for Increase {
409 type Entity = Counter;
410 type Rejection = Overflow;
411
412 async fn handle(
413 self,
414 id: &<Self::Entity as Entity>::Id,
415 entity: &Self::Entity,
416 ) -> Result<
417 Vec<
418 impl Into<
419 EventWithMetadata<
420 <Self::Entity as Entity>::Event,
421 <Self::Entity as Entity>::Metadata,
422 >,
423 >,
424 >,
425 Self::Rejection,
426 > {
427 let Increase(inc) = self;
428 if entity.0 > u64::MAX - inc {
429 Err(Overflow)
430 } else {
431 let increased = Event::Increased { id: *id, inc };
432 let metadata = Metadata {
433 timestamp: OffsetDateTime::now_utc(),
434 };
435 Ok(vec![increased.with_metadata(metadata)])
436 }
437 }
438 }
439
440 #[derive(Debug, PartialEq, Eq)]
441 pub struct Overflow;
442
443 #[derive(Debug)]
444 pub struct Decrease(u64);
445
446 impl Command for Decrease {
447 type Entity = Counter;
448 type Rejection = Underflow;
449
450 async fn handle(
451 self,
452 id: &<Self::Entity as Entity>::Id,
453 entity: &Self::Entity,
454 ) -> Result<
455 Vec<
456 impl Into<
457 EventWithMetadata<
458 <Self::Entity as Entity>::Event,
459 <Self::Entity as Entity>::Metadata,
460 >,
461 >,
462 >,
463 Self::Rejection,
464 > {
465 let Decrease(dec) = self;
466 if entity.0 < dec {
467 Err::<Vec<_>, Underflow>(Underflow)
468 } else {
469 let decreased = Event::Decreased { id: *id, dec };
470 let metadata = Metadata {
471 timestamp: OffsetDateTime::now_utc(),
472 };
473 Ok(vec![decreased.with_metadata(metadata)])
474 }
475 }
476 }
477
478 #[derive(Debug, PartialEq, Eq)]
479 pub struct Underflow;
480
481 #[derive(Debug, Serialize, Deserialize)]
482 pub struct Metadata {
483 timestamp: OffsetDateTime,
484 }
485
486 struct Listener;
487
488 impl EventListener<Event, Metadata> for Listener {
489 async fn listen(
490 &mut self,
491 ewm: &EventWithMetadata<Event, Metadata>,
492 tx: &mut Transaction<'_, sqlx::Postgres>,
493 ) -> Result<(), BoxError> {
494 match ewm {
495 EventWithMetadata {
496 event: Event::Increased { id, inc },
497 ..
498 } => {
499 let value = sqlx::query(
500 "SELECT value
501 FROM counters
502 WHERE id = $1",
503 )
504 .bind(id)
505 .fetch_optional(&mut **tx)
506 .await
507 .map_err(Box::new)?
508 .map(|row| row.try_get::<i64, _>(0))
509 .transpose()?;
510 match value {
511 Some(value) => {
512 sqlx::query(
513 "UPDATE counters
514 SET value = $1
515 WHERE id = $2",
516 )
517 .bind(value + *inc as i64)
518 .bind(id)
519 .execute(&mut **tx)
520 .await
521 .map_err(Box::new)?;
522 }
523
524 None => {
525 sqlx::query(
526 "INSERT INTO counters
527 VALUES ($1, $2)",
528 )
529 .bind(id)
530 .bind(*inc as i64)
531 .execute(&mut **tx)
532 .await
533 .map_err(Box::new)?;
534 }
535 }
536 Ok(())
537 }
538
539 _ => Ok(()),
540 }
541 }
542 }
543
544 #[tokio::test]
545 async fn test_load() -> TestResult {
546 let container = run_postgres().await?;
547 let pg_port = container.get_host_port_ipv4(5432).await?;
548
549 let config = Config {
550 host: "localhost".to_string(),
551 port: pg_port,
552 user: "postgres".to_string(),
553 password: "postgres".to_string().into(),
554 dbname: "postgres".to_string(),
555 sslmode: PgSslMode::Prefer,
556 };
557
558 let pool = Pool::new(config).await?;
559 let ddl = include_str!("../sql/create_event_log_uuid.sql");
560 (&*pool).execute(ddl).await?;
561
562 let id = Uuid::from_u128(0);
563 sqlx::query(
564 "INSERT INTO event (entity_id, version, type_name, event, metadata)
565 VALUES ($1, $2, $3, $4, $5)",
566 )
567 .bind(&id)
568 .bind(1_i64)
569 .bind("counter")
570 .bind(serde_json::to_value(&Event::Increased { id, inc: 40 })?)
571 .bind(Value::Null)
572 .execute(&*pool)
573 .await?;
574 sqlx::query(
575 "INSERT INTO event (entity_id, version, type_name, event, metadata)
576 VALUES ($1, $2, $3, $4, $5)",
577 )
578 .bind(&id)
579 .bind(2_i64)
580 .bind("counter")
581 .bind(serde_json::to_value(&Event::Decreased { id, dec: 20 })?)
582 .bind(Value::Null)
583 .execute(&*pool)
584 .await?;
585 sqlx::query(
586 "INSERT INTO event (entity_id, version, type_name, event, metadata)
587 VALUES ($1, $2, $3, $4, $5)",
588 )
589 .bind(&id)
590 .bind(3_i64)
591 .bind("counter")
592 .bind(serde_json::to_value(&Event::Increased { id, inc: 22 })?)
593 .bind(Value::Null)
594 .execute(&*pool)
595 .await?;
596
597 let counter = Counter::default().entity().build(id, pool).await?;
598 assert_eq!(counter.entity.0, 42);
599
600 Ok(())
601 }
602
603 #[tokio::test]
604 async fn test_handle_command() -> TestResult {
605 let container = run_postgres().await?;
606 let pg_port = container.get_host_port_ipv4(5432).await?;
607
608 let config = Config {
609 host: "localhost".to_string(),
610 port: pg_port,
611 user: "postgres".to_string(),
612 password: "postgres".to_string().into(),
613 dbname: "postgres".to_string(),
614 sslmode: PgSslMode::Prefer,
615 };
616
617 let pool = Pool::new(config).await.expect("pool can be created");
618 let ddl = include_str!("../sql/create_event_log_uuid.sql");
619 (&*pool).execute(ddl).await?;
620
621 let id = Uuid::from_u128(0);
622
623 sqlx::query(
625 "INSERT INTO event (entity_id, version, type_name, event, metadata)
626 VALUES ($1, $2, $3, $4, $5)",
627 )
628 .bind(&id)
629 .bind(1_i64)
630 .bind("faker")
631 .bind(json!({ "name": "Meier", "address": "Musterstraße 42" }))
632 .bind(Value::Null)
633 .execute(&*pool)
634 .await?;
635
636 let mut counter = Counter::default().entity().build(id, pool.clone()).await?;
637 assert_eq!(counter.entity, Counter(0));
638
639 let result = counter.handle_command(Decrease(1)).await?;
640 assert_eq!(result, Err(Underflow));
641
642 let result = counter.handle_command(Increase(40)).await?;
643 assert_eq!(result, Ok(&Counter(40)));
644
645 let result = counter.handle_command(Decrease(20)).await?;
646 assert_eq!(result, Ok(&Counter(20)));
647
648 let mut counter = Counter::default().entity().build(id, pool).await?;
650 let result = counter.handle_command(Increase(22)).await?;
651 assert_eq!(result, Ok(&Counter(42)));
652
653 Ok(())
654 }
655
656 #[tokio::test]
657 async fn test_event_listener() -> TestResult {
658 let container = run_postgres().await?;
659 let pg_port = container.get_host_port_ipv4(5432).await?;
660
661 let config = Config {
662 host: "localhost".to_string(),
663 port: pg_port,
664 user: "postgres".to_string(),
665 password: "postgres".to_string().into(),
666 dbname: "postgres".to_string(),
667 sslmode: PgSslMode::Prefer,
668 };
669
670 let pool = Pool::new(config).await.expect("pool can be created");
671 let ddl = include_str!("../sql/create_event_log_uuid.sql");
672 (&*pool).execute(ddl).await?;
673
674 let ddl = "CREATE TABLE
675 IF NOT EXISTS
676 counters (id uuid, value bigint, PRIMARY KEY (id));";
677 (&*pool).execute(ddl).await?;
678
679 let id_0 = Uuid::from_u128(0);
680 let id_1 = Uuid::from_u128(1);
681 let id_2 = Uuid::from_u128(2);
682
683 let _ = Counter::default()
684 .entity()
685 .with_listener(Listener)
686 .build(id_0, pool.clone())
687 .await?;
688 let mut counter_1 = Counter::default()
689 .entity()
690 .with_listener(Listener)
691 .build(id_1, pool.clone())
692 .await?;
693 let mut counter_2 = Counter::default()
694 .entity()
695 .with_listener(Listener)
696 .build(id_2, pool.clone())
697 .await?;
698
699 let _ = counter_1.handle_command(Increase(1)).await?;
700 let _ = counter_2.handle_command(Increase(1)).await?;
701 let _ = counter_2.handle_command(Increase(1)).await?;
702
703 let value = sqlx::query(
704 "SELECT value
705 FROM counters
706 WHERE id = $1",
707 )
708 .bind(id_0)
709 .fetch_optional(&*pool)
710 .await?
711 .map(|row| row.get::<i64, _>(0));
712 assert!(value.is_none());
713
714 let value = sqlx::query(
715 "SELECT value
716 FROM counters
717 WHERE id = $1",
718 )
719 .bind(id_1)
720 .fetch_optional(&*pool)
721 .await?
722 .map(|row| row.get::<i64, _>(0));
723 assert_eq!(value, Some(1));
724
725 let value = sqlx::query(
726 "SELECT value
727 FROM counters
728 WHERE id = $1",
729 )
730 .bind(id_2)
731 .fetch_optional(&*pool)
732 .await?
733 .map(|row| row.get::<i64, _>(0));
734 assert_eq!(value, Some(2));
735
736 Ok(())
737 }
738}