use {
darling::{ast::Data, util::Ignored, FromDeriveInput},
proc_macro2::TokenStream,
syn::{DeriveInput, Generics, Ident},
};
#[derive(Debug, darling::FromDeriveInput)]
#[darling(attributes(kind), supports(enum_unit))]
pub struct Kind {
#[darling(default)]
sql: Option<String>,
#[cfg(feature = "postgresql")]
#[darling(default)]
postgres: Option<String>,
#[cfg(feature = "sqlite")]
#[darling(default)]
sqlite: Option<String>,
#[darling(default)]
builder: bool,
data: Data<KindVariant, Ignored>,
generics: Generics,
}
#[derive(Clone, Debug, darling::FromVariant)]
#[darling(attributes(kind))]
pub struct KindVariant {
ident: Ident,
#[darling(default)]
rename: Option<String>,
}
impl Kind {
#[allow(clippy::or_fun_call)]
pub fn derive(input: DeriveInput) -> TokenStream {
let kind = Kind::from_derive_input(&input).expect("derive input is not a valid enum type");
let enum_data = kind.data.clone().take_enum();
if enum_data.is_none() {
panic!("The derived enum requires at least one variant");
}
let pairs = enum_data
.unwrap()
.into_iter()
.map(|var| (var.rename.unwrap_or(var.ident.to_string()), var.ident))
.collect::<Vec<(String, Ident)>>();
let ident = &input.ident;
let mut new_generics = kind.generics.clone();
new_generics.params.push(syn::parse_quote!('__from_sql));
let (impl_generics, type_generics, _) = kind.generics.split_for_impl();
let (postgres_impl_generics, _, where_clause) = new_generics.split_for_impl();
let postgresql = if cfg!(feature = "postgresql") {
let postgres_from = Self::postgres_from(&ident, &pairs);
let postgres_to = Self::postgres_to(&ident, &pairs);
let postgres_accepts = Self::postgres_accepts(&ident, &pairs);
Some(quote::quote! {
impl #postgres_impl_generics ::db_derive::internal::PostgresFromSQL<'__from_sql> for #ident #type_generics #where_clause {
fn accepts(type_: &::db_derive::internal::PostgresType) -> bool {
#postgres_accepts
}
fn from_sql(_type: &::db_derive::internal::PostgresType, buf: &'__from_sql [u8]) -> ::std::result::Result<#ident, ::std::boxed::Box<dyn ::std::error::Error + ::std::marker::Sync + ::std::marker::Send>> {
#postgres_from
}
}
impl #impl_generics ::db_derive::internal::PostgresToSQL for #ident #type_generics #where_clause {
fn accepts(type_: &::db_derive::internal::PostgresType) -> bool {
#postgres_accepts
}
fn to_sql(&self, _type: &::db_derive::internal::PostgresType, buf: &mut ::db_derive::internal::PostgresBytesMut) -> ::std::result::Result<::db_derive::internal::PostgresIsNull, ::std::boxed::Box<::std::error::Error + ::std::marker::Sync + ::std::marker::Send>> {
#postgres_to
}
db_derive::postgres_to_sql_checked!();
}
})
} else {
None
};
let sqlite = if cfg!(feature = "sqlite") {
let sqlite_from = Self::sqlite_from(&ident, &pairs);
let sqlite_to = Self::sqlite_to(&ident, &pairs);
Some(quote::quote! {
impl #impl_generics ::db_derive::internal::SQLiteFromSQL for #ident #type_generics #where_clause {
fn column_result(value: ::db_derive::internal::SQLiteValueRef) -> ::db_derive::internal::SQLiteFromSqlResult<#ident> {
#sqlite_from
}
}
impl #impl_generics ::db_derive::internal::SQLiteToSQL for #ident #type_generics #where_clause {
fn to_sql(&self) -> ::db_derive::internal::SQLiteResult<::db_derive::internal::SQLiteToSqlOutput> {
#sqlite_to
}
}
})
} else {
None
};
quote::quote! {
#postgresql
#sqlite
impl #postgres_impl_generics ::db_derive::prelude::Kind<'__from_sql> for #ident #type_generics #where_clause {}
}
}
fn postgres_accepts(ident: &Ident, pairs: &[(String, Ident)]) -> TokenStream {
let name = ident.to_string();
let values = pairs.iter().map(|v| &v.0);
let values_length = values.len();
quote::quote! {
if type_.name() != #name {
return false;
}
match *type_.kind() {
::db_derive::internal::PostgresKind::Enum(ref variants) => {
if variants.len() != #values_length {
return false;
}
variants.iter().all(|v| {
match &**v {
#( #values => true, )*
_ => false,
}
})
}
_ => false,
}
}
}
fn postgres_from(ident: &Ident, pairs: &[(String, Ident)]) -> TokenStream {
let values = pairs.iter().map(|v| &v.0);
let idents = pairs.iter().map(|v| &v.1);
quote::quote! {
match ::std::str::from_utf8(buf)? {
#( #values => ::std::result::Result::Ok(#ident::#idents), )*
s => {
::std::result::Result::Err(::std::convert::Into::into(format!("invalid variant for derived enum type `{}`", s)))
}
}
}
}
fn postgres_to(ident: &Ident, pairs: &[(String, Ident)]) -> TokenStream {
let values = pairs.iter().map(|v| &v.0);
let idents = pairs.iter().map(|v| &v.1);
quote::quote! {
let typ = match self {
#( #ident::#idents => #values, )*
};
buf.extend_from_slice(typ.as_bytes());
::std::result::Result::Ok(::db_derive::internal::PostgresIsNull::No)
}
}
fn sqlite_from(ident: &Ident, pairs: &[(String, Ident)]) -> TokenStream {
let values = pairs.iter().map(|v| &v.0);
let idents = pairs.iter().map(|v| &v.1);
quote::quote! {
use ::db_derive::internal::SQLiteFromSQL;
::std::string::String::column_result(value).map(|as_str| match as_str.as_str() {
#( #values => #ident::#idents, )*
_ => panic!("returned value for derived enum type is unknown, check your database"),
})
}
}
fn sqlite_to(ident: &Ident, pairs: &[(String, Ident)]) -> TokenStream {
let values = pairs.iter().map(|v| &v.0);
let idents = pairs.iter().map(|v| &v.1);
quote::quote! {
::std::result::Result::Ok(match self {
#( #ident::#idents => #values, )*
}.into())
}
}
}