Skip to main content

lens_derive/
lib.rs

1//
2// Copyright (c) 2015-2019 Plausible Labs Cooperative, Inc.
3// All rights reserved.
4// Copyright (c) 2025 Julius Foitzik on derivative work.
5// All rights reserved.
6//
7
8use proc_macro::TokenStream;
9use proc_macro2::TokenStream as TokenStream2;
10use quote::{format_ident, quote};
11use syn::{
12    Data, DeriveInput, Field, Fields, GenericArgument, PathArguments, Type, Visibility,
13    parse_macro_input,
14};
15
16/// Handles the `#[derive(Lenses)]` applied to a struct by generating a `Lens` implementation for
17/// each field in the struct.
18#[proc_macro_derive(Lenses)]
19pub fn lenses_derive(input: TokenStream) -> TokenStream {
20    let input = parse_macro_input!(input as DeriveInput);
21
22    match expand_lenses(&input) {
23        Ok(expanded) => TokenStream::from(expanded),
24        Err(error) => error.to_compile_error().into(),
25    }
26}
27
28fn expand_lenses(input: &DeriveInput) -> Result<TokenStream2, syn::Error> {
29    let data_struct = match &input.data {
30        Data::Struct(data_struct) => data_struct,
31        _ => {
32            return Err(syn::Error::new_spanned(
33                input,
34                "`#[derive(Lenses)]` may only be applied to structs",
35            ));
36        }
37    };
38
39    let fields = match &data_struct.fields {
40        Fields::Named(fields) => &fields.named,
41        _ => {
42            return Err(syn::Error::new_spanned(
43                input,
44                "`#[derive(Lenses)]` may only be applied to structs with named fields",
45            ));
46        }
47    };
48
49    let struct_name = &input.ident;
50    let lens_visibility = &input.vis;
51
52    let lens_items = fields
53        .iter()
54        .enumerate()
55        .map(|(index, field)| expand_field_lens(struct_name, lens_visibility, index as u64, field))
56        .collect::<Result<Vec<_>, _>>()?;
57
58    let lenses_struct_name = format_ident!("{struct_name}Lenses");
59    let lenses_struct_fields = fields
60        .iter()
61        .map(|field| expand_lenses_struct_fields(struct_name, field))
62        .collect::<Result<Vec<_>, _>>()?;
63
64    let lenses_const_name = format_ident!("_{struct_name}Lenses");
65    let lenses_const_fields = fields
66        .iter()
67        .map(|field| expand_lenses_const_fields(struct_name, field))
68        .collect::<Result<Vec<_>, _>>()?;
69
70    Ok(quote! {
71        #(#lens_items)*
72
73        #[allow(dead_code)]
74        #[doc(hidden)]
75        #lens_visibility struct #lenses_struct_name {
76            #(#lenses_struct_fields),*
77        }
78
79        #[allow(dead_code)]
80        #[allow(non_upper_case_globals)]
81        #[doc(hidden)]
82        #lens_visibility const #lenses_const_name: #lenses_struct_name = #lenses_struct_name {
83            #(#lenses_const_fields),*
84        };
85    })
86}
87
88fn expand_field_lens(
89    struct_name: &syn::Ident,
90    lens_visibility: &Visibility,
91    field_index: u64,
92    field: &Field,
93) -> Result<TokenStream2, syn::Error> {
94    let field_name = field_name(field)?;
95    let field_type = &field.ty;
96    let lens_name = lens_type_name(struct_name, field_name);
97    let value_lens = if is_value_lens_type(field_type) {
98        quote! {
99            #[allow(dead_code)]
100            impl lens::ValueLens for #lens_name {
101                #[inline(always)]
102                fn get(&self, source: &#struct_name) -> #field_type {
103                    (*source).#field_name.clone()
104                }
105            }
106        }
107    } else {
108        quote!()
109    };
110
111    Ok(quote! {
112        #[allow(dead_code)]
113        #[doc(hidden)]
114        #lens_visibility struct #lens_name;
115
116        #[allow(dead_code)]
117        impl lens::Lens for #lens_name {
118            type Source = #struct_name;
119            type Target = #field_type;
120
121            #[inline(always)]
122            fn path(&self) -> lens::LensPath {
123                lens::LensPath::new(#field_index)
124            }
125
126            #[inline(always)]
127            fn mutate(&self, source: &mut #struct_name, target: #field_type) {
128                source.#field_name = target
129            }
130        }
131
132        #[allow(dead_code)]
133        impl lens::RefLens for #lens_name {
134            #[inline(always)]
135            fn get_ref<'a>(&self, source: &'a #struct_name) -> &'a #field_type {
136                &(*source).#field_name
137            }
138
139            #[inline(always)]
140            fn get_mut_ref<'a>(&self, source: &'a mut #struct_name) -> &'a mut #field_type {
141                &mut (*source).#field_name
142            }
143        }
144
145        #value_lens
146    })
147}
148
149fn expand_lenses_struct_fields(
150    struct_name: &syn::Ident,
151    field: &Field,
152) -> Result<TokenStream2, syn::Error> {
153    let field_name = field_name(field)?;
154    let field_lens_name = lens_type_name(struct_name, field_name);
155    let mut generated = vec![quote!(#field_name: #field_lens_name)];
156
157    if let Some(item_type) = vec_item_type(&field.ty) {
158        let item_marker_name = vec_item_marker_name(field_name);
159        generated.push(quote!(#item_marker_name: std::marker::PhantomData<#item_type>));
160        if !is_value_lens_type(item_type) {
161            let item_lenses_name = vec_item_lenses_field_name(field_name);
162            let item_lenses_type_name = nested_lenses_type_name(item_type)?;
163            generated.push(quote!(#item_lenses_name: #item_lenses_type_name));
164        }
165    } else if !is_value_lens_type(&field.ty) {
166        let field_parent_lenses_field_name = nested_lenses_field_name(field_name);
167        let field_parent_lenses_type_name = nested_lenses_type_name(&field.ty)?;
168        generated.push(quote!(
169            #field_parent_lenses_field_name: #field_parent_lenses_type_name
170        ));
171    }
172
173    Ok(quote!(#(#generated),*))
174}
175
176fn expand_lenses_const_fields(
177    struct_name: &syn::Ident,
178    field: &Field,
179) -> Result<TokenStream2, syn::Error> {
180    let field_name = field_name(field)?;
181    let field_lens_name = lens_type_name(struct_name, field_name);
182    let mut generated = vec![quote!(#field_name: #field_lens_name)];
183
184    if let Some(item_type) = vec_item_type(&field.ty) {
185        let item_marker_name = vec_item_marker_name(field_name);
186        generated.push(quote!(#item_marker_name: std::marker::PhantomData));
187        if !is_value_lens_type(item_type) {
188            let item_lenses_name = vec_item_lenses_field_name(field_name);
189            let item_lenses_type_name = nested_lenses_const_name(item_type)?;
190            generated.push(quote!(#item_lenses_name: #item_lenses_type_name));
191        }
192    } else if !is_value_lens_type(&field.ty) {
193        let field_parent_lenses_field_name = nested_lenses_field_name(field_name);
194        let field_parent_lenses_type_name = nested_lenses_const_name(&field.ty)?;
195        generated.push(quote!(
196            #field_parent_lenses_field_name: #field_parent_lenses_type_name
197        ));
198    }
199
200    Ok(quote!(#(#generated),*))
201}
202
203fn field_name(field: &Field) -> Result<&syn::Ident, syn::Error> {
204    field.ident.as_ref().ok_or_else(|| {
205        syn::Error::new_spanned(
206            field,
207            "`#[derive(Lenses)]` may only be applied to structs with named fields",
208        )
209    })
210}
211
212fn lens_type_name(struct_name: &syn::Ident, field_name: &syn::Ident) -> syn::Ident {
213    format_ident!(
214        "{}{}Lens",
215        struct_name,
216        to_camel_case(&field_name.to_string())
217    )
218}
219
220fn nested_lenses_field_name(field_name: &syn::Ident) -> syn::Ident {
221    format_ident!("{field_name}_lenses")
222}
223
224fn vec_item_marker_name(field_name: &syn::Ident) -> syn::Ident {
225    format_ident!("{field_name}_item")
226}
227
228fn vec_item_lenses_field_name(field_name: &syn::Ident) -> syn::Ident {
229    format_ident!("{field_name}_item_lenses")
230}
231
232fn nested_lenses_type_name(ty: &Type) -> Result<syn::Ident, syn::Error> {
233    let ident = terminal_type_ident(ty)?;
234    Ok(format_ident!("{ident}Lenses"))
235}
236
237fn nested_lenses_const_name(ty: &Type) -> Result<syn::Ident, syn::Error> {
238    let ident = terminal_type_ident(ty)?;
239    Ok(format_ident!("_{ident}Lenses"))
240}
241
242fn terminal_type_ident(ty: &Type) -> Result<syn::Ident, syn::Error> {
243    match ty {
244        Type::Path(type_path) => type_path
245            .path
246            .segments
247            .last()
248            .map(|segment| segment.ident.clone())
249            .ok_or_else(|| syn::Error::new_spanned(ty, "unsupported field type for `Lenses`")),
250        _ => Err(syn::Error::new_spanned(
251            ty,
252            "unsupported field type for `Lenses`",
253        )),
254    }
255}
256
257fn vec_item_type(ty: &Type) -> Option<&Type> {
258    let Type::Path(type_path) = ty else {
259        return None;
260    };
261    let segment = type_path.path.segments.last()?;
262    if segment.ident != "Vec" {
263        return None;
264    }
265
266    let PathArguments::AngleBracketed(arguments) = &segment.arguments else {
267        return None;
268    };
269    if arguments.args.len() != 1 {
270        return None;
271    }
272
273    match arguments.args.first()? {
274        GenericArgument::Type(ty) => Some(ty),
275        _ => None,
276    }
277}
278
279fn is_value_lens_type(ty: &Type) -> bool {
280    let Type::Path(type_path) = ty else {
281        return false;
282    };
283    let Some(segment) = type_path.path.segments.last() else {
284        return false;
285    };
286    matches!(
287        segment.ident.to_string().as_str(),
288        "bool"
289            | "char"
290            | "i8"
291            | "i16"
292            | "i32"
293            | "i64"
294            | "i128"
295            | "isize"
296            | "u8"
297            | "u16"
298            | "u32"
299            | "u64"
300            | "u128"
301            | "usize"
302            | "f32"
303            | "f64"
304            | "String"
305    )
306}
307
308fn to_camel_case(s: &str) -> String {
309    s.split('_')
310        .flat_map(|word| {
311            word.chars().enumerate().map(|(i, c)| {
312                if i == 0 {
313                    c.to_uppercase().collect::<String>()
314                } else {
315                    c.to_lowercase().collect()
316                }
317            })
318        })
319        .collect::<Vec<_>>()
320        .concat()
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use quote::quote;
327
328    #[test]
329    fn to_camel_case_should_work() {
330        assert_eq!(to_camel_case("this_is_snake_case"), "ThisIsSnakeCase");
331    }
332
333    #[test]
334    fn vec_item_type_should_detect_vec_fields() {
335        let ty: Type = syn::parse2(quote!(Vec<MyStruct>)).expect("valid type");
336        let item_type = vec_item_type(&ty).expect("vec item type");
337        assert_eq!(quote!(#item_type).to_string(), "MyStruct");
338    }
339
340    #[test]
341    fn scalar_types_should_get_value_lenses() {
342        let ty: Type = syn::parse2(quote!(String)).expect("valid type");
343        assert!(is_value_lens_type(&ty));
344    }
345
346    #[test]
347    fn nested_type_name_should_use_the_actual_field_type() {
348        let ty: Type = syn::parse2(quote!(crate::models::Address)).expect("valid type");
349        assert_eq!(
350            nested_lenses_type_name(&ty)
351                .expect("nested lenses type")
352                .to_string(),
353            "AddressLenses"
354        );
355    }
356}