diesel_selectable_macro/
lib.rs

1//! A `Selectable` derive trait for deriving a `select` method on a struct
2//! that performs a Diesel query by key names, rather than position (as
3//! [`diesel::Queryable`] does).
4
5use std::ops::Deref;
6
7use darling::ast;
8use darling::FromDeriveInput;
9use darling::FromField;
10use proc_macro::TokenStream as TokenStream1;
11use proc_macro2::TokenStream;
12use quote::quote;
13use quote::ToTokens;
14use syn::parse::Parse;
15use syn::parse::ParseStream;
16use syn::parse_macro_input;
17use syn::Attribute;
18use syn::DeriveInput;
19use syn::Generics;
20use syn::Ident;
21
22/// Provide a `.select()` function based on the struct's fields.
23#[proc_macro_derive(Selectable, attributes(diesel))]
24#[cfg(not(tarpaulin_include))]
25pub fn derive_selectable(input: TokenStream1) -> TokenStream1 {
26  SelectableStruct::from_derive_input(&parse_macro_input!(
27    input as DeriveInput
28  ))
29  .map(|recv| quote!(#recv))
30  .unwrap_or_else(|err| err.write_errors())
31  .into()
32}
33
34/// Struct that receives the input struct to `Selectable` and augments it
35/// with the `.select()` function.
36#[derive(FromDeriveInput)]
37#[darling(supports(struct_named), forward_attrs(diesel))]
38pub(crate) struct SelectableStruct {
39  /// The name of the struct.
40  ident: Ident,
41
42  // Lifetimes and type parameters attached to the struct.
43  generics: Generics,
44
45  /// Data on the individual fields.
46  data: ast::Data<(), SelectableField>,
47
48  /// Attributes on the overall struct.
49  attrs: Vec<Attribute>,
50}
51
52impl SelectableStruct {
53  /// Return the field identifiers on the struct.
54  #[cfg(not(tarpaulin_include))]
55  fn field_names(&self) -> Vec<&Ident> {
56    self
57      .data
58      .as_ref()
59      .take_struct()
60      .expect("Selectable only supports named structs")
61      .into_iter()
62      .map(|field| field.name())
63      .collect()
64  }
65}
66
67impl ToTokens for SelectableStruct {
68  /// Return an automatically generated `Selectable` implementation.
69  #[cfg(not(tarpaulin_include))]
70  fn to_tokens(&self, tokens: &mut TokenStream) {
71    // Put together our basic tokens: the struct identifier, and any generics
72    // (types, lifetimes, etc.) we need to carry forward to the `Selectable`
73    // implementation.
74
75    let ident = &self.ident;
76    let (impl_generics, type_generics, where_clause) =
77      self.generics.split_for_impl();
78
79    // Get the name of the table.
80    let diesel = self
81      .attrs
82      .iter()
83      .find(|attr| attr.path.is_ident("diesel"))
84      .expect("The `diesel` attribute is required");
85    let args = syn::parse_macro_input::parse::<CommaSeparatedArguments>(
86      diesel.into_token_stream().into(),
87    )
88    .expect("Unable to parse arguments.");
89    let table_name = args
90      .iter()
91      .find_map(|arg| {
92        syn::parse_macro_input::parse::<TableNameParser>(arg.clone().into())
93          .ok()
94      })
95      .expect("No `table_name` argument found.")
96      .table_name;
97    let table = quote! { crate::schema::#table_name::dsl::#table_name };
98
99    // Get the list of fields as a tuple.
100    let fields: Vec<TokenStream> = self
101      .field_names()
102      .iter()
103      .map(|f| quote! { crate::schema::#table_name::dsl::#f })
104      .collect();
105
106    // Add the select implementation.
107    tokens.extend(quote! {
108      #[automatically_derived]
109      impl #impl_generics #ident #type_generics #where_clause {
110        /// Return a tuple of the table's fields.
111        pub fn fields() -> (#(#fields),*) {
112          (#(#fields),*)
113        }
114
115        /// Construct a query object to retrieve objects from the corresponding
116        /// database table.
117        pub fn select() -> diesel::dsl::Select<#table, (#(#fields),*)> {
118          #table.select(Self::fields())
119        }
120      }
121    })
122  }
123}
124
125/// A representation of a single field on the struct.
126#[derive(FromField)]
127#[darling(attributes(field_names))]
128struct SelectableField {
129  /// The name of the field, or None for tuple fields.
130  ident: Option<Ident>,
131}
132
133impl SelectableField {
134  /// Return the field's identifier, or panic if there is no identifier.
135  #[cfg(not(tarpaulin_include))]
136  fn name(&self) -> &Ident {
137    self.ident.as_ref().expect("Selectable only supports named fields")
138  }
139}
140
141struct CommaSeparatedArguments(Vec<TokenStream>);
142
143impl Parse for CommaSeparatedArguments {
144  fn parse(input: ParseStream) -> syn::Result<Self> {
145    let bracketed;
146    let content;
147    input.parse::<syn::Token![#]>()?;
148    syn::bracketed!(bracketed in input);
149    bracketed.parse::<syn::Ident>()?;
150    syn::parenthesized!(content in bracketed);
151
152    // There are zero or more arguments, comma separated. Split them up.
153    Ok(Self(
154      content
155        .parse_terminated::<TokenStream, syn::Token![,]>(TokenStream::parse)
156        .expect("Failed to parse comma-separated args")
157        .into_iter()
158        .collect(),
159    ))
160  }
161}
162
163impl Deref for CommaSeparatedArguments {
164  type Target = Vec<TokenStream>;
165
166  fn deref(&self) -> &Self::Target {
167    &self.0
168  }
169}
170
171struct TableNameParser {
172  table_name: Ident,
173}
174
175impl Parse for TableNameParser {
176  fn parse(input: ParseStream) -> syn::Result<Self> {
177    // We're looking for `table_name = foo`, so first we can split on `=` and
178    // collect the result; the length should be 2 if this is a match.
179    let key_val: Vec<Ident> = input
180      .parse_terminated::<Ident, syn::Token![=]>(Ident::parse)
181      .expect("Not = separated.")
182      .into_iter()
183      .collect();
184    if key_val.len() != 2 {
185      return Err(input.error("Incorrect token length."));
186    }
187    match key_val[0] == "table_name" {
188      true => Ok(Self { table_name: key_val[1].clone() }),
189      false => Err(input.error("Wrong attribute,")),
190    }
191  }
192}