#![doc = include_str!("../README.md")]
#[proc_macro_derive(FromRow, attributes(sqlx_with))]
pub fn derive_sqlx_with(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = syn::parse_macro_input!(input as syn::DeriveInput);
match expand_derive(input) {
Ok(ts) => ts.into(),
Err(e) => e.to_compile_error().into(),
}
}
#[derive(Debug, darling::FromMeta)]
enum RenameAll {
#[darling(rename = "snake_case")]
Snake,
#[darling(rename = "lowercase")]
Lower,
#[darling(rename = "UPPERCASE")]
Upper,
#[darling(rename = "camelCase")]
Camel,
#[darling(rename = "PascalCase")]
Pascal,
#[darling(rename = "SCREAMING_SNAKE_CASE")]
ScreamingSnake,
#[darling(rename = "kebab-case")]
Kebab,
}
#[derive(Debug, darling::FromDeriveInput)]
#[darling(attributes(sqlx_with), supports(struct_named))]
struct DeriveInput {
ident: syn::Ident,
generics: syn::Generics,
data: darling::ast::Data<(), Field>,
db: syn::Path,
rename_all: Option<RenameAll>,
}
#[derive(Debug, darling::FromField)]
#[darling(attributes(sqlx_with))]
struct Field {
ident: Option<syn::Ident>,
ty: syn::Type,
rename: Option<String>,
default: darling::util::Flag,
decode: Option<syn::Path>,
flatten: darling::util::Flag,
}
fn expand_derive(input: syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
use darling::FromDeriveInput as _;
let input = DeriveInput::from_derive_input(&input)?;
let mut struct_expr: syn::ExprStruct = syn::parse_quote!(Self {});
for field in input.data.take_struct().unwrap().fields {
let id = field.ident.unwrap();
let column_val_expr: syn::Expr = if field.flatten.is_present() {
let ty = field.ty;
syn::parse_quote!(#ty::from_row(row)?)
} else {
let column_name = if let Some(rename) = field.rename {
rename
} else if let Some(ref rename_all) = input.rename_all {
use heck::*;
match rename_all {
RenameAll::Snake => id.to_string().to_snake_case(),
RenameAll::Lower => id.to_string().to_lowercase(),
RenameAll::Upper => id.to_string().to_uppercase(),
RenameAll::Camel => id.to_string().to_lower_camel_case(),
RenameAll::Pascal => id.to_string().to_upper_camel_case(),
RenameAll::ScreamingSnake => id.to_string().to_shouty_snake_case(),
RenameAll::Kebab => id.to_string().to_kebab_case(),
}
} else {
id.to_string()
};
let column_get_expr: syn::Expr = if let Some(decode) = field.decode {
syn::parse_quote!(#decode(#column_name, row))
} else {
syn::parse_quote!(row.try_get(#column_name))
};
if field.default.is_present() {
syn::parse_quote! {
match #column_get_expr {
::std::result::Result::Err(::sqlx::Error::ColumnNotFound(_)) => ::std::result::Result::Ok(::std::default::Default::default()),
val => val,
}?
}
} else {
syn::parse_quote!(#column_get_expr?)
}
};
struct_expr
.fields
.push(syn::parse_quote!(#id: #column_val_expr));
}
let struct_ident = input.ident;
let db = input.db;
let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
Ok(quote::quote! {
impl #impl_generics ::sqlx::FromRow<'_, <#db as ::sqlx::Database>::Row> for #struct_ident #type_generics #where_clause {
fn from_row(row: &<#db as ::sqlx::Database>::Row) -> ::sqlx::Result<Self> {
use ::sqlx::Row;
Ok(#struct_expr)
}
}
})
}