use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::DeriveInput;
pub(crate) fn has_auth_from_row(attrs: &[syn::Attribute]) -> bool {
attrs.iter().any(|attr| {
if !attr.path().is_ident("auth") {
return false;
}
let mut found = false;
let _ = attr.parse_nested_meta(|meta| {
if meta.path.is_ident("from_row") {
found = true;
}
Ok(())
});
found
})
}
struct FromRowField {
ident: syn::Ident,
ty: syn::Type,
is_json: bool,
default_expr: Option<TokenStream2>,
auth_field_name: Option<String>,
auth_column: Option<String>,
}
fn parse_from_row_fields(input: &DeriveInput) -> Result<Vec<FromRowField>, TokenStream2> {
let struct_name = &input.ident;
let named_fields = match &input.data {
syn::Data::Struct(data) => match &data.fields {
syn::Fields::Named(fields) => &fields.named,
_ => {
return Err(syn::Error::new_spanned(
struct_name,
"from_row requires a struct with named fields",
)
.to_compile_error());
}
},
_ => {
return Err(
syn::Error::new_spanned(struct_name, "from_row requires a struct")
.to_compile_error(),
);
}
};
let mut fields = Vec::new();
for f in named_fields {
let Some(ident) = f.ident.clone() else {
continue;
};
let ty = f.ty.clone();
let mut is_json = false;
let mut auth_field_name = None;
let mut auth_column = None;
let mut default_expr = None;
for attr in &f.attrs {
if !attr.path().is_ident("auth") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("json") {
is_json = true;
} else if meta.path.is_ident("field") {
let value = meta.value()?;
let lit: syn::LitStr = value.parse()?;
auth_field_name = Some(lit.value());
} else if meta.path.is_ident("column") {
let value = meta.value()?;
let lit: syn::LitStr = value.parse()?;
auth_column = Some(lit.value());
} else if meta.path.is_ident("default") {
let value = meta.value()?;
let lit: syn::LitStr = value.parse()?;
let parsed: syn::Expr = syn::parse_str(&lit.value()).map_err(|e| {
syn::Error::new_spanned(&lit, format!("invalid default expression: {e}"))
})?;
default_expr = Some(quote! { #parsed });
}
Ok(())
})
.map_err(|e| e.to_compile_error())?;
}
fields.push(FromRowField {
ident,
ty,
is_json,
default_expr,
auth_field_name,
auth_column,
});
}
Ok(fields)
}
fn extract_option_inner(ty: &syn::Type) -> Option<&syn::Type> {
let syn::Type::Path(type_path) = ty else {
return None;
};
let segment = type_path.path.segments.last()?;
if segment.ident != "Option" {
return None;
}
let syn::PathArguments::AngleBracketed(ref args) = segment.arguments else {
return None;
};
match args.args.first()? {
syn::GenericArgument::Type(inner) => Some(inner),
_ => None,
}
}
fn type_last_segment_name(ty: &syn::Type) -> Option<String> {
let syn::Type::Path(type_path) = ty else {
return None;
};
type_path.path.segments.last().map(|s| s.ident.to_string())
}
fn is_known_sqlx_type_name(name: &str) -> bool {
matches!(
name,
"String"
| "bool"
| "i8"
| "i16"
| "i32"
| "i64"
| "u8"
| "u16"
| "u32"
| "u64"
| "f32"
| "f64"
| "DateTime"
| "DateTimeUtc"
| "NaiveDateTime"
| "NaiveDate"
| "NaiveTime"
| "Uuid"
)
}
fn is_json_type_name(name: &str) -> bool {
name == "Json" || name == "Value"
}
fn gen_from_row_field_expr(field: &FromRowField) -> TokenStream2 {
let ident = &field.ident;
let col_name = if let Some(ref col) = field.auth_column {
col.clone()
} else if let Some(ref field_name) = field.auth_field_name {
field_name.clone()
} else {
ident.to_string()
};
let (is_option, inner_ty) = match extract_option_inner(&field.ty) {
Some(inner) => (true, inner),
None => (false, &field.ty),
};
let inner_name = type_last_segment_name(inner_ty);
let is_json = field.is_json || inner_name.as_deref().is_some_and(is_json_type_name);
let is_known = inner_name.as_deref().is_some_and(is_known_sqlx_type_name);
if is_json && is_option {
quote! {
#ident: row.try_get::<
::core::option::Option<::sqlx::types::Json<::serde_json::Value>>, _
>(#col_name)?.map(|j| j.0)
}
} else if is_json {
quote! {
#ident: row.try_get::<
::sqlx::types::Json<::serde_json::Value>, _
>(#col_name)?.0
}
} else if let Some(ref default_expr) = field.default_expr {
quote! {
#ident: row.try_get(#col_name).unwrap_or_else(|_| #default_expr)
}
} else if !is_known && !is_option {
quote! {
#ident: {
let __s: ::std::string::String = row.try_get(#col_name)?;
::core::convert::From::from(__s)
}
}
} else {
quote! {
#ident: row.try_get(#col_name)?
}
}
}
pub(crate) fn maybe_gen_from_row(input: &DeriveInput) -> TokenStream2 {
if !has_auth_from_row(&input.attrs) {
return quote! {};
}
let fields = match parse_from_row_fields(input) {
Ok(f) => f,
Err(e) => return e,
};
let struct_name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let field_exprs: Vec<TokenStream2> = fields.iter().map(gen_from_row_field_expr).collect();
quote! {
impl #impl_generics ::sqlx::FromRow<'_, ::sqlx::postgres::PgRow>
for #struct_name #ty_generics #where_clause
{
fn from_row(
row: &::sqlx::postgres::PgRow,
) -> ::core::result::Result<Self, ::sqlx::Error> {
use ::sqlx::Row as _;
::core::result::Result::Ok(Self {
#(#field_exprs),*
})
}
}
}
}