diesel_selectable_macro/
lib.rs1use std::ops::Deref;
6
7use darling::ast;
8use darling::FromDeriveInput;
9use darling::FromField;
10use proc_macro::TokenStream as TokenStream1;
11use proc_macro2::TokenStream;
12use quote::quote;
13use quote::ToTokens;
14use syn::parse::Parse;
15use syn::parse::ParseStream;
16use syn::parse_macro_input;
17use syn::Attribute;
18use syn::DeriveInput;
19use syn::Generics;
20use syn::Ident;
21
22#[proc_macro_derive(Selectable, attributes(diesel))]
24#[cfg(not(tarpaulin_include))]
25pub fn derive_selectable(input: TokenStream1) -> TokenStream1 {
26 SelectableStruct::from_derive_input(&parse_macro_input!(
27 input as DeriveInput
28 ))
29 .map(|recv| quote!(#recv))
30 .unwrap_or_else(|err| err.write_errors())
31 .into()
32}
33
34#[derive(FromDeriveInput)]
37#[darling(supports(struct_named), forward_attrs(diesel))]
38pub(crate) struct SelectableStruct {
39 ident: Ident,
41
42 generics: Generics,
44
45 data: ast::Data<(), SelectableField>,
47
48 attrs: Vec<Attribute>,
50}
51
52impl SelectableStruct {
53 #[cfg(not(tarpaulin_include))]
55 fn field_names(&self) -> Vec<&Ident> {
56 self
57 .data
58 .as_ref()
59 .take_struct()
60 .expect("Selectable only supports named structs")
61 .into_iter()
62 .map(|field| field.name())
63 .collect()
64 }
65}
66
67impl ToTokens for SelectableStruct {
68 #[cfg(not(tarpaulin_include))]
70 fn to_tokens(&self, tokens: &mut TokenStream) {
71 let ident = &self.ident;
76 let (impl_generics, type_generics, where_clause) =
77 self.generics.split_for_impl();
78
79 let diesel = self
81 .attrs
82 .iter()
83 .find(|attr| attr.path.is_ident("diesel"))
84 .expect("The `diesel` attribute is required");
85 let args = syn::parse_macro_input::parse::<CommaSeparatedArguments>(
86 diesel.into_token_stream().into(),
87 )
88 .expect("Unable to parse arguments.");
89 let table_name = args
90 .iter()
91 .find_map(|arg| {
92 syn::parse_macro_input::parse::<TableNameParser>(arg.clone().into())
93 .ok()
94 })
95 .expect("No `table_name` argument found.")
96 .table_name;
97 let table = quote! { crate::schema::#table_name::dsl::#table_name };
98
99 let fields: Vec<TokenStream> = self
101 .field_names()
102 .iter()
103 .map(|f| quote! { crate::schema::#table_name::dsl::#f })
104 .collect();
105
106 tokens.extend(quote! {
108 #[automatically_derived]
109 impl #impl_generics #ident #type_generics #where_clause {
110 pub fn fields() -> (#(#fields),*) {
112 (#(#fields),*)
113 }
114
115 pub fn select() -> diesel::dsl::Select<#table, (#(#fields),*)> {
118 #table.select(Self::fields())
119 }
120 }
121 })
122 }
123}
124
125#[derive(FromField)]
127#[darling(attributes(field_names))]
128struct SelectableField {
129 ident: Option<Ident>,
131}
132
133impl SelectableField {
134 #[cfg(not(tarpaulin_include))]
136 fn name(&self) -> &Ident {
137 self.ident.as_ref().expect("Selectable only supports named fields")
138 }
139}
140
141struct CommaSeparatedArguments(Vec<TokenStream>);
142
143impl Parse for CommaSeparatedArguments {
144 fn parse(input: ParseStream) -> syn::Result<Self> {
145 let bracketed;
146 let content;
147 input.parse::<syn::Token![#]>()?;
148 syn::bracketed!(bracketed in input);
149 bracketed.parse::<syn::Ident>()?;
150 syn::parenthesized!(content in bracketed);
151
152 Ok(Self(
154 content
155 .parse_terminated::<TokenStream, syn::Token![,]>(TokenStream::parse)
156 .expect("Failed to parse comma-separated args")
157 .into_iter()
158 .collect(),
159 ))
160 }
161}
162
163impl Deref for CommaSeparatedArguments {
164 type Target = Vec<TokenStream>;
165
166 fn deref(&self) -> &Self::Target {
167 &self.0
168 }
169}
170
171struct TableNameParser {
172 table_name: Ident,
173}
174
175impl Parse for TableNameParser {
176 fn parse(input: ParseStream) -> syn::Result<Self> {
177 let key_val: Vec<Ident> = input
180 .parse_terminated::<Ident, syn::Token![=]>(Ident::parse)
181 .expect("Not = separated.")
182 .into_iter()
183 .collect();
184 if key_val.len() != 2 {
185 return Err(input.error("Incorrect token length."));
186 }
187 match key_val[0] == "table_name" {
188 true => Ok(Self { table_name: key_val[1].clone() }),
189 false => Err(input.error("Wrong attribute,")),
190 }
191 }
192}