1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#![recursion_limit = "256"]
extern crate proc_macro;
extern crate syn;
#[macro_use]
extern crate quote;
extern crate proc_macro2;
use proc_macro::TokenStream;
use proc_macro2::{Literal, Span};
use syn::Ident;
#[proc_macro_derive(DieselEnum)]
pub fn diesel_enum(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = ast.ident;
if let syn::Data::Enum(enum_data) = ast.data {
let variants = enum_data
.variants
.iter()
.map(|vs| Ident::new(&vs.ident.to_string(), Span::call_site()))
.collect::<Vec<_>>();
let variants_literal = enum_data
.variants
.iter()
.map(|vs| vs.ident.to_string().to_lowercase())
.collect::<Vec<_>>();
impl_diesel_enum(name, &variants, &variants_literal)
} else {
panic!("#[derive(DieselEnum)] works with enums only!");
}
}
fn impl_diesel_enum(
name: Ident,
variants: &Vec<Ident>,
variants_literal: &Vec<String>,
) -> TokenStream {
let name_iter = std::iter::repeat(&name);
let name_iter1 = std::iter::repeat(&name);
let name_iter2 = std::iter::repeat(&name);
let name_iter3 = std::iter::repeat(&name);
let bytes_literal = &variants_literal
.iter()
.map(|vl| Literal::byte_string(vl.as_bytes()))
.collect::<Vec<_>>();
let expanded = quote! {
use diesel::deserialize::{self, FromSql, FromSqlRow, Queryable};
use diesel::dsl::AsExprOf;
use diesel::expression::AsExpression;
use diesel::pg::Pg;
use diesel::row::Row;
use diesel::serialize::{self, IsNull, Output, ToSql};
use diesel::sql_types::VarChar;
use std::error::Error;
use std::fmt;
use std::io::Write;
impl fmt::Display for #name {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}",
match *self {
#(#name_iter::#variants => #variants_literal,)*
}
)
}
}
impl AsExpression<VarChar> for #name {
type Expression = AsExprOf<String, VarChar>;
fn as_expression(self) -> Self::Expression {
<String as AsExpression<VarChar>>::as_expression(self.to_string())
}
}
impl<'a> AsExpression<VarChar> for &'a #name {
type Expression = AsExprOf<String, VarChar>;
fn as_expression(self) -> Self::Expression {
<String as AsExpression<VarChar>>::as_expression(self.to_string())
}
}
impl ToSql<VarChar, Pg> for #name {
fn to_sql<W: Write>(&self, out: &mut Output<W, Pg>) -> serialize::Result {
match *self {
#(#name_iter1::#variants => out.write_all(#bytes_literal)?,)*
}
Ok(IsNull::No)
}
}
impl FromSql<VarChar, Pg> for #name {
fn from_sql(bytes: Option<&[u8]>) -> deserialize::Result<Self> {
match not_none!(bytes) {
#(#bytes_literal => Ok(#name_iter2::#variants),)*
v => Err(format!("Unknown value {:?} for {}", v, stringify!(#name)).into()),
}
}
}
impl FromSqlRow<VarChar, Pg> for #name {
fn build_from_row<R: Row<Pg>>(row: &mut R) -> Result<Self, Box<Error + Send + Sync>> {
match String::build_from_row(row)?.as_ref() {
#(#variants_literal => Ok(#name_iter3::#variants),)*
v => Err(format!("Unknown value {} for {}", v, stringify!(#name)).into()),
}
}
}
impl Queryable<VarChar, Pg> for #name {
type Row = Self;
fn build(row: Self::Row) -> Self {
row
}
}
};
expanded.into()
}