use darling::ToTokens;
use proc_macro2::TokenStream;
use quote::{TokenStreamExt, quote};
use super::options::*;
pub struct CreateAllFn<'a> {
entity: &'a syn::Ident,
table_name: &'a str,
columns: &'a Columns,
create_error: syn::Ident,
nested_fn_names: Vec<syn::Ident>,
post_hydrate_error: Option<&'a syn::Type>,
post_persist_error: Option<&'a syn::Type>,
#[cfg(feature = "instrument")]
repo_name_snake: String,
}
impl<'a> From<&'a RepositoryOptions> for CreateAllFn<'a> {
fn from(opts: &'a RepositoryOptions) -> Self {
Self {
table_name: opts.table_name(),
entity: opts.entity(),
create_error: opts.create_error(),
nested_fn_names: opts
.all_nested()
.map(|f| f.create_nested_fn_name())
.collect(),
columns: &opts.columns,
post_hydrate_error: opts.post_hydrate_hook.as_ref().map(|h| &h.error),
post_persist_error: opts.post_persist_hook.as_ref().map(|h| &h.error),
#[cfg(feature = "instrument")]
repo_name_snake: opts.repo_name_snake_case(),
}
}
}
impl ToTokens for CreateAllFn<'_> {
fn to_tokens(&self, tokens: &mut TokenStream) {
let entity = self.entity;
let create_error = &self.create_error;
let nested = self.nested_fn_names.iter().map(|f| {
quote! {
self.#f(op, &mut entity).await?;
}
});
let maybe_mut_entity = if self.nested_fn_names.is_empty() {
quote! { entity }
} else {
quote! { mut entity }
};
let table_name = self.table_name;
let column_names = self.columns.insert_column_names();
let placeholders = self.columns.insert_placeholders(1);
let (arg_collection, bindings) = self
.columns
.create_all_arg_collection(syn::parse_quote! { new_entity });
let query = format!(
"INSERT INTO {} (created_at, {}) \
SELECT COALESCE($1, NOW()), unnested.{} \
FROM UNNEST({}) \
AS unnested({})",
table_name,
column_names.join(", "),
column_names.join(", unnested."),
placeholders,
column_names.join(", "),
);
#[cfg(feature = "instrument")]
let (instrument_attr, error_recording) = {
let entity_name = entity.to_string();
let repo_name = &self.repo_name_snake;
let span_name = format!("{}.create_all", repo_name);
(
quote! {
#[tracing::instrument(name = #span_name, skip_all, fields(entity = #entity_name, count = new_entities.len(), error = tracing::field::Empty, exception.message = tracing::field::Empty, exception.type = tracing::field::Empty))]
},
quote! {
if let Err(ref e) = __result {
tracing::Span::current().record("error", true);
tracing::Span::current().record("exception.message", tracing::field::display(e));
tracing::Span::current().record("exception.type", std::any::type_name_of_val(e));
}
},
)
};
#[cfg(not(feature = "instrument"))]
let (instrument_attr, error_recording) = (quote! {}, quote! {});
let post_hydrate_check = if self.post_hydrate_error.is_some() {
quote! {
self.execute_post_hydrate_hook(&entity).map_err(#create_error::PostHydrateError)?;
}
} else {
quote! {}
};
let post_persist_check = if self.post_persist_error.is_some() {
quote! {
self.execute_post_persist_hook(op, &entity, entity.events().last_persisted(n_events)).await.map_err(#create_error::PostPersistHookError)?;
}
} else {
quote! {}
};
tokens.append_all(quote! {
pub async fn create_all(
&self,
new_entities: Vec<<#entity as es_entity::EsEntity>::New>
) -> Result<Vec<#entity>, #create_error> {
let mut op = self.begin_op().await?;
let res = self.create_all_in_op(&mut op, new_entities).await?;
op.commit().await?;
Ok(res)
}
#instrument_attr
pub async fn create_all_in_op<OP>(
&self,
op: &mut OP,
new_entities: Vec<<#entity as es_entity::EsEntity>::New>
) -> Result<Vec<#entity>, #create_error>
where
OP: es_entity::AtomicOperation
{
let __result: Result<Vec<#entity>, #create_error> = async {
let mut res = Vec::new();
if new_entities.is_empty() {
return Ok(res);
}
#arg_collection
let now = op.maybe_now();
sqlx::query(#query)
.bind(now)
#(#bindings)*
.fetch_all(op.as_executor())
.await
.map_err(|e| match &e {
sqlx::Error::Database(db_err) if db_err.is_unique_violation() => {
#create_error::ConstraintViolation {
column: Self::map_constraint_column(db_err.constraint()),
value: es_entity::extract_constraint_value(db_err.as_ref()),
inner: e,
}
}
_ => #create_error::Sqlx(e),
})?;
let mut all_events: Vec<es_entity::EntityEvents<<#entity as es_entity::EsEntity>::Event>> = new_entities.into_iter().map(Self::convert_new).collect();
let mut n_persisted = Self::extract_concurrent_modification(
self.persist_events_batch(op, &mut all_events).await,
#create_error::ConcurrentModification,
)?;
for events in all_events.into_iter() {
let n_events = n_persisted.remove(events.id()).expect("n_events exists");
let #maybe_mut_entity = Self::hydrate_entity(events)?;
#(#nested)*
#post_hydrate_check
#post_persist_check
res.push(entity);
}
Ok(res)
}.await;
#error_recording
__result
}
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use proc_macro2::Span;
use syn::Ident;
#[test]
fn create_all_fn() {
let entity = Ident::new("Entity", Span::call_site());
let create_error = syn::Ident::new("EntityCreateError", Span::call_site());
use darling::FromMeta;
let input: syn::Meta = syn::parse_quote!(columns(id = "EntityId", name = "String",));
let columns = Columns::from_meta(&input).expect("Failed to parse Fields");
let create_fn = CreateAllFn {
table_name: "entities",
entity: &entity,
create_error,
columns: &columns,
nested_fn_names: Vec::new(),
post_hydrate_error: None,
post_persist_error: None,
#[cfg(feature = "instrument")]
repo_name_snake: "test_repo".to_string(),
};
let mut tokens = TokenStream::new();
create_fn.to_tokens(&mut tokens);
let mut tokens = TokenStream::new();
create_fn.to_tokens(&mut tokens);
let expected = quote! {
pub async fn create_all(
&self,
new_entities: Vec<<Entity as es_entity::EsEntity>::New>
) -> Result<Vec<Entity>, EntityCreateError> {
let mut op = self.begin_op().await?;
let res = self.create_all_in_op(&mut op, new_entities).await?;
op.commit().await?;
Ok(res)
}
pub async fn create_all_in_op<OP>(
&self,
op: &mut OP,
new_entities: Vec<<Entity as es_entity::EsEntity>::New>
) -> Result<Vec<Entity>, EntityCreateError>
where
OP: es_entity::AtomicOperation
{
let __result: Result<Vec<Entity>, EntityCreateError> = async {
let mut res = Vec::new();
if new_entities.is_empty() {
return Ok(res);
}
let mut id_collection = Vec::new();
let mut name_collection = Vec::new();
for new_entity in new_entities.iter() {
let id: &EntityId = &new_entity.id;
let name: &String = &new_entity.name;
id_collection.push(id);
name_collection.push(name);
}
let now = op.maybe_now();
sqlx::query(
"INSERT INTO entities (created_at, id, name) SELECT COALESCE($1, NOW()), unnested.id, unnested.name FROM UNNEST($2, $3) AS unnested(id, name)")
.bind(now)
.bind(id_collection)
.bind(name_collection)
.fetch_all(op.as_executor())
.await
.map_err(|e| match &e {
sqlx::Error::Database(db_err) if db_err.is_unique_violation() => {
EntityCreateError::ConstraintViolation {
column: Self::map_constraint_column(db_err.constraint()),
value: es_entity::extract_constraint_value(db_err.as_ref()),
inner: e,
}
}
_ => EntityCreateError::Sqlx(e),
})?;
let mut all_events: Vec<es_entity::EntityEvents<<#entity as es_entity::EsEntity>::Event>> = new_entities.into_iter().map(Self::convert_new).collect();
let mut n_persisted = Self::extract_concurrent_modification(
self.persist_events_batch(op, &mut all_events).await,
EntityCreateError::ConcurrentModification,
)?;
for events in all_events.into_iter() {
let n_events = n_persisted.remove(events.id()).expect("n_events exists");
let entity = Self::hydrate_entity(events)?;
res.push(entity);
}
Ok(res)
}.await;
__result
}
};
assert_eq!(tokens.to_string(), expected.to_string());
}
}