es-entity-macros 0.10.36

Proc macros for es-entity
Documentation
use darling::ToTokens;
use proc_macro2::TokenStream;
use quote::{TokenStreamExt, quote};

use super::options::*;

pub struct UpdateFn<'a> {
    entity: &'a syn::Ident,
    table_name: &'a str,
    columns: &'a Columns,
    modify_error: syn::Ident,
    nested_fn_names: Vec<syn::Ident>,
    post_persist_error: Option<&'a syn::Type>,
    #[cfg(feature = "instrument")]
    repo_name_snake: String,
}

impl<'a> From<&'a RepositoryOptions> for UpdateFn<'a> {
    fn from(opts: &'a RepositoryOptions) -> Self {
        Self {
            entity: opts.entity(),
            modify_error: opts.modify_error(),
            columns: &opts.columns,
            table_name: opts.table_name(),
            nested_fn_names: opts
                .all_nested()
                .map(|f| f.update_nested_fn_name())
                .collect(),
            post_persist_error: opts.post_persist_hook.as_ref().map(|h| &h.error),
            #[cfg(feature = "instrument")]
            repo_name_snake: opts.repo_name_snake_case(),
        }
    }
}

impl ToTokens for UpdateFn<'_> {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        let entity = self.entity;
        let modify_error = &self.modify_error;

        let nested = self.nested_fn_names.iter().map(|f| {
            quote! {
                self.#f(op, entity).await?;
            }
        });

        let update_tokens = if self.columns.updates_needed() {
            let assignments = self
                .columns
                .variable_assignments_for_update(syn::parse_quote! { entity });
            let column_updates = self.columns.sql_updates();
            let query = format!(
                "UPDATE {} SET {} WHERE id = $1",
                self.table_name, column_updates,
            );
            let args = self.columns.update_query_args();
            Some(quote! {
            #assignments
            sqlx::query!(
                #query,
                #(#args),*
            )
                .execute(op.as_executor())
                .await
                .map_err(|e| match &e {
                    sqlx::Error::Database(db_err) if db_err.is_unique_violation() => {
                        #modify_error::ConstraintViolation {
                            column: Self::map_constraint_column(db_err.constraint()),
                            value: es_entity::extract_constraint_value(db_err.as_ref()),
                            inner: e,
                        }
                    }
                    _ => #modify_error::Sqlx(e),
                })?;
            })
        } else {
            None
        };

        #[cfg(feature = "instrument")]
        let (instrument_attr, record_id, error_recording) = {
            use convert_case::{Case, Casing};

            let entity_name = entity.to_string();
            let repo_name = &self.repo_name_snake;

            let id_ident = quote::format_ident!("{}_id", entity.to_string().to_case(Case::Snake));

            let span_name = format!("{}.update", repo_name);
            (
                quote! {
                    #[tracing::instrument(name = #span_name, skip_all, fields(entity = #entity_name, #id_ident = tracing::field::Empty, error = tracing::field::Empty, exception.message = tracing::field::Empty, exception.type = tracing::field::Empty))]
                },
                quote! {
                    tracing::Span::current().record(stringify!(#id_ident), tracing::field::display(&entity.id));
                },
                quote! {
                    if let Err(ref e) = __result {
                        tracing::Span::current().record("error", true);
                        tracing::Span::current().record("exception.message", tracing::field::display(e));
                        tracing::Span::current().record("exception.type", std::any::type_name_of_val(e));
                    }
                },
            )
        };
        #[cfg(not(feature = "instrument"))]
        let (instrument_attr, record_id, error_recording) = (quote! {}, quote! {}, quote! {});

        let post_persist_check = if self.post_persist_error.is_some() {
            quote! {
                self.execute_post_persist_hook(op, &entity, entity.events().last_persisted(n_events)).await.map_err(#modify_error::PostPersistHookError)?;
            }
        } else {
            quote! {}
        };

        tokens.append_all(quote! {
            #[inline(always)]
            fn extract_events<Entity, Event>(entity: &mut Entity) -> &mut es_entity::EntityEvents<Event>
            where
                Entity: es_entity::EsEntity<Event = Event>,
                Event: es_entity::EsEvent,
            {
                entity.events_mut()
            }

            pub async fn update(
                &self,
                entity: &mut #entity
            ) -> Result<usize, #modify_error> {
                let mut op = self.begin_op().await?;
                let res = self.update_in_op(&mut op, entity).await?;
                op.commit().await?;
                Ok(res)
            }

            #instrument_attr
            pub async fn update_in_op<OP>(
                &self,
                op: &mut OP,
                entity: &mut #entity
            ) -> Result<usize, #modify_error>
            where
                OP: es_entity::AtomicOperation
            {
                let __result: Result<usize, #modify_error> = async {
                    #record_id
                    #(#nested)*

                    if !Self::extract_events(entity).any_new() {
                        return Ok(0);
                    }

                    #update_tokens
                    let n_events = {
                        let events = Self::extract_events(entity);
                        Self::extract_concurrent_modification(
                            self.persist_events(op, events).await,
                            #modify_error::ConcurrentModification,
                        )?
                    };

                    #post_persist_check

                    Ok(n_events)
                }.await;

                #error_recording
                __result
            }
        });
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use proc_macro2::Span;
    use syn::Ident;

    #[test]
    fn update_fn() {
        let id = syn::parse_str("EntityId").unwrap();
        let entity = Ident::new("Entity", Span::call_site());

        let columns = Columns::new(
            &id,
            [Column::new(
                Ident::new("name", Span::call_site()),
                syn::parse_str("String").unwrap(),
            )],
        );

        let update_fn = UpdateFn {
            entity: &entity,
            table_name: "entities",
            modify_error: syn::Ident::new("EntityModifyError", Span::call_site()),
            columns: &columns,
            nested_fn_names: Vec::new(),
            post_persist_error: None,
            #[cfg(feature = "instrument")]
            repo_name_snake: "test_repo".to_string(),
        };

        let mut tokens = TokenStream::new();
        update_fn.to_tokens(&mut tokens);

        let expected = quote! {
            #[inline(always)]
            fn extract_events<Entity, Event>(entity: &mut Entity) -> &mut es_entity::EntityEvents<Event>
            where
                Entity: es_entity::EsEntity<Event = Event>,
                Event: es_entity::EsEvent,
            {
                entity.events_mut()
            }

            pub async fn update(
                &self,
                entity: &mut Entity
            ) -> Result<usize, EntityModifyError> {
                let mut op = self.begin_op().await?;
                let res = self.update_in_op(&mut op, entity).await?;
                op.commit().await?;
                Ok(res)
            }

            pub async fn update_in_op<OP>(
                &self,
                op: &mut OP,
                entity: &mut Entity
            ) -> Result<usize, EntityModifyError>
            where
                OP: es_entity::AtomicOperation
            {
                let __result: Result<usize, EntityModifyError> = async {
                    if !Self::extract_events(entity).any_new() {
                        return Ok(0);
                    }

                    let id = &entity.id;
                    let name = &entity.name;
                    sqlx::query!(
                        "UPDATE entities SET name = $2 WHERE id = $1",
                        id as &EntityId,
                        name as &String
                    )
                        .execute(op.as_executor())
                        .await
                        .map_err(|e| match &e {
                            sqlx::Error::Database(db_err) if db_err.is_unique_violation() => {
                                EntityModifyError::ConstraintViolation {
                                    column: Self::map_constraint_column(db_err.constraint()),
                                    value: es_entity::extract_constraint_value(db_err.as_ref()),
                                    inner: e,
                                }
                            }
                            _ => EntityModifyError::Sqlx(e),
                        })?;

                    let n_events = {
                        let events = Self::extract_events(entity);
                        Self::extract_concurrent_modification(
                            self.persist_events(op, events).await,
                            EntityModifyError::ConcurrentModification,
                        )?
                    };

                    Ok(n_events)
                }.await;

                __result
            }
        };

        assert_eq!(tokens.to_string(), expected.to_string());
    }

    #[test]
    fn update_fn_no_columns() {
        let id = syn::parse_str("EntityId").unwrap();
        let entity = Ident::new("Entity", Span::call_site());

        let mut columns = Columns::default();
        columns.set_id_column(&id);

        let update_fn = UpdateFn {
            entity: &entity,
            table_name: "entities",
            modify_error: syn::Ident::new("EntityModifyError", Span::call_site()),
            columns: &columns,
            nested_fn_names: Vec::new(),
            post_persist_error: None,
            #[cfg(feature = "instrument")]
            repo_name_snake: "test_repo".to_string(),
        };

        let mut tokens = TokenStream::new();
        update_fn.to_tokens(&mut tokens);

        let expected = quote! {
            #[inline(always)]
            fn extract_events<Entity, Event>(entity: &mut Entity) -> &mut es_entity::EntityEvents<Event>
            where
                Entity: es_entity::EsEntity<Event = Event>,
                Event: es_entity::EsEvent,
            {
                entity.events_mut()
            }

            pub async fn update(
                &self,
                entity: &mut Entity
            ) -> Result<usize, EntityModifyError> {
                let mut op = self.begin_op().await?;
                let res = self.update_in_op(&mut op, entity).await?;
                op.commit().await?;
                Ok(res)
            }

            pub async fn update_in_op<OP>(
                &self,
                op: &mut OP,
                entity: &mut Entity
            ) -> Result<usize, EntityModifyError>
            where
                OP: es_entity::AtomicOperation
            {
                let __result: Result<usize, EntityModifyError> = async {
                    if !Self::extract_events(entity).any_new() {
                        return Ok(0);
                    }

                    let n_events = {
                        let events = Self::extract_events(entity);
                        Self::extract_concurrent_modification(
                            self.persist_events(op, events).await,
                            EntityModifyError::ConcurrentModification,
                        )?
                    };

                    Ok(n_events)
                }.await;

                __result
            }
        };

        assert_eq!(tokens.to_string(), expected.to_string());
    }
}