es-entity-macros 0.10.36

Proc macros for es-entity
Documentation
use darling::ToTokens;
use proc_macro2::TokenStream;
use quote::{TokenStreamExt, quote};

use super::options::*;

pub struct FindAllFn<'a> {
    prefix: Option<&'a syn::LitStr>,
    id: &'a syn::Ident,
    entity: &'a syn::Ident,
    table_name: &'a str,
    query_error: syn::Ident,
    any_nested: bool,
    post_hydrate_error: Option<&'a syn::Type>,
    #[cfg(feature = "instrument")]
    repo_name_snake: String,
}

impl<'a> From<&'a RepositoryOptions> for FindAllFn<'a> {
    fn from(opts: &'a RepositoryOptions) -> Self {
        Self {
            prefix: opts.table_prefix(),
            id: opts.id(),
            entity: opts.entity(),
            table_name: opts.table_name(),
            query_error: opts.query_error(),
            any_nested: opts.any_nested(),
            post_hydrate_error: opts.post_hydrate_hook.as_ref().map(|h| &h.error),
            #[cfg(feature = "instrument")]
            repo_name_snake: opts.repo_name_snake_case(),
        }
    }
}

impl ToTokens for FindAllFn<'_> {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        let id = self.id;
        let entity = self.entity;
        let query_error = &self.query_error;
        let query_fn_op_traits = RepositoryOptions::query_fn_op_traits(self.any_nested);
        let query_fn_get_op = RepositoryOptions::query_fn_get_op(self.any_nested);

        let generics = if self.any_nested {
            quote! { <Out: From<#entity>> }
        } else {
            quote! { <'a, Out: From<#entity>> }
        };

        let query = format!("SELECT id FROM {} WHERE id = ANY($1)", self.table_name);

        let es_query_call = if let Some(prefix) = self.prefix {
            quote! {
                es_entity::es_query!(
                    tbl_prefix = #prefix,
                    #query,
                    ids as &[#id],
                )
            }
        } else {
            quote! {
                es_entity::es_query!(
                    entity = #entity,
                    #query,
                    ids as &[#id],
                )
            }
        };

        let op_param = if self.any_nested {
            quote! { op: &mut impl #query_fn_op_traits }
        } else {
            quote! { op: impl #query_fn_op_traits }
        };

        #[cfg(feature = "instrument")]
        let instrument_attr = {
            let entity_name = entity.to_string();
            let repo_name = &self.repo_name_snake;
            let span_name = format!("{}.find_all", repo_name);
            quote! {
                #[tracing::instrument(name = #span_name, skip_all, fields(entity = #entity_name, count = ids.len(), ids = tracing::field::debug(ids)), err)]
            }
        };
        #[cfg(not(feature = "instrument"))]
        let instrument_attr = quote! {};

        let post_hydrate_check = if self.post_hydrate_error.is_some() {
            quote! {
                for __entity in &entities {
                    self.execute_post_hydrate_hook(__entity).map_err(#query_error::PostHydrateError)?;
                }
            }
        } else {
            quote! {}
        };

        tokens.append_all(quote! {
            pub async fn find_all<Out: From<#entity>>(
                &self,
                ids: &[#id]
            ) -> Result<std::collections::HashMap<#id, Out>, #query_error> {
                self.find_all_in_op(#query_fn_get_op, ids).await
            }

            #instrument_attr
            pub async fn find_all_in_op #generics(
                &self,
                #op_param,
                ids: &[#id]
            ) -> Result<std::collections::HashMap<#id, Out>, #query_error> {
                 let (entities, _) = #es_query_call.fetch_n(op, ids.len()).await?;
                 #post_hydrate_check
                 Ok(entities.into_iter().map(|u| (u.id.clone(), Out::from(u))).collect())
            }
        });
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use proc_macro2::Span;
    use syn::Ident;

    #[test]
    fn find_all_fn() {
        let id_type = Ident::new("EntityId", Span::call_site());
        let entity = Ident::new("Entity", Span::call_site());
        let query_error = syn::Ident::new("EntityQueryError", Span::call_site());

        let persist_fn = FindAllFn {
            prefix: None,
            id: &id_type,
            entity: &entity,
            table_name: "entities",
            query_error,
            any_nested: false,
            post_hydrate_error: None,
            #[cfg(feature = "instrument")]
            repo_name_snake: "test_repo".to_string(),
        };

        let mut tokens = TokenStream::new();
        persist_fn.to_tokens(&mut tokens);

        let expected = quote! {
            pub async fn find_all<Out: From<Entity>>(
                &self,
                ids: &[EntityId]
            ) -> Result<std::collections::HashMap<EntityId, Out>, EntityQueryError> {
                self.find_all_in_op(self.pool(), ids).await
            }

            pub async fn find_all_in_op<'a, Out: From<Entity>>(
                &self,
                op: impl es_entity::IntoOneTimeExecutor<'a>,
                ids: &[EntityId]
            ) -> Result<std::collections::HashMap<EntityId, Out>, EntityQueryError> {
                let (entities, _) = es_entity::es_query!(
                    entity = Entity,
                    "SELECT id FROM entities WHERE id = ANY($1)",
                    ids as &[EntityId],
                )
                    .fetch_n(op, ids.len())
                    .await?;
                Ok(entities.into_iter().map(|u| (u.id.clone(), Out::from(u))).collect())
            }
        };

        assert_eq!(tokens.to_string(), expected.to_string());
    }
}