pl_lens_derive/
lib.rs

1//
2// Copyright (c) 2015-2019 Plausible Labs Cooperative, Inc.
3// All rights reserved.
4//
5
6extern crate proc_macro;
7
8use proc_macro::TokenStream;
9use quote::{format_ident, quote};
10use syn::{parse_macro_input, Data, DataStruct, DeriveInput, Fields, Visibility};
11
12/// Handles the `#derive(Lenses)` applied to a struct by generating a `Lens` implementation for
13/// each field in the struct.
14#[proc_macro_derive(Lenses)]
15pub fn lenses_derive(input: TokenStream) -> TokenStream {
16    // Parse the input tokens into a syntax tree
17    let input = parse_macro_input!(input as DeriveInput);
18
19    // Check that the input type is a struct
20    let data_struct: DataStruct;
21    if let Data::Struct(s) = input.data {
22        data_struct = s
23    } else {
24        panic!("`#[derive(Lenses)]` may only be applied to structs")
25    }
26
27    // Check that the struct has named fields, since that's the only
28    // type we support at the moment
29    let fields: Fields;
30    if let Fields::Named(_) = data_struct.fields {
31        fields = data_struct.fields
32    } else {
33        panic!("`#[derive(Lenses)]` may only be applied to structs with named fields")
34    }
35
36    // Extract the struct name
37    let struct_name = &input.ident;
38
39    // Determine the visibility of the lens struct
40    let lens_visibility = match input.vis {
41        Visibility::Public(..) => quote!(pub),
42        // TODO: Handle `Crate` and `Restricted` visibliity
43        Visibility::Crate(..) => quote!(),
44        Visibility::Restricted(..) => quote!(),
45        Visibility::Inherited => quote!(),
46    };
47
48    // Generate lenses for each field in the struct
49    let lens_items = fields.iter().enumerate().map(|(index, field)| {
50        if let Some(field_name) = &field.ident {
51            let field_index = index as u64;
52            let field_type = &field.ty;
53
54            // Build the Lens name from the struct name and field name (for example, "StructFieldLens")
55            let lens_name = format_ident!(
56                "{}{}Lens",
57                struct_name.to_string(),
58                to_camel_case(&field_name.to_string())
59            );
60
61            // Build a `ValueLens` impl if the target is a primitive
62            // TODO: Should do this automatically for any target type that implements `Clone`
63            let value_lens = if is_primitive(&field.ty) {
64                quote!(
65                    #[allow(dead_code)]
66                    impl pl_lens::ValueLens for #lens_name {
67                        #[inline(always)]
68                        fn get(&self, source: &#struct_name) -> #field_type {
69                            (*source).#field_name.clone()
70                        }
71                    }
72                )
73            } else {
74                quote!()
75            };
76
77            quote!(
78                // Include the lens struct declaration
79                #[allow(dead_code)]
80                #[doc(hidden)]
81                #lens_visibility struct #lens_name;
82
83                // Include the `Lens` impl
84                #[allow(dead_code)]
85                impl pl_lens::Lens for #lens_name {
86                    type Source = #struct_name;
87                    type Target = #field_type;
88
89                    #[inline(always)]
90                    fn path(&self) -> pl_lens::LensPath {
91                        pl_lens::LensPath::new(#field_index)
92                    }
93
94                    #[inline(always)]
95                    fn mutate<'a>(&self, source: &'a mut #struct_name, target: #field_type) {
96                        source.#field_name = target
97                    }
98                }
99
100                // Include the `RefLens` impl
101                #[allow(dead_code)]
102                impl pl_lens::RefLens for #lens_name {
103                    #[inline(always)]
104                    fn get_ref<'a>(&self, source: &'a #struct_name) -> &'a #field_type {
105                        &(*source).#field_name
106                    }
107
108                    #[inline(always)]
109                    fn get_mut_ref<'a>(&self, source: &'a mut #struct_name) -> &'a mut #field_type {
110                        &mut (*source).#field_name
111                    }
112                }
113
114                // Include the `ValueLens` impl (only if it should be defined)
115                #value_lens
116            )
117        } else {
118            // This should be unreachable, since we already verified above that the struct
119            // only contains named fields
120            panic!("`#[derive(Lenses)]` may only be applied to structs with named fields")
121        }
122    });
123
124    // Build a `<StructName>Lenses` struct that enumerates the available lenses
125    // for each field in the struct, for example:
126    //     struct Struct2Lenses {
127    //         int32: Struct2Int32Lens,
128    //         struct1: Struct2Struct1Lens,
129    //         struct1_lenses: Struct1Lenses
130    //     }
131    let lenses_struct_name = format_ident!("{}Lenses", struct_name);
132    let lenses_struct_fields = fields.iter().map(|field| {
133        if let Some(field_name) = &field.ident {
134            let field_lens_name = format_ident!(
135                "{}{}Lens",
136                struct_name,
137                to_camel_case(&field_name.to_string())
138            );
139            if is_primitive(&field.ty) {
140                quote!(#field_name: #field_lens_name)
141            } else {
142                let field_parent_lenses_field_name = format_ident!("{}_lenses", field_name);
143                let field_parent_lenses_type_name =
144                    format_ident!("{}Lenses", to_camel_case(&field_name.to_string()));
145                quote!(
146                    #field_name: #field_lens_name,
147                    #field_parent_lenses_field_name: #field_parent_lenses_type_name
148                )
149            }
150        } else {
151            // This should be unreachable, since we already verified above that the struct
152            // only contains named fields
153            panic!("`#[derive(Lenses)]` may only be applied to structs with named fields")
154        }
155    });
156    let lenses_struct = quote!(
157        #[allow(dead_code)]
158        #[doc(hidden)]
159        #lens_visibility struct #lenses_struct_name {
160            #(#lenses_struct_fields),*
161        }
162    );
163
164    // Declare a `_<StructName>Lenses` instance that holds the available lenses
165    // for each field in the struct, for example:
166    //     const _Struct2Lenses: Struct2Lenses = Struct2Lenses {
167    //         int32: Struct2Int32Lens,
168    //         struct1: Struct2Struct1Lens,
169    //         struct1_lenses: _Struct1Lenses
170    //     };
171    let lenses_const_name = format_ident!("_{}Lenses", struct_name);
172    let lenses_const_fields = fields.iter().map(|field|
173        // TODO: Most of this is nearly identical to how the "Lenses" struct is declared,
174        // except for the underscore prefix in a couple places; might be good to consolidate
175        if let Some(field_name) = &field.ident {
176            let field_lens_name = format_ident!("{}{}Lens", struct_name, to_camel_case(&field_name.to_string()));
177            if is_primitive(&field.ty) {
178                quote!(#field_name: #field_lens_name)
179            } else {
180                let field_parent_lenses_field_name = format_ident!("{}_lenses", field_name);
181                let field_parent_lenses_type_name = format_ident!("_{}Lenses", to_camel_case(&field_name.to_string()));
182                quote!(
183                    #field_name: #field_lens_name,
184                    #field_parent_lenses_field_name: #field_parent_lenses_type_name
185                )
186            }
187        } else {
188            // This should be unreachable, since we already verified above that the struct
189            // only contains named fields
190            panic!("`#[derive(Lenses)]` may only be applied to structs with named fields")
191        }
192    );
193    let lenses_const = quote!(
194        #[allow(dead_code)]
195        #[allow(non_upper_case_globals)]
196        #[doc(hidden)]
197        #lens_visibility const #lenses_const_name: #lenses_struct_name = #lenses_struct_name {
198            #(#lenses_const_fields),*
199        };
200    );
201
202    // Build the output
203    let expanded = quote! {
204        #(#lens_items)*
205
206        #lenses_struct
207
208        #lenses_const
209    };
210
211    // Hand the output tokens back to the compiler
212    TokenStream::from(expanded)
213}
214
215/// Return true if the given type should be considered a primitive, i.e., whether
216/// it doesn't have lenses defined for it.
217fn is_primitive(ty: &syn::Type) -> bool {
218    let type_str = quote!(#ty).to_string();
219    match type_str.as_ref() {
220        // XXX: This is quick and dirty; we need a more reliable way to
221        // know whether the field is a struct type for which there are
222        // lenses derived
223        "i8" | "i16" | "i32" | "i64" | "u8" | "u16" | "u32" | "u64" | "f32" | "f64" | "String" => {
224            true
225        }
226        _ => false,
227    }
228}
229
230// XXX: Lifted from librustc_lint/builtin.rs
231fn to_camel_case(s: &str) -> String {
232    s.split('_')
233        .flat_map(|word| {
234            word.chars().enumerate().map(|(i, c)| {
235                if i == 0 {
236                    c.to_uppercase().collect::<String>()
237                } else {
238                    c.to_lowercase().collect()
239                }
240            })
241        })
242        .collect::<Vec<_>>()
243        .concat()
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn to_camel_case_should_work() {
252        assert_eq!(to_camel_case("this_is_snake_case"), "ThisIsSnakeCase");
253    }
254}