1use quote::quote;
67use syn::spanned::Spanned;
68
69macro_rules! try_or_return {
70 ($inp:expr) => {
71 match $inp {
72 Ok(ok) => ok,
73 Err(msg) => return msg.into(),
74 }
75 };
76}
77
78struct MacroState<'a> {
79 name: syn::Ident,
80 variants: Vec<&'a syn::Variant>,
81 sql_type: syn::Ident,
82 rust_type: syn::Ident,
83 error_type: syn::Path,
84 error_fn: syn::Path,
85}
86
87impl<'a> MacroState<'a> {
88 fn val(variant: &syn::Variant) -> Option<syn::Lit> {
89 let val = variant
90 .attrs
91 .iter()
92 .find(|a| a.path.get_ident().map(|i| i == "val").unwrap_or(false))
93 .map(|a| a.tokens.to_string())?;
94 let trimmed = val[1..].trim();
95 syn::parse_str(trimmed).ok()
96 }
97
98 fn rust_type(sql_type: &syn::Ident) -> Result<syn::Ident, proc_macro2::TokenStream> {
99 let name = match sql_type.to_string().as_str() {
100 "SmallInt" => "i16",
101 "Integer" | "Int" => "i32",
102 "BigInt" => "i64",
103 "VarChar" | "Text" => "String",
104 _ => {
105 let sql_types = "`SmallInt`, `Integer`, `Int`, `BigInt`, `VarChar`, `Text`";
106 let message = format!(
107 "`sql_type` must be one of {}, but was {}",
108 sql_types, sql_type,
109 );
110 return Err(error(sql_type.span(), &message));
111 }
112 };
113 let span = proc_macro2::Span::call_site();
114 Ok(syn::Ident::new(name, span))
115 }
116
117 fn try_from(&self) -> proc_macro2::TokenStream {
118 let span = proc_macro2::Span::call_site();
119 let variants = self.variants.iter().map(|f| &f.ident);
120 let error_fn = &self.error_fn;
121 let name = self.name.to_string();
122 let conversion = match self.rust_type.to_string().as_str() {
123 "i16" | "i32" | "i64" => {
124 let nums = self
125 .variants
126 .iter()
127 .enumerate()
128 .map(|(idx, &var)| (syn::LitInt::new(&idx.to_string(), span), var))
129 .map(|(idx, var)| (syn::Lit::Int(idx), var))
130 .map(|(idx, var)| Self::val(var).unwrap_or(idx));
131 quote! {
132 match inp {
133 #(#nums => Ok(Self::#variants),)*
134 otherwise => {
135 Err(#error_fn(format!("Unexpected `{}`: {}", #name, otherwise)))
136 },
137 }
138 }
139 }
140 "String" => {
141 let field_names = self.variants.iter().map(|v| {
142 use syn::{Lit::Str, LitStr};
143 let fallback = v.ident.to_string().to_lowercase();
144 Self::val(v).unwrap_or_else(|| Str(LitStr::new(&fallback, span)))
145 });
146
147 quote! {
148 match inp.as_str() {
149 #(#field_names => Ok(Self::#variants),)*
150 otherwise => {
151 Err(#error_fn(format!("Unexpected `{}`: {}", #name, otherwise)))
152 },
153 }
154 }
155 }
156 _ => panic!(),
157 };
158
159 let error_type = &self.error_type;
160 let rust_type = &self.rust_type;
161 let name = &self.name;
162 quote! {
163 impl TryFrom<#rust_type> for #name {
164 type Error = #error_type;
165
166 fn try_from(inp: #rust_type) -> std::result::Result<Self, Self::Error> {
167 #conversion
168 }
169 }
170 }
171 }
172
173 fn as_impl(&self) -> proc_macro2::TokenStream {
174 let span = proc_macro2::Span::call_site();
175 let rust_type = &self.rust_type;
176 let name = &self.name;
177 let variants = self.variants.iter().map(|f| &f.ident);
178 let conversion = match self.rust_type.to_string().as_str() {
179 "i16" | "i32" | "i64" => {
180 let nums = self
181 .variants
182 .iter()
183 .enumerate()
184 .map(|(idx, &var)| (syn::LitInt::new(&idx.to_string(), span), var))
185 .map(|(idx, var)| (syn::Lit::Int(idx), var))
186 .map(|(idx, var)| Self::val(var).unwrap_or(idx));
187 quote! {
188 match self {
189 #(Self::#variants => #nums as #rust_type,)*
190 }
191 }
192 }
193 "String" => {
194 let field_names = self.variants.iter().map(|v| {
195 use syn::{Lit::Str, LitStr};
196 let fallback = v.ident.to_string().to_lowercase();
197 Self::val(v).unwrap_or_else(|| Str(LitStr::new(&fallback, span)))
198 });
199
200 quote! {
201 match self {
202 #(Self::#variants => #field_names,)*
203 }
204 }
205 }
206 _ => panic!(),
207 };
208
209 quote! {
210 impl Into<#rust_type> for #name {
211 fn into(self) -> #rust_type {
212 #conversion.into()
213 }
214 }
215 }
216 }
217
218 fn impl_for_from_sql(&self) -> proc_macro2::TokenStream {
219 let sql_type = &self.sql_type;
220 let rust_type = &self.rust_type;
221 let name = &self.name;
222
223 quote! {
224 impl<Db> FromSql<#sql_type, Db> for #name
225 where
226 Db: diesel::backend::Backend,
227 #rust_type: FromSql<#sql_type, Db>
228 {
229 fn from_sql(bytes: <Db as diesel::backend::Backend>::RawValue<'_>) -> deserialize::Result<Self> {
230 let s = <#rust_type as FromSql<#sql_type, Db>>::from_sql(bytes)?;
231 let v = s.try_into()?;
232 Ok(v)
233 }
234 }
235 }
236 }
237
238 fn to_sql(&self) -> proc_macro2::TokenStream {
239 let span = proc_macro2::Span::call_site();
240 let sql_type = &self.sql_type;
241 let rust_type = &self.rust_type;
242 let rust_type_borrowed = if rust_type == "String" {
243 quote! { str }
244 } else {
245 quote! { #rust_type }
246 };
247 let name = &self.name;
248 let conversion = match self.rust_type.to_string().as_str() {
249 "i16" | "i32" | "i64" => {
250 let variants = self.variants.iter().map(|f| &f.ident);
251 let values = self.variants.iter().map(|&v| {
252 let ident = &v.ident;
253 quote! {
254 (Self::#ident as #rust_type).to_sql(out)
255 }
256 });
257
258 quote! {
259 match self {
260 #(Self::#variants => #values,)*
261 }
262 }
263 }
264 "String" => {
265 let variants = self.variants.iter().map(|f| &f.ident);
266 let field_names = self.variants.iter().map(|&v| {
267 use syn::{Lit::Str, LitStr};
268 let fallback = v.ident.to_string().to_lowercase();
269 let val = Self::val(v).unwrap_or_else(|| Str(LitStr::new(&fallback, span)));
270 quote! {
271 #val.to_sql(out)
272 }
273 });
274
275 quote! {
276 match self {
277 #(Self::#variants => #field_names,)*
278 }
279 }
280 }
281 _ => panic!(),
282 };
283
284 quote! {
285 impl<Db> ToSql<#sql_type, Db> for #name
286 where
287 Db: diesel::backend::Backend,
288 #rust_type_borrowed: ToSql<#sql_type, Db>
289 {
290 fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Db>) -> serialize::Result {
291 #conversion
292 }
293 }
294 }
295 }
296}
297
298fn get_attr_ident(
299 attrs: &[syn::Attribute],
300 outer: &str,
301 inner: &str,
302) -> Result<syn::Ident, proc_macro2::TokenStream> {
303 let stream = attrs
304 .iter()
305 .filter(|a| a.path.get_ident().map(|i| i == outer).unwrap_or(false))
306 .map(|a| &a.tokens)
307 .find(|s| s.to_string().contains(inner))
308 .ok_or_else(|| {
309 let span = proc_macro2::Span::call_site();
310 let msg = format!(
311 "Usage of the `DbEnum` macro requires the `{}` attribute to be present",
312 outer
313 );
314 error(span, &msg)
315 })?;
316 let s = stream.to_string();
317 let s = s
318 .split('=')
319 .nth(1)
320 .ok_or_else(|| error(stream.span(), "malformed attribute"))?
321 .trim_matches(|c| " )".contains(c));
322 Ok(syn::Ident::new(s, stream.span()))
323}
324
325fn get_attr_path(
326 attrs: &[syn::Attribute],
327 outer: &str,
328 inner: &str,
329) -> Result<syn::Path, proc_macro2::TokenStream> {
330 let stream = attrs
331 .iter()
332 .filter(|a| a.path.get_ident().map(|i| i == outer).unwrap_or(false))
333 .map(|a| &a.tokens)
334 .find(|s| s.to_string().contains(inner))
335 .ok_or_else(|| {
336 let span = proc_macro2::Span::call_site();
337 let msg = format!(
338 "Usage of the `DbEnum` macro requires the `{}` attribute to be present",
339 outer
340 );
341 error(span, &msg)
342 })?;
343 let s = stream.to_string();
344 let s = s
345 .split('=')
346 .nth(1)
347 .ok_or_else(|| error(stream.span(), "malformed attribute"))?
348 .trim_matches(|c| " )".contains(c));
349 syn::parse_str(s).map_err(|_| error(stream.span(), "Invalid path"))
350}
351
352fn error(span: proc_macro2::Span, message: &str) -> proc_macro2::TokenStream {
353 syn::Error::new(span, message).into_compile_error()
354}
355
356#[proc_macro_derive(DbEnum, attributes(diesel, diesel_enum))]
357pub fn db_enum(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
358 let input = syn::parse_macro_input!(input as syn::DeriveInput);
359 let name = input.ident;
360 let sql_type = try_or_return!(get_attr_ident(&input.attrs, "diesel", "sql_type"));
361 let error_fn = try_or_return!(get_attr_path(&input.attrs, "diesel_enum", "error_fn"));
362 let error_type = try_or_return!(get_attr_path(&input.attrs, "diesel_enum", "error_type"));
363 let rust_type = try_or_return!(MacroState::rust_type(&sql_type));
364 let span = proc_macro2::Span::call_site();
365 let data = match input.data {
366 syn::Data::Enum(data) => data,
367 _ => return error(span, "DbEnum should be called on an Enum").into(),
368 };
369 let variants = data.variants.iter().collect();
370 let state = MacroState {
371 name,
372 variants,
373 sql_type,
374 rust_type,
375 error_fn,
376 error_type,
377 };
378 let impl_for_from_sql = state.impl_for_from_sql();
379 let to_sql = state.to_sql();
380 let try_from = state.try_from();
381 let into = state.as_impl();
382 let name = state.name;
383 let mod_name = syn::Ident::new(
384 &format!("__impl_db_enum_{}", name),
385 proc_macro2::Span::call_site(),
386 );
387 let sql_type = state.sql_type;
388 let error_type = state.error_type;
389 let error_mod = state.error_fn.segments.first().expect("need `error_fn`");
390 let error_type_str = error_type
391 .segments
392 .iter()
393 .fold(String::new(), |a, b| a + &b.ident.to_string() + "::");
394 let error_type_str = &error_type_str[..error_type_str.len() - 2];
395 let error_import = if error_mod.ident == error_type_str {
396 quote! {}
397 } else {
398 quote! { use super::#error_mod; }
399 };
400
401 (quote! {
402 #[allow(non_snake_case, unused_extern_crates, unused_imports)]
403 mod #mod_name {
404 use super::{#name, #error_type};
405 #error_import
406
407 use diesel::{
408 self,
409 deserialize::{self, FromSql},
410 serialize::{self, Output, ToSql},
411 sql_types::#sql_type,
412 };
413 use std::{
414 convert::{TryFrom, TryInto},
415 io::Write,
416 };
417
418 #[automatically_derived]
419 #impl_for_from_sql
420
421 #[automatically_derived]
422 #to_sql
423
424 #[automatically_derived]
425 #try_from
426
427 #[automatically_derived]
428 #into
429 }
430 })
431 .into()
432}