use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::DeriveInput;
#[derive(Clone, Copy)]
pub(crate) enum ReturnKind {
RefStr,
OptionRefStr,
Bool,
U64,
DateTime,
OptionDateTime,
RefValue,
RefStatus,
OptionRefValue,
}
pub(crate) use ReturnKind::*;
pub(crate) struct FieldInfo {
pub ident: syn::Ident,
pub auth_field_name: Option<String>,
pub auth_column: Option<String>,
pub auth_default: Option<TokenStream2>,
}
pub(crate) struct ParsedAuthAttrs {
pub field_name: Option<String>,
pub column_name: Option<String>,
pub default_expr: Option<TokenStream2>,
}
pub(crate) fn parse_auth_attrs(attrs: &[syn::Attribute]) -> Result<ParsedAuthAttrs, syn::Error> {
let mut field_name = None;
let mut column_name = None;
let mut default_expr = None;
for attr in attrs {
if !attr.path().is_ident("auth") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("field") {
let value = meta.value()?;
let lit: syn::LitStr = value.parse()?;
field_name = Some(lit.value());
} else if meta.path.is_ident("column") {
let value = meta.value()?;
let lit: syn::LitStr = value.parse()?;
column_name = Some(lit.value());
} else if meta.path.is_ident("default") {
let value = meta.value()?;
let lit: syn::LitStr = value.parse()?;
let parsed: syn::Expr = syn::parse_str(&lit.value()).map_err(|e| {
syn::Error::new_spanned(&lit, format!("invalid default expression: {e}"))
})?;
default_expr = Some(quote! { #parsed });
} else if !meta.path.is_ident("from_row")
&& !meta.path.is_ident("json")
&& !meta.path.is_ident("table")
{
}
Ok(())
})?;
}
Ok(ParsedAuthAttrs {
field_name,
column_name,
default_expr,
})
}
pub(crate) fn parse_struct_auth_table(
attrs: &[syn::Attribute],
) -> Result<Option<String>, syn::Error> {
let mut table_name = None;
for attr in attrs {
if !attr.path().is_ident("auth") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("table") {
let value = meta.value()?;
let lit: syn::LitStr = value.parse()?;
table_name = Some(lit.value());
} else if !meta.path.is_ident("field")
&& !meta.path.is_ident("column")
&& !meta.path.is_ident("default")
&& !meta.path.is_ident("from_row")
&& !meta.path.is_ident("json")
{
}
Ok(())
})?;
}
Ok(table_name)
}
pub(crate) fn find_field_for_getter<'a>(
fields: &'a [FieldInfo],
getter_name: &str,
) -> Option<&'a syn::Ident> {
for f in fields {
if let Some(ref mapped) = f.auth_field_name
&& mapped == getter_name
{
return Some(&f.ident);
}
}
for f in fields {
if f.ident == getter_name {
return Some(&f.ident);
}
}
None
}
pub(crate) fn parse_named_fields(
input: &DeriveInput,
trait_name: &str,
) -> Result<Vec<FieldInfo>, TokenStream2> {
let struct_name = &input.ident;
let named_fields = match &input.data {
syn::Data::Struct(data) => match &data.fields {
syn::Fields::Named(fields) => &fields.named,
_ => {
return Err(syn::Error::new_spanned(
struct_name,
format!("{trait_name} can only be derived for structs with named fields"),
)
.to_compile_error());
}
},
_ => {
return Err(syn::Error::new_spanned(
struct_name,
format!("{trait_name} requires a struct"),
)
.to_compile_error());
}
};
let mut fields = Vec::new();
for f in named_fields {
let Some(ident) = f.ident.clone() else {
continue;
};
let parsed = parse_auth_attrs(&f.attrs).map_err(|e| e.to_compile_error())?;
fields.push(FieldInfo {
ident,
auth_field_name: parsed.field_name,
auth_column: parsed.column_name,
auth_default: parsed.default_expr,
});
}
Ok(fields)
}
pub(crate) fn gen_getter_tokens(
field_ident: &syn::Ident,
kind: ReturnKind,
) -> (TokenStream2, TokenStream2) {
match kind {
RefStr => (quote! { &str }, quote! { &self.#field_ident }),
OptionRefStr => (
quote! { ::core::option::Option<&str> },
quote! { self.#field_ident.as_deref() },
),
Bool => (quote! { bool }, quote! { self.#field_ident }),
U64 => (quote! { u64 }, quote! { self.#field_ident }),
DateTime => (
quote! { ::chrono::DateTime<::chrono::Utc> },
quote! { self.#field_ident },
),
OptionDateTime => (
quote! { ::core::option::Option<::chrono::DateTime<::chrono::Utc>> },
quote! { self.#field_ident },
),
RefValue => (
quote! { &::serde_json::Value },
quote! { &self.#field_ident },
),
RefStatus => (
quote! { &::better_auth_core::types::InvitationStatus },
quote! { &self.#field_ident },
),
OptionRefValue => (
quote! { ::core::option::Option<&::serde_json::Value> },
quote! { self.#field_ident.as_ref() },
),
}
}