es-entity-macros 0.10.36

Proc macros for es-entity
Documentation
use darling::ToTokens;
use proc_macro2::TokenStream;
use quote::{TokenStreamExt, quote};

use super::options::*;

pub struct CreateAllFn<'a> {
    entity: &'a syn::Ident,
    table_name: &'a str,
    columns: &'a Columns,
    create_error: syn::Ident,
    nested_fn_names: Vec<syn::Ident>,
    post_hydrate_error: Option<&'a syn::Type>,
    post_persist_error: Option<&'a syn::Type>,
    #[cfg(feature = "instrument")]
    repo_name_snake: String,
}

impl<'a> From<&'a RepositoryOptions> for CreateAllFn<'a> {
    fn from(opts: &'a RepositoryOptions) -> Self {
        Self {
            table_name: opts.table_name(),
            entity: opts.entity(),
            create_error: opts.create_error(),
            nested_fn_names: opts
                .all_nested()
                .map(|f| f.create_nested_fn_name())
                .collect(),
            columns: &opts.columns,
            post_hydrate_error: opts.post_hydrate_hook.as_ref().map(|h| &h.error),
            post_persist_error: opts.post_persist_hook.as_ref().map(|h| &h.error),
            #[cfg(feature = "instrument")]
            repo_name_snake: opts.repo_name_snake_case(),
        }
    }
}

impl ToTokens for CreateAllFn<'_> {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        let entity = self.entity;
        let create_error = &self.create_error;

        let nested = self.nested_fn_names.iter().map(|f| {
            quote! {
                self.#f(op, &mut entity).await?;
            }
        });
        let maybe_mut_entity = if self.nested_fn_names.is_empty() {
            quote! { entity }
        } else {
            quote! { mut entity }
        };

        let table_name = self.table_name;

        let column_names = self.columns.insert_column_names();
        let placeholders = self.columns.insert_placeholders(1);
        let (arg_collection, bindings) = self
            .columns
            .create_all_arg_collection(syn::parse_quote! { new_entity });

        let query = format!(
            "INSERT INTO {} (created_at, {}) \
            SELECT COALESCE($1, NOW()), unnested.{} \
            FROM UNNEST({}) \
            AS unnested({})",
            table_name,
            column_names.join(", "),
            column_names.join(", unnested."),
            placeholders,
            column_names.join(", "),
        );

        #[cfg(feature = "instrument")]
        let (instrument_attr, error_recording) = {
            let entity_name = entity.to_string();
            let repo_name = &self.repo_name_snake;
            let span_name = format!("{}.create_all", repo_name);
            (
                quote! {
                    #[tracing::instrument(name = #span_name, skip_all, fields(entity = #entity_name, count = new_entities.len(), error = tracing::field::Empty, exception.message = tracing::field::Empty, exception.type = tracing::field::Empty))]
                },
                quote! {
                    if let Err(ref e) = __result {
                        tracing::Span::current().record("error", true);
                        tracing::Span::current().record("exception.message", tracing::field::display(e));
                        tracing::Span::current().record("exception.type", std::any::type_name_of_val(e));
                    }
                },
            )
        };
        #[cfg(not(feature = "instrument"))]
        let (instrument_attr, error_recording) = (quote! {}, quote! {});

        let post_hydrate_check = if self.post_hydrate_error.is_some() {
            quote! {
                self.execute_post_hydrate_hook(&entity).map_err(#create_error::PostHydrateError)?;
            }
        } else {
            quote! {}
        };

        let post_persist_check = if self.post_persist_error.is_some() {
            quote! {
                self.execute_post_persist_hook(op, &entity, entity.events().last_persisted(n_events)).await.map_err(#create_error::PostPersistHookError)?;
            }
        } else {
            quote! {}
        };

        tokens.append_all(quote! {
            pub async fn create_all(
                &self,
                new_entities: Vec<<#entity as es_entity::EsEntity>::New>
            ) -> Result<Vec<#entity>, #create_error> {
                let mut op = self.begin_op().await?;
                let res = self.create_all_in_op(&mut op, new_entities).await?;
                op.commit().await?;
                Ok(res)
            }

            #instrument_attr
            pub async fn create_all_in_op<OP>(
                &self,
                op: &mut OP,
                new_entities: Vec<<#entity as es_entity::EsEntity>::New>
            ) -> Result<Vec<#entity>, #create_error>
            where
                OP: es_entity::AtomicOperation
            {
                let __result: Result<Vec<#entity>, #create_error> = async {
                    let mut res = Vec::new();
                    if new_entities.is_empty() {
                        return Ok(res);
                    }

                    #arg_collection

                    let now = op.maybe_now();
                    sqlx::query(#query)
                       .bind(now)
                       #(#bindings)*
                       .fetch_all(op.as_executor())
                       .await
                       .map_err(|e| match &e {
                           sqlx::Error::Database(db_err) if db_err.is_unique_violation() => {
                               #create_error::ConstraintViolation {
                                   column: Self::map_constraint_column(db_err.constraint()),
                                   value: es_entity::extract_constraint_value(db_err.as_ref()),
                                   inner: e,
                               }
                           }
                           _ => #create_error::Sqlx(e),
                       })?;


                    let mut all_events: Vec<es_entity::EntityEvents<<#entity as es_entity::EsEntity>::Event>> = new_entities.into_iter().map(Self::convert_new).collect();
                    let mut n_persisted = Self::extract_concurrent_modification(
                        self.persist_events_batch(op, &mut all_events).await,
                        #create_error::ConcurrentModification,
                    )?;

                    for events in all_events.into_iter() {
                        let n_events = n_persisted.remove(events.id()).expect("n_events exists");
                        let #maybe_mut_entity = Self::hydrate_entity(events)?;

                        #(#nested)*

                        #post_hydrate_check
                        #post_persist_check
                        res.push(entity);
                    }

                    Ok(res)
                }.await;

                #error_recording
                __result
            }
        });
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use proc_macro2::Span;
    use syn::Ident;

    #[test]
    fn create_all_fn() {
        let entity = Ident::new("Entity", Span::call_site());
        let create_error = syn::Ident::new("EntityCreateError", Span::call_site());

        use darling::FromMeta;
        let input: syn::Meta = syn::parse_quote!(columns(id = "EntityId", name = "String",));
        let columns = Columns::from_meta(&input).expect("Failed to parse Fields");

        let create_fn = CreateAllFn {
            table_name: "entities",
            entity: &entity,
            create_error,
            columns: &columns,
            nested_fn_names: Vec::new(),
            post_hydrate_error: None,
            post_persist_error: None,
            #[cfg(feature = "instrument")]
            repo_name_snake: "test_repo".to_string(),
        };

        let mut tokens = TokenStream::new();
        create_fn.to_tokens(&mut tokens);

        let mut tokens = TokenStream::new();
        create_fn.to_tokens(&mut tokens);

        let expected = quote! {
            pub async fn create_all(
                &self,
                new_entities: Vec<<Entity as es_entity::EsEntity>::New>
            ) -> Result<Vec<Entity>, EntityCreateError> {
                let mut op = self.begin_op().await?;
                let res = self.create_all_in_op(&mut op, new_entities).await?;
                op.commit().await?;
                Ok(res)
            }

            pub async fn create_all_in_op<OP>(
                &self,
                op: &mut OP,
                new_entities: Vec<<Entity as es_entity::EsEntity>::New>
            ) -> Result<Vec<Entity>, EntityCreateError>
            where
                OP: es_entity::AtomicOperation
            {
                let __result: Result<Vec<Entity>, EntityCreateError> = async {
                    let mut res = Vec::new();
                    if new_entities.is_empty() {
                        return Ok(res);
                    }

                    let mut id_collection = Vec::new();
                    let mut name_collection = Vec::new();

                    for new_entity in new_entities.iter() {
                        let id: &EntityId = &new_entity.id;
                        let name: &String = &new_entity.name;

                        id_collection.push(id);
                        name_collection.push(name);
                    }

                    let now = op.maybe_now();
                    sqlx::query(
                        "INSERT INTO entities (created_at, id, name) SELECT COALESCE($1, NOW()), unnested.id, unnested.name FROM UNNEST($2, $3) AS unnested(id, name)")
                        .bind(now)
                        .bind(id_collection)
                        .bind(name_collection)
                        .fetch_all(op.as_executor())
                        .await
                        .map_err(|e| match &e {
                            sqlx::Error::Database(db_err) if db_err.is_unique_violation() => {
                                EntityCreateError::ConstraintViolation {
                                    column: Self::map_constraint_column(db_err.constraint()),
                                    value: es_entity::extract_constraint_value(db_err.as_ref()),
                                    inner: e,
                                }
                            }
                            _ => EntityCreateError::Sqlx(e),
                        })?;


                    let mut all_events: Vec<es_entity::EntityEvents<<#entity as es_entity::EsEntity>::Event>> = new_entities.into_iter().map(Self::convert_new).collect();
                    let mut n_persisted = Self::extract_concurrent_modification(
                        self.persist_events_batch(op, &mut all_events).await,
                        EntityCreateError::ConcurrentModification,
                    )?;

                    for events in all_events.into_iter() {
                        let n_events = n_persisted.remove(events.id()).expect("n_events exists");
                        let entity = Self::hydrate_entity(events)?;

                        res.push(entity);
                    }

                    Ok(res)
                }.await;

                __result
            }
        };

        assert_eq!(tokens.to_string(), expected.to_string());
    }
}