diesel-derive-enum 1.0.0

Derive diesel boilerplate for using enums in databases
Documentation
#![recursion_limit = "1024"]

extern crate proc_macro;

use heck::SnakeCase;
use proc_macro::TokenStream;
use quote::quote;
use syn::*;
use proc_macro2::{Ident, Span};

#[proc_macro_derive(DbEnum, attributes(PgType, DieselType, db_rename))]
pub fn derive(input: TokenStream) -> TokenStream {
    let input:DeriveInput =  parse_macro_input!(input as DeriveInput);
    let db_type = 
        type_from_attrs(&input.attrs, "PgType")
            .unwrap_or(input.ident.to_string().to_snake_case());
    let diesel_mapping = 
        type_from_attrs(&input.attrs, "DieselType")
            .unwrap_or(format!("{}Mapping", input.ident));

    let diesel_mapping = Ident::new(diesel_mapping.as_ref(), Span::call_site());
    let quoted = if let Data::Enum(syn::DataEnum{variants: data_variants, ..})= input.data {
        generate_derive_enum_impls(&db_type, &diesel_mapping, &input.ident, &data_variants)
    } else {
        return syn::Error::new(Span::call_site(), "derive(DbEnum) can only be applied to enums").to_compile_error().into()
    };
    quoted.into()
}



fn type_from_attrs(attrs: &[Attribute], attrname: &str) -> Option<String> {
    for attr in attrs {
        if attr.path.is_ident(attrname){
            match attr.parse_meta().ok()? {
                Meta::NameValue(MetaNameValue { lit: Lit::Str(lit_str), .. }) => {
                    return Some(lit_str.value())
                }
                _ => {
                    return None
                }
            }
        }
    }
    None
}

fn generate_derive_enum_impls(
    db_type: &str,
    diesel_mapping: &Ident,
    enum_ty: &Ident,
    variants: &syn::punctuated::Punctuated<Variant, syn::token::Comma>,
) -> TokenStream {
    let modname = Ident::new(&format!("db_enum_impl_{}", enum_ty), Span::call_site());
    let variant_ids: Vec<proc_macro2::TokenStream> = variants
        .iter()
        .map(|variant| {
            if let Fields::Unit = variant.fields {
                let id = &variant.ident;
                quote! {
                    #enum_ty::#id
                }
            } else {
                panic!("Variants must be fieldless")
            }
        })
        .collect();

    let variants_db: Vec<LitByteStr> = variants
        .iter()
        .map(|variant| {
            let dbname = type_from_attrs(&variant.attrs, "db_rename")
                .unwrap_or(variant.ident.to_string().to_snake_case());
            LitByteStr::new(&dbname.into_bytes(), Span::call_site())
        })
        .collect();


    let variants_rs: &[proc_macro2::TokenStream] = &variant_ids;
    let variants_db: &[LitByteStr] = &variants_db;

    let common_impl = generate_common_impl(diesel_mapping, enum_ty, variants_rs, variants_db);
 

    let pg_impl = if cfg!(feature = "postgres") {
        generate_postgres_impl(db_type, diesel_mapping, enum_ty, variants_rs, variants_db)
    } else {
        quote!{}
    };
    let mysql_impl = if cfg!(feature = "mysql") {
        generate_mysql_impl(diesel_mapping, enum_ty, variants_rs, variants_db)
    } else {
        quote!{}
    };
    let sqlite_impl = if cfg!(feature = "sqlite") {
        generate_sqlite_impl(diesel_mapping, enum_ty, variants_rs, variants_db)
    } else {
        quote!{}
    };
    
    let quoted = quote! {
        pub use self::#modname::#diesel_mapping;
        #[allow(non_snake_case)]
        mod #modname {
            #common_impl
            #pg_impl
            #mysql_impl
            #sqlite_impl
        }
    };


    quoted.into()
}

fn generate_common_impl(
    diesel_mapping: &Ident,
    enum_ty: &Ident,
    variants_rs: &[proc_macro2::TokenStream],
    variants_db: &[LitByteStr],
) -> proc_macro2::TokenStream {
    quote! {
        use super::*;
        use diesel::Queryable;
        use diesel::backend::Backend;
        use diesel::expression::AsExpression;
        use diesel::expression::bound::Bound;
        use diesel::row::Row;
        use diesel::sql_types::*;
        use diesel::serialize::{self, ToSql, IsNull, Output};
        use diesel::deserialize::{self, FromSql, FromSqlRow};
        use diesel::query_builder::QueryId;
        use std::io::Write;

        pub struct #diesel_mapping;
        impl QueryId for #diesel_mapping {
            type QueryId = #diesel_mapping;
            const HAS_STATIC_QUERY_ID: bool = true;
        }
        impl NotNull for #diesel_mapping {}
        impl SingleValue for #diesel_mapping {}

        impl AsExpression<#diesel_mapping> for #enum_ty {
            type Expression = Bound<#diesel_mapping, Self>;

            fn as_expression(self) -> Self::Expression {
                Bound::new(self)
            }
        }

        impl AsExpression<Nullable<#diesel_mapping>> for #enum_ty {
            type Expression = Bound<Nullable<#diesel_mapping>, Self>;

            fn as_expression(self) -> Self::Expression {
                Bound::new(self)
            }
        }

        impl<'a> AsExpression<#diesel_mapping> for &'a #enum_ty {
            type Expression = Bound<#diesel_mapping, Self>;

            fn as_expression(self) -> Self::Expression {
                Bound::new(self)
            }
        }

        impl<'a> AsExpression<Nullable<#diesel_mapping>> for &'a #enum_ty {
            type Expression = Bound<Nullable<#diesel_mapping>, Self>;

            fn as_expression(self) -> Self::Expression {
                Bound::new(self)
            }
        }

        impl<'a, 'b> AsExpression<#diesel_mapping> for &'a &'b #enum_ty {
            type Expression = Bound<#diesel_mapping, Self>;

            fn as_expression(self) -> Self::Expression {
                Bound::new(self)
            }
        }

        impl<'a, 'b> AsExpression<Nullable<#diesel_mapping>> for &'a &'b #enum_ty {
            type Expression = Bound<Nullable<#diesel_mapping>, Self>;

            fn as_expression(self) -> Self::Expression {
                Bound::new(self)
            }
        }

        impl<DB: Backend> ToSql<#diesel_mapping, DB> for #enum_ty {
            fn to_sql<W: Write>(&self, out: &mut Output<W, DB>) -> serialize::Result {
                match *self {
                    #(#variants_rs => out.write_all(#variants_db)?,)*
                }
                Ok(IsNull::No)
            }
        }

        impl<DB> ToSql<Nullable<#diesel_mapping>, DB> for #enum_ty
        where
            DB: Backend,
            Self: ToSql<#diesel_mapping, DB>,
        {
            fn to_sql<W: ::std::io::Write>(&self, out: &mut Output<W, DB>) -> serialize::Result {
                ToSql::<#diesel_mapping, DB>::to_sql(self, out)
            }
        }
    }
}

fn generate_postgres_impl(
    db_type: &str,
    diesel_mapping: &Ident,
    enum_ty: &Ident,
    variants_rs: &[proc_macro2::TokenStream],
    variants_db: &[LitByteStr],
) -> proc_macro2::TokenStream {
    quote! {
        mod pg_impl {
            use super::*;
            use diesel::pg::Pg;

            impl HasSqlType<#diesel_mapping> for Pg {
                fn metadata(lookup: &Self::MetadataLookup) -> Self::TypeMetadata {
                    lookup.lookup_type(#db_type)
                }
            }

            impl FromSqlRow<#diesel_mapping, Pg> for #enum_ty {
                fn build_from_row<T: Row<Pg>>(row: &mut T) -> deserialize::Result<Self> {
                    FromSql::<#diesel_mapping, Pg>::from_sql(row.take())
                }
            }

            impl FromSql<#diesel_mapping, Pg> for #enum_ty {
                fn from_sql(bytes: Option<&<Pg as Backend>::RawValue>) -> deserialize::Result<Self> {
                    match bytes {
                        #(Some(#variants_db) => Ok(#variants_rs),)*
                        Some(v) => Err(format!("Unrecognized enum variant: '{}'",
                                               String::from_utf8_lossy(v)).into()),
                        None => Err("Unexpected null for non-null column".into()),
                    }
                }
            }

            impl Queryable<#diesel_mapping, Pg> for #enum_ty {
                type Row = Self;

                fn build(row: Self::Row) -> Self {
                    row
                }
            }
        }
    }
}

fn generate_mysql_impl(
    diesel_mapping: &Ident,
    enum_ty: &Ident,
    variants_rs: &[proc_macro2::TokenStream],
    variants_db: &[LitByteStr],
) -> proc_macro2::TokenStream {
    quote! {
        mod mysql_impl {
            use super::*;
            use diesel;
            use diesel::mysql::Mysql;

            impl HasSqlType<#diesel_mapping> for Mysql {
                fn metadata(_lookup: &Self::MetadataLookup) -> Self::TypeMetadata {
                    diesel::mysql::MysqlType::String
                }
            }

            impl FromSqlRow<#diesel_mapping, Mysql> for #enum_ty {
                fn build_from_row<T: Row<Mysql>>(row: &mut T) -> deserialize::Result<Self> {
                    FromSql::<#diesel_mapping, Mysql>::from_sql(row.take())
                }
            }

            impl FromSql<#diesel_mapping, Mysql> for #enum_ty {
                fn from_sql(bytes: Option<&<Mysql as Backend>::RawValue>) -> deserialize::Result<Self> {
                    match bytes {
                        #(Some(#variants_db) => Ok(#variants_rs),)*
                        Some(v) => Err(format!("Unrecognized enum variant: '{}'",
                                               String::from_utf8_lossy(v)).into()),
                        None => Err("Unexpected null for non-null column".into()),
                    }
                }
            }

            impl Queryable<#diesel_mapping, Mysql> for #enum_ty {
                type Row = Self;

                fn build(row: Self::Row) -> Self {
                    row
                }
            }
        }
    }
}

fn generate_sqlite_impl(
    diesel_mapping: &Ident,
    enum_ty: &Ident,
    variants_rs: &[proc_macro2::TokenStream],
    variants_db: &[LitByteStr],
) -> proc_macro2::TokenStream {
    quote! {
        mod sqlite_impl {
            use super::*;
            use diesel;
            use diesel::sqlite::Sqlite;

            impl HasSqlType<#diesel_mapping> for Sqlite {
                fn metadata(_lookup: &Self::MetadataLookup) -> Self::TypeMetadata {
                    diesel::sqlite::SqliteType::Text
                }
            }

            impl FromSqlRow<#diesel_mapping, Sqlite> for #enum_ty {
                fn build_from_row<T: Row<Sqlite>>(row: &mut T) -> deserialize::Result<Self> {
                    FromSql::<#diesel_mapping, Sqlite>::from_sql(row.take())
                }
            }

            impl FromSql<#diesel_mapping, Sqlite> for #enum_ty {
                fn from_sql(bytes: Option<&<Sqlite as Backend>::RawValue>) -> deserialize::Result<Self> {
                    match bytes.map(|v| v.read_blob()) {
                        #(Some(#variants_db) => Ok(#variants_rs),)*
                        Some(blob) => Err(format!("Unexpected variant: {}", String::from_utf8_lossy(blob)).into()),
                        None => Err("Unexpected null for non-null column".into()),
                    }
                }
            }

            impl Queryable<#diesel_mapping, Sqlite> for #enum_ty {
                type Row = Self;

                fn build(row: Self::Row) -> Self {
                    row
                }
            }
        }
    }
}