diesel_sort_struct_fields/
lib.rs

1//! Macro to sort struct fields and `table!` columns to avoid subtle bugs.
2//!
3//! The way Diesel maps a response from a query into a struct is by treating a row as a tuple and
4//! assigning the fields in the order of the fields in code. Something like (not real code):
5//!
6//! ```rust,ignore
7//! struct User {
8//!     id: i32,
9//!     name: String,
10//! }
11//!
12//! fn user_from_row(row: (i32, String)) -> User {
13//!     User {
14//!         id: row.0,
15//!         name: row.1,
16//!     }
17//! }
18//! ```
19//!
20//! This works well, but it will break in subtle ways if the order of `id` and `name` aren't the
21//! same in `table!` and `struct User { ... }`. So this code doesn't compile:
22//!
23//! ```rust,ignore
24//! #[macro_use]
25//! extern crate diesel;
26//!
27//! use diesel::prelude::*;
28//!
29//! table! {
30//!     users {
31//!         // order here doesn't match order in the struct
32//!         name -> VarChar,
33//!         id -> Integer,
34//!     }
35//! }
36//!
37//! #[derive(Queryable)]
38//! struct User {
39//!     id: i32,
40//!     name: String,
41//! }
42//!
43//! fn main() {
44//!     let db = connect_to_db();
45//!
46//!     users::table
47//!         .select(users::all_columns)
48//!         .load::<User>(&db)
49//!         .unwrap();
50//! }
51//!
52//! fn connect_to_db() -> PgConnection {
53//!     PgConnection::establish("postgres://localhost/diesel-sort-struct-fields").unwrap()
54//! }
55//! ```
56//!
57//! Luckily you get a type error, so Diesel is clearly telling you that something is wrong. However
58//! if the types of `id` and `name` were the same you wouldn't get a type error. You would just
59//! have subtle bugs that could take hours to track down (it did for me).
60//!
61//! This crate prevents that with a simple procedural macro that sorts the fields of your model
62//! struct and `table!` such that you can define them in any order, but once the code gets to the
63//! compiler the order will always be the same.
64//!
65//! Example:
66//!
67//! ```rust
68//! #[macro_use]
69//! extern crate diesel;
70//!
71//! use diesel_sort_struct_fields::{sort_columns, sort_fields};
72//! use diesel::prelude::*;
73//!
74//! #[sort_columns]
75//! table! {
76//!     users {
77//!         name -> VarChar,
78//!         id -> Integer,
79//!     }
80//! }
81//!
82//! #[sort_fields]
83//! #[derive(Queryable)]
84//! struct User {
85//!     id: i32,
86//!     name: String,
87//! }
88//!
89//! fn main() {
90//!     let db = connect_to_db();
91//!
92//!     let users = users::table
93//!         .select(users::all_columns)
94//!         .load::<User>(&db)
95//!         .unwrap();
96//!
97//!     assert_eq!(0, users.len());
98//! }
99//!
100//! fn connect_to_db() -> PgConnection {
101//!     PgConnection::establish("postgres://localhost/diesel-sort-struct-fields").unwrap()
102//! }
103//! ```
104
105#![recursion_limit = "1024"]
106#![deny(unused_imports, dead_code, unused_variables, unused_must_use)]
107#![doc(html_root_url = "https://docs.rs/diesel-sort-struct-fields/0.1.3")]
108
109extern crate proc_macro;
110
111use proc_macro2::{Span, TokenStream};
112use quote::{quote, ToTokens};
113use syn::{
114    parse::{Parse, ParseBuffer, ParseStream},
115    parse2,
116    parse_macro_input::parse,
117    punctuated::Punctuated,
118    spanned::Spanned,
119    DeriveInput, Ident, Token,
120};
121
122type Result<A, B = syn::Error> = std::result::Result<A, B>;
123
124/// Sort fields in a model struct.
125///
126/// See crate level docs for more info.
127#[proc_macro_attribute]
128pub fn sort_fields(
129    attr: proc_macro::TokenStream,
130    item: proc_macro::TokenStream,
131) -> proc_macro::TokenStream {
132    let ast = match syn::parse_macro_input::parse::<DeriveInput>(item) {
133        Ok(ast) => ast,
134        Err(err) => return err.to_compile_error().into(),
135    };
136
137    match expand_sorted(attr.into(), ast) {
138        Ok(out) => out.into(),
139        Err(err) => err.to_compile_error().into(),
140    }
141}
142
143/// Sort columns in a `table!` macro.
144///
145/// See crate level docs for more info.
146#[proc_macro_attribute]
147pub fn sort_columns(
148    attr: proc_macro::TokenStream,
149    item: proc_macro::TokenStream,
150) -> proc_macro::TokenStream {
151    if !attr.is_empty() {
152        let attr: TokenStream = attr.into();
153        return syn::Error::new(
154            attr.span(),
155            "`#[sort_columns]` doesn't support any attributes",
156        )
157        .to_compile_error()
158        .into();
159    }
160
161    let ast = match parse::<syn::Macro>(item) {
162        Ok(ast) => ast,
163        Err(err) => return sort_columns_on_wrong_item_error(err.span()).into(),
164    };
165
166    let ident = &ast.path.segments.last().unwrap().value().ident;
167    if ident != "table" {
168        return sort_columns_on_wrong_item_error(ident.span()).into();
169    }
170
171    match parse2::<TableDsl>(ast.tts) {
172        Ok(table_dsl) => {
173            let tokens = quote! { #table_dsl };
174
175            tokens.into()
176        }
177        Err(err) => err.to_compile_error().into(),
178    }
179}
180
181fn sort_columns_on_wrong_item_error(span: Span) -> TokenStream {
182    syn::Error::new(
183        span,
184        "`#[sort_columns]` only works on the `diesel::table!` macro",
185    )
186    .to_compile_error()
187}
188
189#[derive(Debug)]
190struct TableDsl {
191    name: Ident,
192    id_columns: Option<Punctuated<Ident, Token![,]>>,
193    columns: Punctuated<ColumnDsl, Token![,]>,
194    use_statements: Vec<syn::ItemUse>,
195    attributes: Vec<syn::Attribute>,
196}
197
198impl Parse for TableDsl {
199    fn parse(input: ParseStream) -> syn::parse::Result<Self> {
200        let mut use_statements = Vec::new();
201
202        while let Some(stmt) = input.parse::<syn::ItemUse>().ok() {
203            use_statements.push(stmt)
204        }
205
206        let attributes = input.call(syn::Attribute::parse_outer)?;
207        let name = input.parse::<Ident>()?;
208
209        let id_columns = match try_parse_parens(input) {
210            Ok(inside_parens) => {
211                let id_columns = Punctuated::<Ident, Token![,]>::parse_terminated(&inside_parens)?;
212                Some(id_columns)
213            }
214            Err(_) => None,
215        };
216
217        let inside_braces;
218        syn::braced!(inside_braces in input);
219        let columns = Punctuated::<ColumnDsl, Token![,]>::parse_terminated(&inside_braces)?;
220
221        Ok(TableDsl {
222            name,
223            id_columns,
224            columns,
225            use_statements,
226            attributes,
227        })
228    }
229}
230
231impl ToTokens for TableDsl {
232    fn to_tokens(&self, tokens: &mut TokenStream) {
233        let table_name = &self.name;
234        let attributes = &self.attributes;
235
236        let id_column = if let Some(id_columns) = &self.id_columns {
237            quote! { ( #id_columns ) }
238        } else {
239            quote! {}
240        };
241        let use_statements = &self.use_statements;
242
243        let columns = sort_punctuated(&self.columns, |column| &column.name);
244
245        tokens.extend(quote! {
246            diesel::table! {
247                #(#use_statements)*
248
249                #( #attributes )*
250                #table_name #id_column {
251                    #( #columns )*
252                }
253            }
254        })
255    }
256}
257
258#[derive(Debug)]
259struct ColumnDsl {
260    name: Ident,
261    ty: ColumnType,
262    attributes: Vec<syn::Attribute>,
263}
264
265impl ToTokens for ColumnDsl {
266    fn to_tokens(&self, tokens: &mut TokenStream) {
267        let name = &self.name;
268        let ty = &self.ty;
269        let attributes = &self.attributes;
270
271        tokens.extend(quote! {
272            #(#attributes)*
273            #name -> #ty,
274        })
275    }
276}
277
278impl Parse for ColumnDsl {
279    fn parse(input: ParseStream) -> syn::parse::Result<Self> {
280        let attributes = input.call(syn::Attribute::parse_outer)?;
281
282        let name = input.parse::<Ident>()?;
283        input.parse::<Token![-]>()?;
284        input.parse::<Token![>]>()?;
285
286        let outer_ty = input.parse::<Ident>()?;
287        let ty = if input.peek(Token![<]) {
288            input.parse::<Token![<]>()?;
289            let ty = input.parse::<Ident>()?;
290            input.parse::<Token![>]>()?;
291            ColumnType::Wrapped(outer_ty, ty)
292        } else {
293            ColumnType::Bare(outer_ty)
294        };
295
296        Ok(ColumnDsl {
297            name,
298            ty,
299            attributes,
300        })
301    }
302}
303
304#[derive(Debug)]
305enum ColumnType {
306    Bare(Ident),
307    Wrapped(Ident, Ident),
308}
309
310impl ToTokens for ColumnType {
311    fn to_tokens(&self, tokens: &mut TokenStream) {
312        match self {
313            ColumnType::Bare(ty) => tokens.extend(quote! { #ty }),
314            ColumnType::Wrapped(constructor, ty) => tokens.extend(quote! { #constructor<#ty> }),
315        }
316    }
317}
318
319fn try_parse_parens<'a>(input: ParseStream<'a>) -> syn::parse::Result<ParseBuffer<'a>> {
320    (|| {
321        let inside_parens;
322        syn::parenthesized!(inside_parens in input);
323        Ok(inside_parens)
324    })()
325}
326
327fn expand_sorted(
328    attr: proc_macro2::TokenStream,
329    ast: DeriveInput,
330) -> Result<proc_macro2::TokenStream> {
331    if !attr.is_empty() {
332        return Err(syn::Error::new(
333            attr.span(),
334            "`#[sort_fields]` doesn't support any attributes",
335        ));
336    }
337
338    let attrs = ast.attrs;
339    let vis = ast.vis;
340    let ident = ast.ident;
341    let generics = ast.generics;
342
343    let sorted_fieds = find_and_sort_struct_fields(&ast.data, ident.span())?;
344
345    let tokens = quote! {
346        #(#attrs)*
347        #vis struct #ident #generics {
348            #( #sorted_fieds ),*
349        }
350    };
351
352    Ok(tokens)
353}
354
355fn sort_punctuated<A, B, F, K>(punctuated: &Punctuated<A, B>, f: F) -> Vec<&A>
356where
357    F: Fn(&A) -> &K,
358    K: Ord,
359{
360    let mut items = punctuated.iter().collect::<Vec<_>>();
361    items.sort_unstable_by_key(|item| f(item));
362    items
363}
364
365fn find_and_sort_struct_fields(data: &syn::Data, ident_span: Span) -> Result<Vec<&syn::Field>> {
366    match data {
367        syn::Data::Struct(data_struct) => match &data_struct.fields {
368            syn::Fields::Named(fields) => {
369                let fields = sort_punctuated(&fields.named, |field| &field.ident);
370                Ok(fields)
371            }
372            syn::Fields::Unnamed(fields) => Err(syn::Error::new(
373                fields.span(),
374                "`#[sort_fields]` is not allowed on tuple structs, only structs with named fields",
375            )),
376            syn::Fields::Unit => Err(syn::Error::new(
377                ident_span,
378                "`#[sort_fields]` is not allowed on unit structs, only structs with named fields",
379            )),
380        },
381        syn::Data::Enum(data) => Err(syn::Error::new(
382            data.enum_token.span(),
383            "`#[sort_fields]` is not allowed on enums, only structs",
384        )),
385        syn::Data::Union(data) => Err(syn::Error::new(
386            data.union_token.span(),
387            "`#[sort_fields]` is not allowed on unions, only structs",
388        )),
389    }
390}
391
392#[test]
393fn ui() {
394    let t = trybuild::TestCases::new();
395    t.pass("tests/compile_pass/*.rs");
396    t.compile_fail("tests/compile_fail/*.rs");
397}