Skip to main content

diesel_enum_number/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Fields, ItemEnum, parse_macro_input};
4
5/// Attribute macro that implements `from_i16`, `ToSql<SmallInt, Pg>`, and
6/// `FromSql<SmallInt, Pg>` for an enum, and automatically applies `#[repr(i16)]`.
7///
8/// All variants must be unit variants with explicit discriminants.
9///
10/// # Example
11/// ```ignore
12/// use diesel_enum_number::diesel_enum_number;
13///
14/// #[diesel_enum_number]
15/// // serde not required, but can fit in well with this approach
16/// #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
17/// #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
18/// pub enum UserStatus {
19///     Active = 1,
20///     Inactive = 2,
21/// }
22/// ```
23///
24/// expands to:
25///
26/// ```ignore
27/// #[repr(i16)]
28/// #[derive(diesel::AsExpression, diesel::FromSqlRow)]
29/// #[diesel(sql_type = diesel::sql_types::SmallInt)]
30/// #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
31/// #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
32/// pub enum UserStatus {
33///     Active = 1,
34///     Inactive = 2,
35/// }
36/// // + ToSql, FromSql, and from_i16 impls
37/// ```
38#[proc_macro_attribute]
39pub fn diesel_enum_number(_attr: TokenStream, item: TokenStream) -> TokenStream {
40    let input = parse_macro_input!(item as ItemEnum);
41    let name = &input.ident;
42
43    let match_arms: Vec<_> = input
44        .variants
45        .iter()
46        .map(|variant| {
47            assert!(
48                matches!(variant.fields, Fields::Unit),
49                "diesel_enum_number requires unit variants, but `{}::{}` has fields",
50                name,
51                variant.ident
52            );
53            let ident = &variant.ident;
54            let discriminant = variant
55                .discriminant
56                .as_ref()
57                .map(|(_, expr)| expr)
58                .unwrap_or_else(|| {
59                    panic!(
60                        "Variant `{}::{}` must have an explicit discriminant",
61                        name, ident
62                    )
63                });
64            quote! { #discriminant => Ok(#name::#ident), }
65        })
66        .collect();
67
68    let expanded = quote! {
69        #[repr(i16)]
70        #[derive(::diesel::AsExpression, ::diesel::FromSqlRow)]
71        #[diesel(sql_type = ::diesel::sql_types::SmallInt)]
72        #input
73
74        impl #name {
75            pub fn from_i16(value: i16) -> Result<Self, String> {
76                match value {
77                    #(#match_arms)*
78                    _ => Err(format!("Invalid value {} for {}", value, stringify!(#name))),
79                }
80            }
81        }
82
83        impl ::diesel::serialize::ToSql<::diesel::sql_types::SmallInt, ::diesel::pg::Pg> for #name {
84            fn to_sql<'b>(
85                &'b self,
86                out: &mut ::diesel::serialize::Output<'b, '_, ::diesel::pg::Pg>,
87            ) -> ::diesel::serialize::Result {
88                let value = *self as i16;
89                ::diesel::serialize::ToSql::<::diesel::sql_types::SmallInt, ::diesel::pg::Pg>::to_sql(
90                    &value,
91                    &mut out.reborrow(),
92                )
93            }
94        }
95
96        impl ::diesel::deserialize::FromSql<::diesel::sql_types::SmallInt, ::diesel::pg::Pg> for #name {
97            fn from_sql(
98                bytes: ::diesel::pg::PgValue<'_>,
99            ) -> ::diesel::deserialize::Result<Self> {
100                let value = <i16 as ::diesel::deserialize::FromSql<
101                    ::diesel::sql_types::SmallInt,
102                    ::diesel::pg::Pg,
103                >>::from_sql(bytes)?;
104                #name::from_i16(value).map_err(|e| e.into())
105            }
106        }
107    };
108
109    TokenStream::from(expanded)
110}