use convert_case::{Case, Casing};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use super::{parse::EntityDef, sql::postgres::Context};
use crate::utils::{marker, tracing::instrument};
pub fn generate(entity: &EntityDef) -> TokenStream {
if !entity.has_transactions() {
return TokenStream::new();
}
let repo_adapter = generate_repo_adapter(entity);
let builder_ext = generate_builder_extension(entity);
let context_ext = generate_context_extension(entity);
quote! {
#repo_adapter
#builder_ext
#context_ext
}
}
fn generate_repo_adapter(entity: &EntityDef) -> TokenStream {
let vis = &entity.vis;
let ctx = Context::new(entity);
let entity_name = ctx.entity_name;
let row_name = &ctx.row_name;
let insertable_name = &ctx.insertable_name;
let create_dto = &ctx.create_dto;
let update_dto = &ctx.update_dto;
let table = &ctx.table;
let columns_str = &ctx.columns_str;
let placeholders_str = &ctx.placeholders_str;
let id_name = ctx.id_name;
let id_type = ctx.id_type;
let soft_delete = ctx.soft_delete;
let repo_name = format_ident!("{}TransactionRepo", entity_name);
let marker = marker::generated();
let bindings = super::sql::postgres::helpers::insert_bindings(entity.all_fields());
let deleted_filter = if soft_delete {
" AND deleted_at IS NULL"
} else {
""
};
let entity_name_str = entity_name.to_string();
let create_span = instrument(&entity_name_str, "tx.create");
let create_method = if entity.create_fields().is_empty() {
TokenStream::new()
} else {
quote! {
#create_span
pub async fn create(
&mut self,
dto: #create_dto
) -> Result<#entity_name, sqlx::Error> {
let entity = #entity_name::from(dto);
let insertable = #insertable_name::from(&entity);
let row: #row_name = sqlx::query_as(
concat!("INSERT INTO ", #table, " (", #columns_str, ") VALUES (", #placeholders_str, ") RETURNING *")
)
#(#bindings)*
.fetch_one(&mut **self.tx).await?;
Ok(#entity_name::from(row))
}
}
};
let update_span = instrument(&entity_name_str, "tx.update");
let update_method = if entity.update_fields().is_empty() {
TokenStream::new()
} else {
let update_fields = entity.update_fields();
let field_names: Vec<String> = update_fields.iter().map(|f| f.name_str()).collect();
let field_refs: Vec<&str> = field_names.iter().map(String::as_str).collect();
let set_clause = ctx.dialect.set_clause(&field_refs);
let where_placeholder = ctx.dialect.placeholder(update_fields.len() + 1);
let update_bindings = super::sql::postgres::helpers::update_bindings(&update_fields);
quote! {
#update_span
pub async fn update(
&mut self,
id: #id_type,
dto: #update_dto
) -> Result<#entity_name, sqlx::Error> {
let row: #row_name = sqlx::query_as(
&format!("UPDATE {} SET {} WHERE {} = {} RETURNING *",
#table, #set_clause, stringify!(#id_name), #where_placeholder)
)
#(#update_bindings)*
.bind(&id)
.fetch_one(&mut **self.tx).await?;
Ok(#entity_name::from(row))
}
}
};
let delete_sql = if soft_delete {
quote! {
let result = sqlx::query(&format!(
"UPDATE {} SET deleted_at = NOW() WHERE {} = $1 AND deleted_at IS NULL",
#table, stringify!(#id_name)
)).bind(&id).execute(&mut **self.tx).await?;
Ok(result.rows_affected() > 0)
}
} else {
quote! {
let result = sqlx::query(&format!(
"DELETE FROM {} WHERE {} = $1",
#table, stringify!(#id_name)
)).bind(&id).execute(&mut **self.tx).await?;
Ok(result.rows_affected() > 0)
}
};
let find_span = instrument(&entity_name_str, "tx.find_by_id");
let delete_op = if soft_delete {
"tx.soft_delete"
} else {
"tx.delete"
};
let delete_span = instrument(&entity_name_str, delete_op);
let list_span = instrument(&entity_name_str, "tx.list");
quote! {
#marker
#vis struct #repo_name<'t> {
tx: &'t mut sqlx::Transaction<'static, sqlx::Postgres>,
}
impl<'t> #repo_name<'t> {
#[doc(hidden)]
pub fn new(tx: &'t mut sqlx::Transaction<'static, sqlx::Postgres>) -> Self {
Self { tx }
}
#create_method
#find_span
pub async fn find_by_id(
&mut self,
id: #id_type
) -> Result<Option<#entity_name>, sqlx::Error> {
let row: Option<#row_name> = sqlx::query_as(
&format!("SELECT {} FROM {} WHERE {} = $1{}",
#columns_str, #table, stringify!(#id_name), #deleted_filter)
).bind(&id).fetch_optional(&mut **self.tx).await?;
Ok(row.map(#entity_name::from))
}
#update_method
#delete_span
pub async fn delete(
&mut self,
id: #id_type
) -> Result<bool, sqlx::Error> {
#delete_sql
}
#list_span
pub async fn list(
&mut self,
limit: i64,
offset: i64
) -> Result<Vec<#entity_name>, sqlx::Error> {
let where_clause = if #soft_delete { "WHERE deleted_at IS NULL " } else { "" };
let rows: Vec<#row_name> = sqlx::query_as(
&format!("SELECT {} FROM {} {}ORDER BY {} DESC LIMIT $1 OFFSET $2",
#columns_str, #table, where_clause, stringify!(#id_name))
).bind(limit).bind(offset).fetch_all(&mut **self.tx).await?;
Ok(rows.into_iter().map(#entity_name::from).collect())
}
}
}
}
fn generate_builder_extension(entity: &EntityDef) -> TokenStream {
let vis = &entity.vis;
let entity_name = entity.name();
let entity_snake = entity.name_str().to_case(Case::Snake);
let plural = pluralize(&entity_snake);
let method_name = format_ident!("with_{}", plural);
let trait_name = format_ident!("TransactionWith{}", entity_name);
let marker = marker::generated();
let deprecation_note = format!(
"no-op; repositories are accessed via `ctx.{plural}()` inside the closure of \
`Transaction::run` / `run_with_commit`. Drop the `.{method_name}()` call. \
Slated for removal in 0.8.0."
);
quote! {
#marker
#vis trait #trait_name<'p> {
#[deprecated(note = #deprecation_note)]
fn #method_name(self) -> Self;
}
impl<'p> #trait_name<'p> for entity_core::transaction::Transaction<'p, sqlx::PgPool> {
fn #method_name(self) -> Self {
self
}
}
}
}
fn generate_context_extension(entity: &EntityDef) -> TokenStream {
let vis = &entity.vis;
let entity_name = entity.name();
let entity_snake = entity.name_str().to_case(Case::Snake);
let plural = pluralize(&entity_snake);
let accessor_name = format_ident!("{}", plural);
let trait_name = format_ident!("{}ContextExt", entity_name);
let repo_name = format_ident!("{}TransactionRepo", entity_name);
let marker = marker::generated();
quote! {
#marker
#vis trait #trait_name {
fn #accessor_name(&mut self) -> #repo_name<'_>;
}
impl #trait_name for entity_core::transaction::TransactionContext {
fn #accessor_name(&mut self) -> #repo_name<'_> {
#repo_name::new(self.transaction())
}
}
}
}
fn pluralize(word: &str) -> String {
if let Some(plural) = irregular_plural(word) {
return plural;
}
if word.ends_with('s')
|| word.ends_with('x')
|| word.ends_with('z')
|| word.ends_with("ch")
|| word.ends_with("sh")
{
format!("{word}es")
} else if let Some(without_y) = word.strip_suffix('y') {
if let Some(c) = without_y.chars().last()
&& !"aeiou".contains(c)
{
return format!("{without_y}ies");
}
format!("{word}s")
} else {
format!("{word}s")
}
}
fn irregular_plural(word: &str) -> Option<String> {
const IRREGULARS: &[(&str, &str)] = &[
("child", "children"),
("person", "people"),
("mouse", "mice"),
("goose", "geese"),
("foot", "feet"),
("tooth", "teeth"),
("man", "men"),
("woman", "women"),
("datum", "data"),
("criterion", "criteria")
];
let lower = word.to_ascii_lowercase();
let (_, plural) = IRREGULARS.iter().find(|(singular, _)| *singular == lower)?;
if word.chars().next().is_some_and(char::is_uppercase) {
let mut capitalised = String::with_capacity(plural.len());
let mut chars = plural.chars();
if let Some(first) = chars.next() {
capitalised.extend(first.to_uppercase());
capitalised.extend(chars);
}
Some(capitalised)
} else {
Some((*plural).to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn regular_plurals_unchanged() {
assert_eq!(pluralize("user"), "users");
assert_eq!(pluralize("order"), "orders");
assert_eq!(pluralize("account"), "accounts");
}
#[test]
fn sibilant_suffix_adds_es() {
assert_eq!(pluralize("box"), "boxes");
assert_eq!(pluralize("class"), "classes");
assert_eq!(pluralize("buzz"), "buzzes");
assert_eq!(pluralize("match"), "matches");
assert_eq!(pluralize("brush"), "brushes");
}
#[test]
fn consonant_y_becomes_ies() {
assert_eq!(pluralize("category"), "categories");
assert_eq!(pluralize("country"), "countries");
assert_eq!(pluralize("history"), "histories");
}
#[test]
fn vowel_y_stays_y_plus_s() {
assert_eq!(pluralize("day"), "days");
assert_eq!(pluralize("key"), "keys");
assert_eq!(pluralize("boy"), "boys");
}
#[test]
fn irregular_child_becomes_children() {
assert_eq!(pluralize("child"), "children");
}
#[test]
fn irregular_person_becomes_people() {
assert_eq!(pluralize("person"), "people");
}
#[test]
fn irregular_mouse_becomes_mice() {
assert_eq!(pluralize("mouse"), "mice");
}
#[test]
fn irregular_goose_becomes_geese() {
assert_eq!(pluralize("goose"), "geese");
}
#[test]
fn irregular_foot_becomes_feet() {
assert_eq!(pluralize("foot"), "feet");
}
#[test]
fn irregular_tooth_becomes_teeth() {
assert_eq!(pluralize("tooth"), "teeth");
}
#[test]
fn irregular_man_becomes_men() {
assert_eq!(pluralize("man"), "men");
}
#[test]
fn irregular_woman_becomes_women() {
assert_eq!(pluralize("woman"), "women");
}
#[test]
fn irregular_datum_becomes_data() {
assert_eq!(pluralize("datum"), "data");
}
#[test]
fn irregular_criterion_becomes_criteria() {
assert_eq!(pluralize("criterion"), "criteria");
}
#[test]
fn irregular_lookup_is_case_insensitive() {
assert_eq!(pluralize("Child"), "Children");
assert_eq!(pluralize("Person"), "People");
assert_eq!(pluralize("child"), "children");
}
#[test]
fn irregular_does_not_match_compound_words() {
assert_eq!(pluralize("childcare"), "childcares");
assert_eq!(pluralize("manager"), "managers");
}
}