use serde::Deserialize;
use serde::{de::DeserializeOwned, Serialize};
use crate::event::Event;
use crate::event::EventId;
use crate::stream_query::StreamQuery;
use crate::{all_the_tuples, union, BoxDynError, StateSnapshotter, StreamItem};
use async_trait::async_trait;
use paste::paste;
use std::error::Error as StdError;
use std::ops::Deref;
pub trait StateMutate: StateQuery {
fn mutate(&mut self, event: Self::Event);
}
pub trait MultiState<ID: EventId, E: Event + Clone> {
fn mutate_all<I: Into<StreamItem<ID, E>>>(&mut self, item: I);
fn query_all(&self) -> StreamQuery<ID, E>;
fn version(&self) -> ID;
}
macro_rules! impl_multi_state {
(
[$($ty:ident),*], $last:ident
) => {
#[allow(unused_parens)]
impl<ID: EventId, E, $($ty,)* $last> MultiState<ID, E> for ($(StatePart<ID, $ty>,)* StatePart<ID, $last>)
where
E: Event + Clone,
$($ty: StateQuery + StateMutate,)*
$last: StateQuery + StateMutate,
<$last as StateQuery>::Event: TryFrom<E> + Into<E>,
$(<$ty as StateQuery>::Event: TryFrom<E> + Into<E>,)*
<$last as StateQuery>::Event: TryFrom<E> + Into<E>,
$(<<$ty as StateQuery>::Event as TryFrom<E>>::Error:
StdError + 'static + Send + Sync,)*
<<$last as StateQuery>::Event as TryFrom<E>>::Error:
StdError + 'static + Send + Sync,
{
fn mutate_all<I: Into<StreamItem<ID, E>>>(&mut self, item: I) {
let item = item.into();
paste! {
let ($([<state_ $ty:lower>],)* [<state_ $last:lower>])= self;
$(
if [<state_ $ty:lower>].matches_item(&item) {
[<state_ $ty:lower>].mutate_part(item.clone());
}
)*
if [<state_ $last:lower>].matches_item(&item) {
[<state_ $last:lower>].mutate_part(item.clone());
}
}
}
fn query_all(&self) -> StreamQuery<ID, E> {
paste!{
let ($([<state_ $ty:lower>],)* [<state_ $last:lower>])= self;
union!($([<state_ $ty:lower>].query_part(),)* [<state_ $last:lower>].query_part())
}
}
fn version(&self) -> ID {
paste!{
let ($([<state_ $ty:lower>],)* [<state_ $last:lower>])= self;
let version = [<state_ $last:lower>].version();
$(
let version = version.max([<state_ $ty:lower>].version());
)*
version
}
}
}
}
}
all_the_tuples!(impl_multi_state);
#[async_trait]
pub trait MultiStateSnapshot<ID: EventId, T: StateSnapshotter<ID>> {
async fn load_all(&mut self, backend: &T) -> ID;
async fn store_all(&self, backend: &T) -> Result<(), BoxDynError>;
}
macro_rules! impl_multi_state_snapshot {
(
[$($ty:ident),*], $last:ident
) => {
#[async_trait]
#[allow(unused_parens)]
impl<ID: EventId, B, $($ty,)* $last> MultiStateSnapshot<ID, B> for ($(StatePart<ID, $ty>,)* StatePart<ID, $last>)
where
B: StateSnapshotter<ID> + Send + Sync,
$($ty: StateQuery + Serialize + DeserializeOwned + 'static,)*
$last: StateQuery + Serialize + DeserializeOwned + 'static,
{
async fn load_all(&mut self, backend: &B) -> ID {
paste! {
let ($([<state_ $ty:lower>],)* [<state_ $last:lower>]) = self;
*[<state_ $last:lower>] = backend.load_snapshot([<state_ $last:lower>].clone()).await;
let last_event_id = [<state_ $last:lower>].version;
$(
*[<state_ $ty:lower>] = backend.load_snapshot([<state_ $ty:lower>].clone()).await;
let last_event_id = last_event_id.max([<state_ $ty:lower>].version);
)*
}
last_event_id
}
async fn store_all(&self, backend: &B) -> Result<(), BoxDynError>{
paste!{
let ($([<state_ $ty:lower>],)* [<state_ $last:lower>]) = self;
$(
backend.store_snapshot(&[<state_ $ty:lower>]).await?;
)*
backend.store_snapshot(&[<state_ $last:lower>]).await?;
}
Ok(())
}
}
}
}
all_the_tuples!(impl_multi_state_snapshot);
pub trait StateQuery: Clone + Send + Sync {
const NAME: &'static str;
type Event: Event + Clone + Send + Sync;
fn query<ID: EventId>(&self) -> StreamQuery<ID, Self::Event>;
}
impl<ID, S, E: Event + Clone> From<&S> for StreamQuery<ID, E>
where
S: StateQuery<Event = E>,
ID: EventId,
{
fn from(state: &S) -> Self {
state.query()
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct StatePart<ID: EventId, S: StateQuery> {
version: ID,
applied_events: u64,
inner: S,
}
impl<ID: EventId, S: StateQuery> StatePart<ID, S> {
pub fn new(version: ID, payload: S) -> Self {
Self {
version,
applied_events: 0,
inner: payload,
}
}
pub fn version(&self) -> ID {
self.version
}
pub fn applied_events(&self) -> u64 {
self.applied_events
}
pub fn query_part(&self) -> StreamQuery<ID, <S as StateQuery>::Event> {
self.inner.query().change_origin(self.version)
}
pub fn matches_item<U>(&self, item: &StreamItem<ID, U>) -> bool
where
U: Event + Clone,
<S as StateQuery>::Event: Into<U>,
{
match item {
StreamItem::End(_) => true,
StreamItem::Event(event) => self.query_part().cast().matches(event),
}
}
pub fn mutate_part<E>(&mut self, item: StreamItem<ID, E>)
where
E: Event,
S: StateMutate,
<S as StateQuery>::Event: TryFrom<E>,
<<S as StateQuery>::Event as TryFrom<E>>::Error: StdError + 'static + Send + Sync,
{
match item {
StreamItem::End(version) => {
self.version = version;
}
StreamItem::Event(event) => {
self.version = event.id;
self.applied_events += 1;
self.inner.mutate(event.event.try_into().unwrap());
}
}
}
}
impl<ID: EventId, S: StateQuery> Deref for StatePart<ID, S> {
type Target = S;
fn deref(&self) -> &S {
&self.inner
}
}
pub trait IntoStatePart<ID: EventId, T>: Sized {
type Target;
fn into_state_part(self) -> Self::Target;
}
pub trait IntoState<T>: Sized {
fn into_state(self) -> T;
}
macro_rules! impl_from_state {
(
[$($ty:ident),*], $last:ident
) => {
#[allow(unused_parens)]
impl<ID, $($ty,)* $last> IntoStatePart<ID, ($($ty,)* $last)> for ($($ty,)* $last) where
ID: EventId,
$($ty: StateQuery,)*
$last: StateQuery,
{
type Target = ($(StatePart<ID, $ty>,)* StatePart<ID, $last>);
paste::paste! {
fn into_state_part(self) -> ($(StatePart<ID, $ty>,)*StatePart<ID, $last>){
let ($([<state_ $ty:lower>],)* [<state_ $last:lower>])= self;
($(StatePart{ inner: [<state_ $ty:lower>], version: Default::default(), applied_events: 0},)* StatePart{inner: [<state_ $last:lower>], version: Default::default(), applied_events: 0})
}
}
}
#[allow(unused_parens)]
impl<ID, $($ty,)* $last> IntoState<($($ty,)* $last)> for ($(StatePart<ID, $ty>,)* StatePart<ID, $last>) where
ID: EventId,
$($ty: StateQuery,)*
$last: StateQuery,
{
paste::paste! {
fn into_state(self) -> ($($ty,)* $last){
let ($([<state_ $ty:lower>],)* [<state_ $last:lower>])= self;
($( [<state_ $ty:lower>].inner,)* [<state_ $last:lower>].inner)
}
}
}
}
}
all_the_tuples!(impl_from_state);
#[cfg(test)]
mod test {
use super::*;
use crate::{utils::tests::*, PersistedEvent};
#[test]
fn it_mutates_all() {
let mut state = (Cart::new("c1"), Cart::new("c2")).into_state_part();
state.mutate_all(PersistedEvent::new(1, item_added_event("p1", "c1")));
state.mutate_all(PersistedEvent::new(2, item_added_event("p2", "c2")));
let (cart1, cart2) = state;
assert_eq!(cart1.version, 1);
assert_eq!(cart1.applied_events, 1);
assert_eq!(cart1.into_state(), cart("c1", ["p1".to_string()]));
assert_eq!(cart2.version, 2);
assert_eq!(cart2.applied_events, 1);
assert_eq!(cart2.into_state(), cart("c2", ["p2".to_string()]));
}
#[test]
fn it_queries_all() {
let cart1 = Cart::new("c1");
let cart2 = Cart::new("c2");
let state = (cart1.clone(), cart2.clone()).into_state_part();
let query: StreamQuery<_, ShoppingCartEvent> = state.query_all();
assert_eq!(
query,
union!(
cart1.query().change_origin(0),
cart2.query().change_origin(0)
)
);
}
#[tokio::test]
async fn it_stores_all() {
let multi_state = (cart("c1", []), cart("c2", [])).into_state_part();
let mut snapshotter = MockStateSnapshotter::new();
snapshotter
.expect_store_snapshot()
.once()
.withf(|s: &StatePart<i64, Cart>| s.inner == cart("c1", []))
.return_once(|_| Ok(()));
snapshotter
.expect_store_snapshot()
.once()
.withf(|s: &StatePart<i64, Cart>| s.inner == cart("c2", []))
.return_once(|_| Ok(()));
multi_state.store_all(&snapshotter).await.unwrap();
}
#[tokio::test]
async fn it_loads_all() {
let mut multi_state = (cart("c1", []), cart("c2", [])).into_state_part();
let mut snapshotter = MockStateSnapshotter::new();
snapshotter
.expect_load_snapshot()
.once()
.withf(|q| q.inner == cart("c1", []))
.returning(|_| cart("c1", ["p1".to_owned()]).into_state_part());
snapshotter
.expect_load_snapshot()
.once()
.withf(|q| q.inner == cart("c2", []))
.returning(|_| cart("c2", ["p2".to_owned()]).into_state_part());
multi_state.load_all(&snapshotter).await;
let (cart1, cart2) = multi_state;
assert_eq!(cart1.inner, cart("c1", ["p1".to_owned()]));
assert_eq!(cart2.inner, cart("c2", ["p2".to_owned()]));
}
}