es-entity-macros 0.10.36

Proc macros for es-entity
Documentation
use darling::ToTokens;
use proc_macro2::TokenStream;
use quote::{TokenStreamExt, quote};

use super::options::*;

pub struct PersistEventsFn<'a> {
    id: &'a syn::Ident,
    event: &'a syn::Ident,
    events_table_name: &'a str,
    event_ctx: bool,
}

impl<'a> From<&'a RepositoryOptions> for PersistEventsFn<'a> {
    fn from(opts: &'a RepositoryOptions) -> Self {
        Self {
            id: opts.id(),
            event: opts.event(),
            events_table_name: opts.events_table_name(),
            event_ctx: opts.event_context_enabled(),
        }
    }
}

impl ToTokens for PersistEventsFn<'_> {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        let query = format!(
            "INSERT INTO {} (id, recorded_at, sequence, event_type, event{}) SELECT $1, COALESCE($2, NOW()), ROW_NUMBER() OVER () + $3, unnested.event_type, unnested.event{} FROM UNNEST($4::TEXT[], $5::JSONB[]{}) AS unnested(event_type, event{}) RETURNING recorded_at",
            self.events_table_name,
            if self.event_ctx { ", context" } else { "" },
            if self.event_ctx {
                ", unnested.context"
            } else {
                ""
            },
            if self.event_ctx { ", $6::JSONB[]" } else { "" },
            if self.event_ctx { ", context" } else { "" }
        );

        let (ctx_var, ctx_arg) = if self.event_ctx {
            (
                quote! { let contexts = events.serialize_new_event_contexts(); },
                quote! {
                    contexts.as_deref() as Option<&[es_entity::ContextData]>,
                },
            )
        } else {
            (quote! {}, quote! {})
        };
        let id_type = &self.id;
        let event_type = &self.event;
        let id_tokens = quote! {
            id as &#id_type
        };

        tokens.append_all(quote! {
            fn extract_concurrent_modification<T, __EsErr: From<sqlx::Error>>(
                res: Result<T, sqlx::Error>,
                concurrent_modification: __EsErr,
            ) -> Result<T, __EsErr> {
                match res {
                    Ok(v) => Ok(v),
                    Err(sqlx::Error::Database(ref db_err)) if db_err.is_unique_violation() => {
                        Err(concurrent_modification)
                    }
                    Err(e) => Err(__EsErr::from(e)),
                }
            }

            async fn persist_events<OP>(
                &self,
                op: &mut OP,
                events: &mut es_entity::EntityEvents<#event_type>
            ) -> Result<usize, sqlx::Error>
            where
                OP: es_entity::AtomicOperation,
            {
                let id = events.id();
                let offset = events.len_persisted();
                let events_types = events.new_event_types();
                let serialized_events = events.serialize_new_events();
                #ctx_var
                let now = op.maybe_now();

                let rows = sqlx::query!(
                        #query,
                        #id_tokens,
                        now,
                        offset as i32,
                        &events_types,
                        &serialized_events,
                        #ctx_arg
                    ).fetch_all(op.as_executor()).await?;

                let recorded_at = rows[0].recorded_at;
                let n_events = events.mark_new_events_persisted_at(recorded_at);

                Ok(n_events)
            }
        });
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn persist_events_fn() {
        let id = syn::parse_str("EntityId").unwrap();
        let event = syn::Ident::new("EntityEvent", proc_macro2::Span::call_site());
        let persist_fn = PersistEventsFn {
            id: &id,
            event: &event,
            events_table_name: "entity_events",
            event_ctx: true,
        };

        let mut tokens = TokenStream::new();
        persist_fn.to_tokens(&mut tokens);

        let expected = quote! {
            fn extract_concurrent_modification<T, __EsErr: From<sqlx::Error>>(
                res: Result<T, sqlx::Error>,
                concurrent_modification: __EsErr,
            ) -> Result<T, __EsErr> {
                match res {
                    Ok(v) => Ok(v),
                    Err(sqlx::Error::Database(ref db_err)) if db_err.is_unique_violation() => {
                        Err(concurrent_modification)
                    }
                    Err(e) => Err(__EsErr::from(e)),
                }
            }

            async fn persist_events<OP>(
                &self,
                op: &mut OP,
                events: &mut es_entity::EntityEvents<EntityEvent>
            ) -> Result<usize, sqlx::Error>
            where
                OP: es_entity::AtomicOperation,
            {
                let id = events.id();
                let offset = events.len_persisted();
                let events_types = events.new_event_types();
                let serialized_events = events.serialize_new_events();
                let contexts = events.serialize_new_event_contexts();
                let now = op.maybe_now();

                let rows = sqlx::query!(
                        "INSERT INTO entity_events (id, recorded_at, sequence, event_type, event, context) SELECT $1, COALESCE($2, NOW()), ROW_NUMBER() OVER () + $3, unnested.event_type, unnested.event, unnested.context FROM UNNEST($4::TEXT[], $5::JSONB[], $6::JSONB[]) AS unnested(event_type, event, context) RETURNING recorded_at",
                        id as &EntityId,
                        now,
                        offset as i32,
                        &events_types,
                        &serialized_events,
                        contexts.as_deref() as Option<&[es_entity::ContextData]>,
                    ).fetch_all(op.as_executor()).await?;

                let recorded_at = rows[0].recorded_at;
                let n_events = events.mark_new_events_persisted_at(recorded_at);

                Ok(n_events)
            }
        };

        assert_eq!(tokens.to_string(), expected.to_string());
    }

    #[test]
    fn persist_events_fn_without_event_context() {
        let id = syn::parse_str("EntityId").unwrap();
        let event = syn::Ident::new("EntityEvent", proc_macro2::Span::call_site());
        let persist_fn = PersistEventsFn {
            id: &id,
            event: &event,
            events_table_name: "entity_events",
            event_ctx: false,
        };

        let mut tokens = TokenStream::new();
        persist_fn.to_tokens(&mut tokens);

        let expected = quote! {
            fn extract_concurrent_modification<T, __EsErr: From<sqlx::Error>>(
                res: Result<T, sqlx::Error>,
                concurrent_modification: __EsErr,
            ) -> Result<T, __EsErr> {
                match res {
                    Ok(v) => Ok(v),
                    Err(sqlx::Error::Database(ref db_err)) if db_err.is_unique_violation() => {
                        Err(concurrent_modification)
                    }
                    Err(e) => Err(__EsErr::from(e)),
                }
            }

            async fn persist_events<OP>(
                &self,
                op: &mut OP,
                events: &mut es_entity::EntityEvents<EntityEvent>
            ) -> Result<usize, sqlx::Error>
            where
                OP: es_entity::AtomicOperation,
            {
                let id = events.id();
                let offset = events.len_persisted();
                let events_types = events.new_event_types();
                let serialized_events = events.serialize_new_events();
                let now = op.maybe_now();

                let rows = sqlx::query!(
                        "INSERT INTO entity_events (id, recorded_at, sequence, event_type, event) SELECT $1, COALESCE($2, NOW()), ROW_NUMBER() OVER () + $3, unnested.event_type, unnested.event FROM UNNEST($4::TEXT[], $5::JSONB[]) AS unnested(event_type, event) RETURNING recorded_at",
                        id as &EntityId,
                        now,
                        offset as i32,
                        &events_types,
                        &serialized_events,
                    ).fetch_all(op.as_executor()).await?;

                let recorded_at = rows[0].recorded_at;
                let n_events = events.mark_new_events_persisted_at(recorded_at);

                Ok(n_events)
            }
        };

        assert_eq!(tokens.to_string(), expected.to_string());
    }
}