diesel_enum/
lib.rs

1//! This crate allows the user to represent state in the database using Rust enums. This is achieved
2//! through a proc macro. First the macro looks at your chosen `sql_type`, and then it devises a
3//! corresponding Rust type. The mapping is as follows:
4//!
5//! | SQL        | Rust     |
6//! | ---------- | -------- |
7//! | `SmallInt` | `i16`    |
8//! | `Integer`  | `i32`    |
9//! | `Int`      | `i32`    |
10//! | `BigInt`   | `i64`    |
11//! | `VarChar`  | `String` |
12//! | `Text`     | `String` |
13//!
14//!  The macro then generates three impls: a `FromSql` impl, an `ToSql` impl and a
15//! `TryFrom` impl, which allow conversion between the Sql type an the enum (`FromSql` and `ToSql`),
16//! and from the Rust type into the enum (`TryInto`).
17//!
18//! ### Usage
19//! ```rust
20//! #[macro_use] extern crate diesel;
21//! use diesel_enum::DbEnum;
22//! use diesel::{deserialize::FromSqlRow, sql_types::SmallInt};
23//!
24//! #[derive(Debug, thiserror::Error)]
25//! #[error("CustomError: {msg}, {status}")]
26//! pub struct CustomError {
27//!     msg: String,
28//!     status: u16,
29//! }
30//!
31//! impl CustomError {
32//!     fn not_found(msg: String) -> Self {
33//!         Self {
34//!             msg,
35//!             status: 404,
36//!         }
37//!     }
38//! }
39//!
40//! #[derive(Debug, Clone, Copy, PartialEq, Eq, FromSqlRow, DbEnum)]
41//! #[diesel(sql_type = SmallInt)]
42//! #[diesel_enum(error_fn = CustomError::not_found)]
43//! #[diesel_enum(error_type = CustomError)]
44//! pub enum Status {
45//!     /// Will be represented as 0.
46//!     Ready,
47//!     /// Will be represented as 1.
48//!     Pending,
49//! }
50//! ```
51//! Alternatively you can use strings, with will be cast to lowercase. (e.g. `Status::Ready` will be
52//! stored as `"ready"` in the database):
53//! ```rust
54//! #[derive(Debug, Clone, Copy, PartialEq, Eq, FromSqlRow, DbEnum)]
55//! #[diesel(sql_type = VarChar)]
56//! #[diesel_enum(error_fn = CustomError::not_found)]
57//! #[diesel_enum(error_type = CustomError)]
58//! pub enum Status {
59//!     /// Will be represented as `"ready"`.
60//!     Ready,
61//!     /// Will be represented as `"pending"`.
62//!     Pending,
63//! }
64//! ```
65
66use quote::quote;
67use syn::spanned::Spanned;
68
69macro_rules! try_or_return {
70    ($inp:expr) => {
71        match $inp {
72            Ok(ok) => ok,
73            Err(msg) => return msg.into(),
74        }
75    };
76}
77
78struct MacroState<'a> {
79    name: syn::Ident,
80    variants: Vec<&'a syn::Variant>,
81    sql_type: syn::Ident,
82    rust_type: syn::Ident,
83    error_type: syn::Path,
84    error_fn: syn::Path,
85}
86
87impl<'a> MacroState<'a> {
88    fn val(variant: &syn::Variant) -> Option<syn::Lit> {
89        let val = variant
90            .attrs
91            .iter()
92            .find(|a| a.path.get_ident().map(|i| i == "val").unwrap_or(false))
93            .map(|a| a.tokens.to_string())?;
94        let trimmed = val[1..].trim();
95        syn::parse_str(trimmed).ok()
96    }
97
98    fn rust_type(sql_type: &syn::Ident) -> Result<syn::Ident, proc_macro2::TokenStream> {
99        let name = match sql_type.to_string().as_str() {
100            "SmallInt" => "i16",
101            "Integer" | "Int" => "i32",
102            "BigInt" => "i64",
103            "VarChar" | "Text" => "String",
104            _ => {
105                let sql_types = "`SmallInt`, `Integer`, `Int`, `BigInt`, `VarChar`, `Text`";
106                let message = format!(
107                    "`sql_type` must be one of {}, but was {}",
108                    sql_types, sql_type,
109                );
110                return Err(error(sql_type.span(), &message));
111            }
112        };
113        let span = proc_macro2::Span::call_site();
114        Ok(syn::Ident::new(name, span))
115    }
116
117    fn try_from(&self) -> proc_macro2::TokenStream {
118        let span = proc_macro2::Span::call_site();
119        let variants = self.variants.iter().map(|f| &f.ident);
120        let error_fn = &self.error_fn;
121        let name = self.name.to_string();
122        let conversion = match self.rust_type.to_string().as_str() {
123            "i16" | "i32" | "i64" => {
124                let nums = self
125                    .variants
126                    .iter()
127                    .enumerate()
128                    .map(|(idx, &var)| (syn::LitInt::new(&idx.to_string(), span), var))
129                    .map(|(idx, var)| (syn::Lit::Int(idx), var))
130                    .map(|(idx, var)| Self::val(var).unwrap_or(idx));
131                quote! {
132                    match inp {
133                        #(#nums => Ok(Self::#variants),)*
134                        otherwise => {
135                            Err(#error_fn(format!("Unexpected `{}`: {}", #name, otherwise)))
136                        },
137                    }
138                }
139            }
140            "String" => {
141                let field_names = self.variants.iter().map(|v| {
142                    use syn::{Lit::Str, LitStr};
143                    let fallback = v.ident.to_string().to_lowercase();
144                    Self::val(v).unwrap_or_else(|| Str(LitStr::new(&fallback, span)))
145                });
146
147                quote! {
148                    match inp.as_str() {
149                        #(#field_names => Ok(Self::#variants),)*
150                        otherwise => {
151                            Err(#error_fn(format!("Unexpected `{}`: {}", #name, otherwise)))
152                        },
153                    }
154                }
155            }
156            _ => panic!(),
157        };
158
159        let error_type = &self.error_type;
160        let rust_type = &self.rust_type;
161        let name = &self.name;
162        quote! {
163            impl TryFrom<#rust_type> for #name {
164                type Error = #error_type;
165
166                fn try_from(inp: #rust_type) -> std::result::Result<Self, Self::Error> {
167                    #conversion
168                }
169            }
170        }
171    }
172
173    fn as_impl(&self) -> proc_macro2::TokenStream {
174        let span = proc_macro2::Span::call_site();
175        let rust_type = &self.rust_type;
176        let name = &self.name;
177        let variants = self.variants.iter().map(|f| &f.ident);
178        let conversion = match self.rust_type.to_string().as_str() {
179            "i16" | "i32" | "i64" => {
180                let nums = self
181                    .variants
182                    .iter()
183                    .enumerate()
184                    .map(|(idx, &var)| (syn::LitInt::new(&idx.to_string(), span), var))
185                    .map(|(idx, var)| (syn::Lit::Int(idx), var))
186                    .map(|(idx, var)| Self::val(var).unwrap_or(idx));
187                quote! {
188                    match self {
189                        #(Self::#variants => #nums as #rust_type,)*
190                    }
191                }
192            }
193            "String" => {
194                let field_names = self.variants.iter().map(|v| {
195                    use syn::{Lit::Str, LitStr};
196                    let fallback = v.ident.to_string().to_lowercase();
197                    Self::val(v).unwrap_or_else(|| Str(LitStr::new(&fallback, span)))
198                });
199
200                quote! {
201                    match self {
202                        #(Self::#variants => #field_names,)*
203                    }
204                }
205            }
206            _ => panic!(),
207        };
208
209        quote! {
210            impl Into<#rust_type> for #name {
211                fn into(self) -> #rust_type {
212                    #conversion.into()
213                }
214            }
215        }
216    }
217
218    fn impl_for_from_sql(&self) -> proc_macro2::TokenStream {
219        let sql_type = &self.sql_type;
220        let rust_type = &self.rust_type;
221        let name = &self.name;
222
223        quote! {
224            impl<Db> FromSql<#sql_type, Db> for #name
225            where
226                Db: diesel::backend::Backend,
227                #rust_type: FromSql<#sql_type, Db>
228            {
229                fn from_sql(bytes: <Db as diesel::backend::Backend>::RawValue<'_>) -> deserialize::Result<Self> {
230                    let s = <#rust_type as FromSql<#sql_type, Db>>::from_sql(bytes)?;
231                    let v = s.try_into()?;
232                    Ok(v)
233                }
234            }
235        }
236    }
237
238    fn to_sql(&self) -> proc_macro2::TokenStream {
239        let span = proc_macro2::Span::call_site();
240        let sql_type = &self.sql_type;
241        let rust_type = &self.rust_type;
242        let rust_type_borrowed = if rust_type == "String" {
243            quote! { str }
244        } else {
245            quote! { #rust_type }
246        };
247        let name = &self.name;
248        let conversion = match self.rust_type.to_string().as_str() {
249            "i16" | "i32" | "i64" => {
250                let variants = self.variants.iter().map(|f| &f.ident);
251                let values = self.variants.iter().map(|&v| {
252                    let ident = &v.ident;
253                    quote! {
254                        (Self::#ident as #rust_type).to_sql(out)
255                    }
256                });
257
258                quote! {
259                    match self {
260                        #(Self::#variants => #values,)*
261                    }
262                }
263            }
264            "String" => {
265                let variants = self.variants.iter().map(|f| &f.ident);
266                let field_names = self.variants.iter().map(|&v| {
267                    use syn::{Lit::Str, LitStr};
268                    let fallback = v.ident.to_string().to_lowercase();
269                    let val = Self::val(v).unwrap_or_else(|| Str(LitStr::new(&fallback, span)));
270                    quote! {
271                        #val.to_sql(out)
272                    }
273                });
274
275                quote! {
276                    match self {
277                        #(Self::#variants => #field_names,)*
278                    }
279                }
280            }
281            _ => panic!(),
282        };
283
284        quote! {
285            impl<Db> ToSql<#sql_type, Db> for #name
286            where
287                Db: diesel::backend::Backend,
288                #rust_type_borrowed: ToSql<#sql_type, Db>
289            {
290                fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Db>) -> serialize::Result {
291                    #conversion
292                }
293            }
294        }
295    }
296}
297
298fn get_attr_ident(
299    attrs: &[syn::Attribute],
300    outer: &str,
301    inner: &str,
302) -> Result<syn::Ident, proc_macro2::TokenStream> {
303    let stream = attrs
304        .iter()
305        .filter(|a| a.path.get_ident().map(|i| i == outer).unwrap_or(false))
306        .map(|a| &a.tokens)
307        .find(|s| s.to_string().contains(inner))
308        .ok_or_else(|| {
309            let span = proc_macro2::Span::call_site();
310            let msg = format!(
311                "Usage of the `DbEnum` macro requires the `{}` attribute to be present",
312                outer
313            );
314            error(span, &msg)
315        })?;
316    let s = stream.to_string();
317    let s = s
318        .split('=')
319        .nth(1)
320        .ok_or_else(|| error(stream.span(), "malformed attribute"))?
321        .trim_matches(|c| " )".contains(c));
322    Ok(syn::Ident::new(s, stream.span()))
323}
324
325fn get_attr_path(
326    attrs: &[syn::Attribute],
327    outer: &str,
328    inner: &str,
329) -> Result<syn::Path, proc_macro2::TokenStream> {
330    let stream = attrs
331        .iter()
332        .filter(|a| a.path.get_ident().map(|i| i == outer).unwrap_or(false))
333        .map(|a| &a.tokens)
334        .find(|s| s.to_string().contains(inner))
335        .ok_or_else(|| {
336            let span = proc_macro2::Span::call_site();
337            let msg = format!(
338                "Usage of the `DbEnum` macro requires the `{}` attribute to be present",
339                outer
340            );
341            error(span, &msg)
342        })?;
343    let s = stream.to_string();
344    let s = s
345        .split('=')
346        .nth(1)
347        .ok_or_else(|| error(stream.span(), "malformed attribute"))?
348        .trim_matches(|c| " )".contains(c));
349    syn::parse_str(s).map_err(|_| error(stream.span(), "Invalid path"))
350}
351
352fn error(span: proc_macro2::Span, message: &str) -> proc_macro2::TokenStream {
353    syn::Error::new(span, message).into_compile_error()
354}
355
356#[proc_macro_derive(DbEnum, attributes(diesel, diesel_enum))]
357pub fn db_enum(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
358    let input = syn::parse_macro_input!(input as syn::DeriveInput);
359    let name = input.ident;
360    let sql_type = try_or_return!(get_attr_ident(&input.attrs, "diesel", "sql_type"));
361    let error_fn = try_or_return!(get_attr_path(&input.attrs, "diesel_enum", "error_fn"));
362    let error_type = try_or_return!(get_attr_path(&input.attrs, "diesel_enum", "error_type"));
363    let rust_type = try_or_return!(MacroState::rust_type(&sql_type));
364    let span = proc_macro2::Span::call_site();
365    let data = match input.data {
366        syn::Data::Enum(data) => data,
367        _ => return error(span, "DbEnum should be called on an Enum").into(),
368    };
369    let variants = data.variants.iter().collect();
370    let state = MacroState {
371        name,
372        variants,
373        sql_type,
374        rust_type,
375        error_fn,
376        error_type,
377    };
378    let impl_for_from_sql = state.impl_for_from_sql();
379    let to_sql = state.to_sql();
380    let try_from = state.try_from();
381    let into = state.as_impl();
382    let name = state.name;
383    let mod_name = syn::Ident::new(
384        &format!("__impl_db_enum_{}", name),
385        proc_macro2::Span::call_site(),
386    );
387    let sql_type = state.sql_type;
388    let error_type = state.error_type;
389    let error_mod = state.error_fn.segments.first().expect("need `error_fn`");
390    let error_type_str = error_type
391        .segments
392        .iter()
393        .fold(String::new(), |a, b| a + &b.ident.to_string() + "::");
394    let error_type_str = &error_type_str[..error_type_str.len() - 2];
395    let error_import = if error_mod.ident == error_type_str {
396        quote! {}
397    } else {
398        quote! { use super::#error_mod; }
399    };
400
401    (quote! {
402        #[allow(non_snake_case, unused_extern_crates, unused_imports)]
403        mod #mod_name {
404            use super::{#name, #error_type};
405            #error_import
406
407            use diesel::{
408                self,
409                deserialize::{self, FromSql},
410                serialize::{self, Output, ToSql},
411                sql_types::#sql_type,
412            };
413            use std::{
414                convert::{TryFrom, TryInto},
415                io::Write,
416            };
417
418            #[automatically_derived]
419            #impl_for_from_sql
420
421            #[automatically_derived]
422            #to_sql
423
424            #[automatically_derived]
425            #try_from
426
427            #[automatically_derived]
428            #into
429        }
430    })
431    .into()
432}