use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::Ident;
use crate::generators::inputs::{FilterCategory, update_wrapper_ident};
pub struct UpdateField {
pub name: Ident,
pub column: String,
pub category: FilterCategory,
pub nullable: bool,
pub enum_ident: Option<Ident>,
}
pub struct UpdateInputTokens {
pub struct_tokens: TokenStream,
pub impl_tokens: TokenStream,
}
fn category_has_arithmetic(cat: FilterCategory) -> bool {
matches!(
cat,
FilterCategory::Int
| FilterCategory::BigInt
| FilterCategory::Float
| FilterCategory::Decimal
)
}
pub fn generate(
model_ident: &Ident,
module_name: &Ident,
fields: &[UpdateField],
) -> UpdateInputTokens {
let update_ident = format_ident!("{}UpdateInput", model_ident);
let field_decls = fields.iter().map(|f| {
let n = &f.name;
let wrapper = update_wrapper_ident(f.category, f.nullable);
let doc = match f.category {
FilterCategory::Date => Some(
"Date column. The wrapper expects an `Option<String>` \
formatted as `YYYY-MM-DD`; `DateTimeFieldUpdate` is \
shared across Date/Time/DateTime by design.",
),
FilterCategory::Time => Some(
"Time column. The wrapper expects an `Option<String>` \
formatted as `HH:MM:SS`; `DateTimeFieldUpdate` is \
shared across Date/Time/DateTime by design.",
),
_ => None,
};
let doc_attr = doc.map(|d| quote! { #[doc = #d] });
if matches!(f.category, FilterCategory::Enum) {
let e = f
.enum_ident
.as_ref()
.expect("enum field requires enum ident");
quote! {
#doc_attr
pub #n: ::core::option::Option<::prax_query::inputs::#wrapper<#e>>
}
} else {
quote! {
#doc_attr
pub #n: ::core::option::Option<::prax_query::inputs::#wrapper>
}
}
});
let struct_tokens = quote! {
#[derive(Debug, Clone, Default, ::serde::Serialize, ::serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct #update_ident {
#(#field_decls,)*
}
};
let lowerings: Vec<TokenStream> = fields
.iter()
.filter_map(|f| {
if matches!(f.category, FilterCategory::Enum) && f.enum_ident.is_none() {
return None;
}
let n = &f.name;
let col = &f.column;
let arithmetic_arms = if category_has_arithmetic(f.category) {
quote! {
if let ::core::option::Option::Some(__v) = __w.increment {
__out.push((
::std::string::String::from(#col),
::prax_query::inputs::WriteOp::Increment(
::core::convert::Into::<
::prax_query::filter::FilterValue
>::into(__v),
),
));
}
if let ::core::option::Option::Some(__v) = __w.decrement {
__out.push((
::std::string::String::from(#col),
::prax_query::inputs::WriteOp::Decrement(
::core::convert::Into::<
::prax_query::filter::FilterValue
>::into(__v),
),
));
}
if let ::core::option::Option::Some(__v) = __w.multiply {
__out.push((
::std::string::String::from(#col),
::prax_query::inputs::WriteOp::Multiply(
::core::convert::Into::<
::prax_query::filter::FilterValue
>::into(__v),
),
));
}
if let ::core::option::Option::Some(__v) = __w.divide {
__out.push((
::std::string::String::from(#col),
::prax_query::inputs::WriteOp::Divide(
::core::convert::Into::<
::prax_query::filter::FilterValue
>::into(__v),
),
));
}
}
} else {
quote! {}
};
let unset_arm = if f.nullable {
quote! {
if let ::core::option::Option::Some(true) = __w.unset {
__out.push((
::std::string::String::from(#col),
::prax_query::inputs::WriteOp::Unset,
));
}
}
} else {
quote! {}
};
Some(quote! {
if let ::core::option::Option::Some(__w) = self.#n {
if let ::core::option::Option::Some(__v) = __w.set {
__out.push((
::std::string::String::from(#col),
::prax_query::inputs::WriteOp::Set(
::core::convert::Into::<
::prax_query::filter::FilterValue
>::into(__v),
),
));
}
#arithmetic_arms
#unset_arm
}
})
})
.collect();
let impl_tokens = quote! {
impl ::prax_query::inputs::UpdateInput for #module_name::#update_ident {
type Model = #model_ident;
type Data = ::prax_query::inputs::UpdatePayload;
fn into_ir(self) -> Self::Data {
let mut __out: ::prax_query::inputs::UpdatePayload =
::std::vec::Vec::new();
#(#lowerings)*
__out
}
}
};
UpdateInputTokens {
struct_tokens,
impl_tokens,
}
}