use crate::{
error::{LoadError, PersistError},
util::{BorrowedJson, Json, RawJsonPersist, RawJsonRead, Sequence},
};
use cqrs_core::{
Aggregate, AggregateEvent, AggregateId, Before, DeserializableEvent, EventNumber, EventSink,
EventSource, NeverSnapshot, Precondition, SerializableEvent, Since, SnapshotRecommendation,
SnapshotSink, SnapshotSource, SnapshotStrategy, Version, VersionedAggregate, VersionedEvent,
VersionedEventWithMetadata,
};
use fallible_iterator::{FallibleIterator, IntoFallibleIterator};
use num_traits::FromPrimitive;
use postgres::Connection;
use serde::{de::DeserializeOwned, Serialize};
use std::{fmt, marker::PhantomData};
#[derive(Clone)]
pub struct PostgresStore<'conn, A, E, M, S = NeverSnapshot>
where
A: Aggregate,
E: AggregateEvent<A>,
S: SnapshotStrategy,
{
conn: &'conn Connection,
snapshot_strategy: S,
_phantom: PhantomData<&'conn (A, E, M)>,
}
impl<'conn, A, E, M, S> fmt::Debug for PostgresStore<'conn, A, E, M, S>
where
A: Aggregate,
E: AggregateEvent<A>,
S: SnapshotStrategy + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("PostgresStore")
.field("conn", &*self.conn)
.field("strategy", &self.snapshot_strategy)
.field("phantom", &self._phantom)
.finish()
}
}
impl<'conn, A, E, M, S> PostgresStore<'conn, A, E, M, S>
where
A: Aggregate,
E: AggregateEvent<A>,
S: SnapshotStrategy + Default,
{
const DB_VERSION: u32 = 1;
pub fn new(conn: &'conn Connection) -> Self {
PostgresStore {
conn,
snapshot_strategy: S::default(),
_phantom: PhantomData,
}
}
pub fn with_snapshot_strategy(conn: &'conn Connection, snapshot_strategy: S) -> Self {
PostgresStore {
conn,
snapshot_strategy,
_phantom: PhantomData,
}
}
pub fn create_tables(&self) -> Result<(), postgres::Error> {
self.conn
.batch_execute(include_str!("migrations/00_create_migrations.sql"))?;
let current_version: i32 = self
.conn
.query("SELECT MAX(version) from migrations", &[])?
.iter()
.next()
.and_then(|r| r.get(0))
.unwrap_or_default();
if current_version < 1 {
self.conn
.batch_execute(include_str!("migrations/01_create_tables.sql"))?;
}
Ok(())
}
pub fn is_latest(&self) -> Result<bool, postgres::Error> {
let current_version: i32 = self
.conn
.query("SELECT MAX(version) from migrations", &[])?
.iter()
.next()
.and_then(|r| r.get(0))
.unwrap_or_default();
Ok(Self::DB_VERSION == current_version as u32)
}
pub fn is_compatible(&self) -> Result<bool, postgres::Error> {
let current_version: i32 = self
.conn
.query("SELECT MAX(version) from migrations", &[])?
.iter()
.next()
.and_then(|r| r.get(0))
.unwrap_or_default();
Ok(Self::DB_VERSION >= current_version as u32)
}
pub fn get_entity_count(&self) -> Result<u64, postgres::Error> {
let stmt = self.conn.prepare_cached(
"SELECT COUNT(DISTINCT entity_id) \
FROM events \
WHERE aggregate_type = $1",
)?;
let rows = stmt.query(&[&A::aggregate_type()])?;
Ok(rows
.iter()
.next()
.map(|r| r.get::<_, i64>(0) as u64)
.unwrap_or_default())
}
pub fn get_entity_ids(&self, offset: u32, limit: u32) -> Result<Vec<String>, postgres::Error> {
let stmt = self.conn.prepare_cached(
"SELECT DISTINCT entity_id \
FROM events \
WHERE aggregate_type = $1 \
OFFSET $2 LIMIT $3",
)?;
let rows = stmt.query(&[
&A::aggregate_type(),
&(i64::from(offset)),
&(i64::from(limit)),
])?;
Ok(rows.iter().map(|r| r.get(0)).collect())
}
pub fn get_entity_count_matching_pattern(&self, pattern: &str) -> Result<u64, postgres::Error> {
let stmt = self.conn.prepare_cached(
"SELECT COUNT(DISTINCT entity_id) \
FROM events \
WHERE aggregate_type = $1 AND entity_id LIKE $2",
)?;
let rows = stmt.query(&[&A::aggregate_type(), &pattern])?;
Ok(rows
.iter()
.next()
.map(|r| r.get::<_, i64>(0) as u64)
.unwrap_or_default())
}
pub fn get_entity_ids_matching_pattern(
&self,
pattern: &str,
offset: u32,
limit: u32,
) -> Result<Vec<String>, postgres::Error> {
let stmt = self.conn.prepare_cached(
"SELECT DISTINCT entity_id \
FROM events \
WHERE aggregate_type = $1 AND entity_id LIKE $2 \
OFFSET $3 LIMIT $4",
)?;
let rows = stmt.query(&[
&A::aggregate_type(),
&pattern,
&(i64::from(offset)),
&(i64::from(limit)),
])?;
Ok(rows.iter().map(|r| r.get(0)).collect())
}
pub fn get_entity_count_matching_sql_regex(&self, regex: &str) -> Result<u64, postgres::Error> {
let stmt = self.conn.prepare_cached(
"SELECT COUNT(DISTINCT entity_id) \
FROM events \
WHERE aggregate_type = $1 AND entity_id SIMILAR TO $2",
)?;
let rows = stmt.query(&[&A::aggregate_type(), ®ex])?;
Ok(rows
.iter()
.next()
.map(|r| r.get::<_, i64>(0) as u64)
.unwrap_or_default())
}
pub fn get_entity_ids_matching_sql_regex(
&self,
regex: &str,
offset: u32,
limit: u32,
) -> Result<Vec<String>, postgres::Error> {
let stmt = self.conn.prepare_cached(
"SELECT DISTINCT entity_id \
FROM events \
WHERE aggregate_type = $1 AND entity_id SIMILAR TO $2 \
OFFSET $3 LIMIT $4",
)?;
let rows = stmt.query(&[
&A::aggregate_type(),
®ex,
&(i64::from(offset)),
&(i64::from(limit)),
])?;
Ok(rows.iter().map(|r| r.get(0)).collect())
}
pub fn get_entity_count_matching_posix_regex(
&self,
regex: &str,
) -> Result<u64, postgres::Error> {
let stmt = self.conn.prepare_cached(
"SELECT COUNT(DISTINCT entity_id) \
FROM events \
WHERE aggregate_type = $1 AND entity_id ~ $2",
)?;
let rows = stmt.query(&[&A::aggregate_type(), ®ex])?;
Ok(rows
.iter()
.next()
.map(|r| r.get::<_, i64>(0) as u64)
.unwrap_or_default())
}
pub fn get_entity_ids_matching_posix_regex(
&self,
regex: &str,
offset: u32,
limit: u32,
) -> Result<Vec<String>, postgres::Error> {
let stmt = self.conn.prepare_cached(
"SELECT DISTINCT entity_id \
FROM events \
WHERE aggregate_type = $1 AND entity_id ~ $2 \
OFFSET $3 LIMIT $4",
)?;
let rows = stmt.query(&[
&A::aggregate_type(),
®ex,
&(i64::from(offset)),
&(i64::from(limit)),
])?;
Ok(rows.iter().map(|r| r.get(0)).collect())
}
pub fn read_events_with_metadata<I>(
&self,
id: &I,
since: Since,
max_count: Option<u64>,
) -> Result<
Option<Vec<Result<VersionedEventWithMetadata<E, M>, LoadError<E::Error>>>>,
LoadError<E::Error>,
>
where
I: AggregateId<A>,
E: DeserializableEvent,
M: for<'de> serde::Deserialize<'de>,
{
let last_sequence = match since {
cqrs_core::Since::BeginningOfStream => 0,
cqrs_core::Since::Event(x) => x.get(),
} as i64;
let events;
let trans = self
.conn
.transaction_with(postgres::transaction::Config::default().read_only(true))?;
let handle_row = |row: postgres::rows::Row| {
let sequence: Sequence = row.get(0);
let event_type: String = row.get(1);
let raw: RawJsonRead = row.get(2);
let metadata: Json<M> = row.get(3);
let event = E::deserialize_event_from_buffer(&raw.0, &event_type)
.map_err(LoadError::DeserializationError)?
.ok_or_else(|| LoadError::UnknownEventType(event_type.clone()))?;
log::trace!(
"entity {}: loaded event; sequence: {}, type: {}",
id.as_str(),
sequence.0,
event_type
);
Ok(VersionedEventWithMetadata {
sequence: sequence.0,
event,
metadata: metadata.0,
})
};
let stmt;
{
let mut rows;
if let Some(max_count) = max_count {
stmt = trans.prepare_cached(
"SELECT sequence, event_type, payload, metadata \
FROM events \
WHERE aggregate_type = $1 AND entity_id = $2 AND sequence > $3 \
ORDER BY sequence ASC \
LIMIT $4",
)?;
rows = stmt.lazy_query(
&trans,
&[
&A::aggregate_type(),
&id.as_str(),
&last_sequence,
&(max_count.min(i64::max_value() as u64) as i64),
],
100,
)?;
} else {
stmt = trans.prepare_cached(
"SELECT sequence, event_type, payload, metadata \
FROM events \
WHERE aggregate_type = $1 AND entity_id = $2 AND sequence > $3 \
ORDER BY sequence ASC",
)?;
rows = stmt.lazy_query(
&trans,
&[&A::aggregate_type(), &id.as_str(), &last_sequence],
100,
)?;
}
let (lower, upper) = rows.size_hint();
let cap = upper.unwrap_or(lower);
let mut inner_events = Vec::with_capacity(cap);
while let Some(row) = rows.next()? {
inner_events.push(handle_row(row));
}
events = inner_events;
}
trans.commit()?;
log::trace!("entity {}: read {} events", id.as_str(), events.len());
Ok(Some(events))
}
pub fn read_events_reverse_with_metadata<I>(
&self,
id: &I,
before: Before,
max_count: Option<u64>,
) -> Result<
Option<Vec<Result<VersionedEventWithMetadata<E, M>, LoadError<E::Error>>>>,
LoadError<E::Error>,
>
where
I: AggregateId<A>,
E: DeserializableEvent,
M: for<'de> serde::Deserialize<'de>,
{
let last_sequence = match before {
Before::EndOfStream => std::i64::MAX,
Before::Event(x) => x.get() as i64,
};
let events;
let trans = self
.conn
.transaction_with(postgres::transaction::Config::default().read_only(true))?;
let handle_row = |row: postgres::rows::Row| {
let sequence: Sequence = row.get(0);
let event_type: String = row.get(1);
let raw: RawJsonRead = row.get(2);
let metadata: Json<M> = row.get(3);
let event = E::deserialize_event_from_buffer(&raw.0, &event_type)
.map_err(LoadError::DeserializationError)?
.ok_or_else(|| LoadError::UnknownEventType(event_type.clone()))?;
log::trace!(
"entity {}: loaded event; sequence: {}, type: {}",
id.as_str(),
sequence.0,
event_type
);
Ok(VersionedEventWithMetadata {
sequence: sequence.0,
event,
metadata: metadata.0,
})
};
let stmt;
{
let mut rows;
if let Some(max_count) = max_count {
stmt = trans.prepare_cached(
"SELECT sequence, event_type, payload, metadata \
FROM events \
WHERE aggregate_type = $1 AND entity_id = $2 AND sequence < $3 \
ORDER BY sequence DESC \
LIMIT $4",
)?;
rows = stmt.lazy_query(
&trans,
&[
&A::aggregate_type(),
&id.as_str(),
&last_sequence,
&(max_count.min(i64::max_value() as u64) as i64),
],
100,
)?;
} else {
stmt = trans.prepare_cached(
"SELECT sequence, event_type, payload, metadata \
FROM events \
WHERE aggregate_type = $1 AND entity_id = $2 AND sequence < $3 \
ORDER BY sequence DESC",
)?;
rows = stmt.lazy_query(
&trans,
&[&A::aggregate_type(), &id.as_str(), &last_sequence],
100,
)?;
}
let (lower, upper) = rows.size_hint();
let cap = upper.unwrap_or(lower);
let mut inner_events = Vec::with_capacity(cap);
while let Some(row) = rows.next()? {
inner_events.push(handle_row(row));
}
events = inner_events;
}
trans.commit()?;
log::trace!("entity {}: read {} events", id.as_str(), events.len());
Ok(Some(events))
}
}
impl<'conn, A, E, M, S> EventSink<A, E, M> for PostgresStore<'conn, A, E, M, S>
where
A: Aggregate,
E: AggregateEvent<A> + SerializableEvent + fmt::Debug,
M: Serialize + fmt::Debug,
S: SnapshotStrategy,
{
type Error = PersistError<<E as SerializableEvent>::Error>;
fn append_events<I>(
&self,
id: &I,
events: &[E],
precondition: Option<Precondition>,
metadata: M,
) -> Result<EventNumber, Self::Error>
where
I: AggregateId<A>,
{
let trans = self.conn.transaction()?;
let check_stmt = trans.prepare_cached(
"SELECT MAX(sequence) FROM events WHERE aggregate_type = $1 AND entity_id = $2",
)?;
let result = check_stmt.query(&[&A::aggregate_type(), &id.as_str()])?;
let current_version = result.iter().next().and_then(|r| {
let max_sequence: Option<Sequence> = r.get(0);
max_sequence.map(|x| Version::from(x.0))
});
log::trace!(
"entity {}: current version: {:?}",
id.as_str(),
current_version
);
if events.is_empty() {
return Ok(current_version.unwrap_or_default().next_event());
}
if let Some(precondition) = precondition {
precondition.verify(current_version)?;
}
log::trace!("entity {}: precondition satisfied", id.as_str());
let first_sequence = current_version.unwrap_or_default().next_event();
let mut next_sequence = Version::Number(first_sequence);
let mut buffer = Vec::with_capacity(128);
let stmt = trans.prepare_cached(
"INSERT INTO events (aggregate_type, entity_id, sequence, event_type, payload, metadata, timestamp) \
VALUES ($1, $2, $3, $4, $5, $6, CURRENT_TIMESTAMP)",
)?;
for event in events {
buffer.clear();
event
.serialize_event_to_buffer(&mut buffer)
.map_err(PersistError::SerializationError)?;
let modified_count = stmt.execute(&[
&A::aggregate_type(),
&id.as_str(),
&(next_sequence.get() as i64),
&event.event_type(),
&RawJsonPersist(&buffer),
&BorrowedJson(&metadata),
])?;
debug_assert!(modified_count > 0);
log::trace!(
"entity {}: inserted event; sequence: {}",
id.as_str(),
next_sequence
);
next_sequence.incr();
}
trans.commit()?;
Ok(first_sequence)
}
}
impl<'conn, A, E, M, S> EventSource<A, E> for PostgresStore<'conn, A, E, M, S>
where
A: Aggregate,
E: AggregateEvent<A> + DeserializableEvent,
S: SnapshotStrategy,
{
type Error = LoadError<<E as DeserializableEvent>::Error>;
type Events = Vec<VersionedEvent<E>>;
fn read_events<I>(
&self,
id: &I,
since: Since,
max_count: Option<u64>,
) -> Result<Option<Self::Events>, Self::Error>
where
I: AggregateId<A>,
{
let last_sequence = match since {
cqrs_core::Since::BeginningOfStream => 0,
cqrs_core::Since::Event(x) => x.get(),
} as i64;
let trans = self
.conn
.transaction_with(postgres::transaction::Config::default().read_only(true))?;
let handle_row = |row: postgres::rows::Row| {
let sequence: Sequence = row.get(0);
let event_type: String = row.get(1);
let raw: RawJsonRead = row.get(2);
let event = E::deserialize_event_from_buffer(&raw.0, &event_type)
.map_err(LoadError::DeserializationError)?
.ok_or_else(|| LoadError::UnknownEventType(event_type.clone()))?;
log::trace!(
"entity {}: loaded event; sequence: {}, type: {}",
id.as_str(),
sequence.0,
event_type
);
Ok(VersionedEvent {
sequence: sequence.0,
event,
})
};
let events: Vec<VersionedEvent<E>> =
if let Some(max_count) = max_count.and_then(i64::from_u64) {
let stmt = trans.prepare_cached(
"SELECT sequence, event_type, payload \
FROM events \
WHERE aggregate_type = $1 AND entity_id = $2 AND sequence > $3 \
ORDER BY sequence ASC \
LIMIT $4",
)?;
let rows = stmt.lazy_query(
&trans,
&[
&A::aggregate_type(),
&id.as_str(),
&last_sequence,
&max_count,
],
100,
)?;
rows.into_fallible_iterator()
.map_err(LoadError::Postgres)
.and_then(handle_row)
.collect()?
} else {
let stmt = trans.prepare_cached(
"SELECT sequence, event_type, payload \
FROM events \
WHERE aggregate_type = $1 AND entity_id = $2 AND sequence > $3 \
ORDER BY sequence ASC",
)?;
let rows = stmt.lazy_query(
&trans,
&[&A::aggregate_type(), &id.as_str(), &last_sequence],
100,
)?;
rows.into_fallible_iterator()
.map_err(LoadError::Postgres)
.and_then(handle_row)
.collect()?
};
trans.commit()?;
log::trace!("entity {}: read {} events", id.as_str(), events.len());
Ok(Some(events))
}
}
impl<'conn, A, E, M, S> SnapshotSink<A> for PostgresStore<'conn, A, E, M, S>
where
A: Aggregate + Serialize + fmt::Debug,
E: AggregateEvent<A>,
S: SnapshotStrategy,
{
type Error = PersistError<serde_json::Error>;
fn persist_snapshot<I>(
&self,
id: &I,
aggregate: &A,
version: Version,
last_snapshot_version: Option<Version>,
) -> Result<Version, Self::Error>
where
I: AggregateId<A>,
{
if version <= last_snapshot_version.unwrap_or_default()
|| self
.snapshot_strategy
.snapshot_recommendation(version, last_snapshot_version)
== SnapshotRecommendation::DoNotSnapshot
{
return Ok(last_snapshot_version.unwrap_or_default());
}
let stmt = self.conn.prepare_cached(
"INSERT INTO snapshots (aggregate_type, entity_id, sequence, payload) \
VALUES ($1, $2, $3, $4)",
)?;
let _modified_count = stmt.execute(&[
&A::aggregate_type(),
&id.as_str(),
&(version.get() as i64),
&Json(aggregate),
])?;
log::trace!("entity {}: persisted snapshot", id.as_str());
Ok(version)
}
}
impl<'conn, A, E, M, S> SnapshotSource<A> for PostgresStore<'conn, A, E, M, S>
where
A: Aggregate + DeserializeOwned,
E: AggregateEvent<A>,
S: SnapshotStrategy,
{
type Error = postgres::Error;
fn get_snapshot<I>(&self, id: &I) -> Result<Option<VersionedAggregate<A>>, Self::Error>
where
I: AggregateId<A>,
{
let stmt = self.conn.prepare_cached(
"SELECT sequence, payload \
FROM snapshots \
WHERE aggregate_type = $1 AND entity_id = $2 \
ORDER BY sequence DESC \
LIMIT 1",
)?;
let rows = stmt.query(&[&A::aggregate_type(), &id.as_str()])?;
if let Some(row) = rows.iter().next() {
let sequence: Sequence = row.get(0);
let raw: Json<A> = row.get(1);
log::trace!("entity {}: loaded snapshot", id.as_str());
Ok(Some(VersionedAggregate {
version: Version::from(sequence.0),
payload: raw.0,
}))
} else {
log::trace!("entity {}: no snapshot found", id.as_str());
Ok(None)
}
}
}