use convert_case::{Case, Casing};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use super::{parse::EntityDef, sql::postgres::Context};
use crate::utils::marker;
pub fn generate(entity: &EntityDef) -> TokenStream {
if !entity.has_transactions() {
return TokenStream::new();
}
let repo_adapter = generate_repo_adapter(entity);
let builder_ext = generate_builder_extension(entity);
let context_ext = generate_context_extension(entity);
quote! {
#repo_adapter
#builder_ext
#context_ext
}
}
fn generate_repo_adapter(entity: &EntityDef) -> TokenStream {
let vis = &entity.vis;
let ctx = Context::new(entity);
let entity_name = ctx.entity_name;
let row_name = &ctx.row_name;
let insertable_name = &ctx.insertable_name;
let create_dto = &ctx.create_dto;
let update_dto = &ctx.update_dto;
let table = &ctx.table;
let columns_str = &ctx.columns_str;
let placeholders_str = &ctx.placeholders_str;
let id_name = ctx.id_name;
let id_type = ctx.id_type;
let soft_delete = ctx.soft_delete;
let repo_name = format_ident!("{}TransactionRepo", entity_name);
let marker = marker::generated();
let bindings = super::sql::postgres::helpers::insert_bindings(entity.all_fields());
let deleted_filter = if soft_delete {
" AND deleted_at IS NULL"
} else {
""
};
let create_method = if entity.create_fields().is_empty() {
TokenStream::new()
} else {
quote! {
pub async fn create(
&mut self,
dto: #create_dto
) -> Result<#entity_name, sqlx::Error> {
let entity = #entity_name::from(dto);
let insertable = #insertable_name::from(&entity);
let row: #row_name = sqlx::query_as(
concat!("INSERT INTO ", #table, " (", #columns_str, ") VALUES (", #placeholders_str, ") RETURNING *")
)
#(#bindings)*
.fetch_one(&mut **self.tx).await?;
Ok(#entity_name::from(row))
}
}
};
let update_method = if entity.update_fields().is_empty() {
TokenStream::new()
} else {
let update_fields = entity.update_fields();
let field_names: Vec<String> = update_fields.iter().map(|f| f.name_str()).collect();
let field_refs: Vec<&str> = field_names.iter().map(String::as_str).collect();
let set_clause = ctx.dialect.set_clause(&field_refs);
let where_placeholder = ctx.dialect.placeholder(update_fields.len() + 1);
let update_bindings = super::sql::postgres::helpers::update_bindings(&update_fields);
quote! {
pub async fn update(
&mut self,
id: #id_type,
dto: #update_dto
) -> Result<#entity_name, sqlx::Error> {
let row: #row_name = sqlx::query_as(
&format!("UPDATE {} SET {} WHERE {} = {} RETURNING *",
#table, #set_clause, stringify!(#id_name), #where_placeholder)
)
#(#update_bindings)*
.bind(&id)
.fetch_one(&mut **self.tx).await?;
Ok(#entity_name::from(row))
}
}
};
let delete_sql = if soft_delete {
quote! {
let result = sqlx::query(&format!(
"UPDATE {} SET deleted_at = NOW() WHERE {} = $1 AND deleted_at IS NULL",
#table, stringify!(#id_name)
)).bind(&id).execute(&mut **self.tx).await?;
Ok(result.rows_affected() > 0)
}
} else {
quote! {
let result = sqlx::query(&format!(
"DELETE FROM {} WHERE {} = $1",
#table, stringify!(#id_name)
)).bind(&id).execute(&mut **self.tx).await?;
Ok(result.rows_affected() > 0)
}
};
quote! {
#marker
#vis struct #repo_name<'t> {
tx: &'t mut sqlx::Transaction<'static, sqlx::Postgres>,
}
impl<'t> #repo_name<'t> {
#[doc(hidden)]
pub fn new(tx: &'t mut sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
Self { tx }
}
#create_method
pub async fn find_by_id(
&mut self,
id: #id_type
) -> Result<Option<#entity_name>, sqlx::Error> {
let row: Option<#row_name> = sqlx::query_as(
&format!("SELECT {} FROM {} WHERE {} = $1{}",
#columns_str, #table, stringify!(#id_name), #deleted_filter)
).bind(&id).fetch_optional(&mut **self.tx).await?;
Ok(row.map(#entity_name::from))
}
#update_method
pub async fn delete(
&mut self,
id: #id_type
) -> Result<bool, sqlx::Error> {
#delete_sql
}
pub async fn list(
&mut self,
limit: i64,
offset: i64
) -> Result<Vec<#entity_name>, sqlx::Error> {
let where_clause = if #soft_delete { "WHERE deleted_at IS NULL " } else { "" };
let rows: Vec<#row_name> = sqlx::query_as(
&format!("SELECT {} FROM {} {}ORDER BY {} DESC LIMIT $1 OFFSET $2",
#columns_str, #table, where_clause, stringify!(#id_name))
).bind(limit).bind(offset).fetch_all(&mut **self.tx).await?;
Ok(rows.into_iter().map(#entity_name::from).collect())
}
}
}
}
fn generate_builder_extension(entity: &EntityDef) -> TokenStream {
let vis = &entity.vis;
let entity_name = entity.name();
let entity_snake = entity.name_str().to_case(Case::Snake);
let plural = pluralize(&entity_snake);
let method_name = format_ident!("with_{}", plural);
let trait_name = format_ident!("TransactionWith{}", entity_name);
let marker = marker::generated();
quote! {
#marker
#vis trait #trait_name<'p> {
fn #method_name(self) -> Self;
}
impl<'p> #trait_name<'p> for entity_core::transaction::Transaction<'p, sqlx::PgPool> {
fn #method_name(self) -> Self {
self
}
}
}
}
fn generate_context_extension(entity: &EntityDef) -> TokenStream {
let vis = &entity.vis;
let entity_name = entity.name();
let entity_snake = entity.name_str().to_case(Case::Snake);
let plural = pluralize(&entity_snake);
let accessor_name = format_ident!("{}", plural);
let trait_name = format_ident!("{}ContextExt", entity_name);
let repo_name = format_ident!("{}TransactionRepo", entity_name);
let marker = marker::generated();
quote! {
#marker
#vis trait #trait_name {
fn #accessor_name(&mut self) -> #repo_name<'_>;
}
impl #trait_name for entity_core::transaction::TransactionContext {
fn #accessor_name(&mut self) -> #repo_name<'_> {
#repo_name::new(self.transaction())
}
}
}
}
fn pluralize(word: &str) -> String {
if word.ends_with('s')
|| word.ends_with('x')
|| word.ends_with('z')
|| word.ends_with("ch")
|| word.ends_with("sh")
{
format!("{}es", word)
} else if let Some(without_y) = word.strip_suffix('y') {
if let Some(c) = without_y.chars().last()
&& !"aeiou".contains(c)
{
return format!("{}ies", without_y);
}
format!("{}s", word)
} else {
format!("{}s", word)
}
}