1use proc_macro::TokenStream;
9use proc_macro2::TokenStream as TokenStream2;
10use quote::{format_ident, quote};
11use syn::{
12 Data, DeriveInput, Field, Fields, GenericArgument, PathArguments, Type, Visibility,
13 parse_macro_input,
14};
15
16#[proc_macro_derive(Lenses)]
19pub fn lenses_derive(input: TokenStream) -> TokenStream {
20 let input = parse_macro_input!(input as DeriveInput);
21
22 match expand_lenses(&input) {
23 Ok(expanded) => TokenStream::from(expanded),
24 Err(error) => error.to_compile_error().into(),
25 }
26}
27
28fn expand_lenses(input: &DeriveInput) -> Result<TokenStream2, syn::Error> {
29 let data_struct = match &input.data {
30 Data::Struct(data_struct) => data_struct,
31 _ => {
32 return Err(syn::Error::new_spanned(
33 input,
34 "`#[derive(Lenses)]` may only be applied to structs",
35 ));
36 }
37 };
38
39 let fields = match &data_struct.fields {
40 Fields::Named(fields) => &fields.named,
41 _ => {
42 return Err(syn::Error::new_spanned(
43 input,
44 "`#[derive(Lenses)]` may only be applied to structs with named fields",
45 ));
46 }
47 };
48
49 let struct_name = &input.ident;
50 let lens_visibility = &input.vis;
51
52 let lens_items = fields
53 .iter()
54 .enumerate()
55 .map(|(index, field)| expand_field_lens(struct_name, lens_visibility, index as u64, field))
56 .collect::<Result<Vec<_>, _>>()?;
57
58 let lenses_struct_name = format_ident!("{struct_name}Lenses");
59 let lenses_struct_fields = fields
60 .iter()
61 .map(|field| expand_lenses_struct_fields(struct_name, field))
62 .collect::<Result<Vec<_>, _>>()?;
63
64 let lenses_const_name = format_ident!("_{struct_name}Lenses");
65 let lenses_const_fields = fields
66 .iter()
67 .map(|field| expand_lenses_const_fields(struct_name, field))
68 .collect::<Result<Vec<_>, _>>()?;
69
70 Ok(quote! {
71 #(#lens_items)*
72
73 #[allow(dead_code)]
74 #[doc(hidden)]
75 #lens_visibility struct #lenses_struct_name {
76 #(#lenses_struct_fields),*
77 }
78
79 #[allow(dead_code)]
80 #[allow(non_upper_case_globals)]
81 #[doc(hidden)]
82 #lens_visibility const #lenses_const_name: #lenses_struct_name = #lenses_struct_name {
83 #(#lenses_const_fields),*
84 };
85 })
86}
87
88fn expand_field_lens(
89 struct_name: &syn::Ident,
90 lens_visibility: &Visibility,
91 field_index: u64,
92 field: &Field,
93) -> Result<TokenStream2, syn::Error> {
94 let field_name = field_name(field)?;
95 let field_type = &field.ty;
96 let lens_name = lens_type_name(struct_name, field_name);
97 let value_lens = if is_value_lens_type(field_type) {
98 quote! {
99 #[allow(dead_code)]
100 impl lens::ValueLens for #lens_name {
101 #[inline(always)]
102 fn get(&self, source: &#struct_name) -> #field_type {
103 (*source).#field_name.clone()
104 }
105 }
106 }
107 } else {
108 quote!()
109 };
110
111 Ok(quote! {
112 #[allow(dead_code)]
113 #[doc(hidden)]
114 #lens_visibility struct #lens_name;
115
116 #[allow(dead_code)]
117 impl lens::Lens for #lens_name {
118 type Source = #struct_name;
119 type Target = #field_type;
120
121 #[inline(always)]
122 fn path(&self) -> lens::LensPath {
123 lens::LensPath::new(#field_index)
124 }
125
126 #[inline(always)]
127 fn mutate(&self, source: &mut #struct_name, target: #field_type) {
128 source.#field_name = target
129 }
130 }
131
132 #[allow(dead_code)]
133 impl lens::RefLens for #lens_name {
134 #[inline(always)]
135 fn get_ref<'a>(&self, source: &'a #struct_name) -> &'a #field_type {
136 &(*source).#field_name
137 }
138
139 #[inline(always)]
140 fn get_mut_ref<'a>(&self, source: &'a mut #struct_name) -> &'a mut #field_type {
141 &mut (*source).#field_name
142 }
143 }
144
145 #value_lens
146 })
147}
148
149fn expand_lenses_struct_fields(
150 struct_name: &syn::Ident,
151 field: &Field,
152) -> Result<TokenStream2, syn::Error> {
153 let field_name = field_name(field)?;
154 let field_lens_name = lens_type_name(struct_name, field_name);
155 let mut generated = vec![quote!(#field_name: #field_lens_name)];
156
157 if let Some(item_type) = vec_item_type(&field.ty) {
158 let item_marker_name = vec_item_marker_name(field_name);
159 generated.push(quote!(#item_marker_name: std::marker::PhantomData<#item_type>));
160 if !is_value_lens_type(item_type) {
161 let item_lenses_name = vec_item_lenses_field_name(field_name);
162 let item_lenses_type_name = nested_lenses_type_name(item_type)?;
163 generated.push(quote!(#item_lenses_name: #item_lenses_type_name));
164 }
165 } else if !is_value_lens_type(&field.ty) {
166 let field_parent_lenses_field_name = nested_lenses_field_name(field_name);
167 let field_parent_lenses_type_name = nested_lenses_type_name(&field.ty)?;
168 generated.push(quote!(
169 #field_parent_lenses_field_name: #field_parent_lenses_type_name
170 ));
171 }
172
173 Ok(quote!(#(#generated),*))
174}
175
176fn expand_lenses_const_fields(
177 struct_name: &syn::Ident,
178 field: &Field,
179) -> Result<TokenStream2, syn::Error> {
180 let field_name = field_name(field)?;
181 let field_lens_name = lens_type_name(struct_name, field_name);
182 let mut generated = vec![quote!(#field_name: #field_lens_name)];
183
184 if let Some(item_type) = vec_item_type(&field.ty) {
185 let item_marker_name = vec_item_marker_name(field_name);
186 generated.push(quote!(#item_marker_name: std::marker::PhantomData));
187 if !is_value_lens_type(item_type) {
188 let item_lenses_name = vec_item_lenses_field_name(field_name);
189 let item_lenses_type_name = nested_lenses_const_name(item_type)?;
190 generated.push(quote!(#item_lenses_name: #item_lenses_type_name));
191 }
192 } else if !is_value_lens_type(&field.ty) {
193 let field_parent_lenses_field_name = nested_lenses_field_name(field_name);
194 let field_parent_lenses_type_name = nested_lenses_const_name(&field.ty)?;
195 generated.push(quote!(
196 #field_parent_lenses_field_name: #field_parent_lenses_type_name
197 ));
198 }
199
200 Ok(quote!(#(#generated),*))
201}
202
203fn field_name(field: &Field) -> Result<&syn::Ident, syn::Error> {
204 field.ident.as_ref().ok_or_else(|| {
205 syn::Error::new_spanned(
206 field,
207 "`#[derive(Lenses)]` may only be applied to structs with named fields",
208 )
209 })
210}
211
212fn lens_type_name(struct_name: &syn::Ident, field_name: &syn::Ident) -> syn::Ident {
213 format_ident!(
214 "{}{}Lens",
215 struct_name,
216 to_camel_case(&field_name.to_string())
217 )
218}
219
220fn nested_lenses_field_name(field_name: &syn::Ident) -> syn::Ident {
221 format_ident!("{field_name}_lenses")
222}
223
224fn vec_item_marker_name(field_name: &syn::Ident) -> syn::Ident {
225 format_ident!("{field_name}_item")
226}
227
228fn vec_item_lenses_field_name(field_name: &syn::Ident) -> syn::Ident {
229 format_ident!("{field_name}_item_lenses")
230}
231
232fn nested_lenses_type_name(ty: &Type) -> Result<syn::Ident, syn::Error> {
233 let ident = terminal_type_ident(ty)?;
234 Ok(format_ident!("{ident}Lenses"))
235}
236
237fn nested_lenses_const_name(ty: &Type) -> Result<syn::Ident, syn::Error> {
238 let ident = terminal_type_ident(ty)?;
239 Ok(format_ident!("_{ident}Lenses"))
240}
241
242fn terminal_type_ident(ty: &Type) -> Result<syn::Ident, syn::Error> {
243 match ty {
244 Type::Path(type_path) => type_path
245 .path
246 .segments
247 .last()
248 .map(|segment| segment.ident.clone())
249 .ok_or_else(|| syn::Error::new_spanned(ty, "unsupported field type for `Lenses`")),
250 _ => Err(syn::Error::new_spanned(
251 ty,
252 "unsupported field type for `Lenses`",
253 )),
254 }
255}
256
257fn vec_item_type(ty: &Type) -> Option<&Type> {
258 let Type::Path(type_path) = ty else {
259 return None;
260 };
261 let segment = type_path.path.segments.last()?;
262 if segment.ident != "Vec" {
263 return None;
264 }
265
266 let PathArguments::AngleBracketed(arguments) = &segment.arguments else {
267 return None;
268 };
269 if arguments.args.len() != 1 {
270 return None;
271 }
272
273 match arguments.args.first()? {
274 GenericArgument::Type(ty) => Some(ty),
275 _ => None,
276 }
277}
278
279fn is_value_lens_type(ty: &Type) -> bool {
280 let Type::Path(type_path) = ty else {
281 return false;
282 };
283 let Some(segment) = type_path.path.segments.last() else {
284 return false;
285 };
286 matches!(
287 segment.ident.to_string().as_str(),
288 "bool"
289 | "char"
290 | "i8"
291 | "i16"
292 | "i32"
293 | "i64"
294 | "i128"
295 | "isize"
296 | "u8"
297 | "u16"
298 | "u32"
299 | "u64"
300 | "u128"
301 | "usize"
302 | "f32"
303 | "f64"
304 | "String"
305 )
306}
307
308fn to_camel_case(s: &str) -> String {
309 s.split('_')
310 .flat_map(|word| {
311 word.chars().enumerate().map(|(i, c)| {
312 if i == 0 {
313 c.to_uppercase().collect::<String>()
314 } else {
315 c.to_lowercase().collect()
316 }
317 })
318 })
319 .collect::<Vec<_>>()
320 .concat()
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use quote::quote;
327
328 #[test]
329 fn to_camel_case_should_work() {
330 assert_eq!(to_camel_case("this_is_snake_case"), "ThisIsSnakeCase");
331 }
332
333 #[test]
334 fn vec_item_type_should_detect_vec_fields() {
335 let ty: Type = syn::parse2(quote!(Vec<MyStruct>)).expect("valid type");
336 let item_type = vec_item_type(&ty).expect("vec item type");
337 assert_eq!(quote!(#item_type).to_string(), "MyStruct");
338 }
339
340 #[test]
341 fn scalar_types_should_get_value_lenses() {
342 let ty: Type = syn::parse2(quote!(String)).expect("valid type");
343 assert!(is_value_lens_type(&ty));
344 }
345
346 #[test]
347 fn nested_type_name_should_use_the_actual_field_type() {
348 let ty: Type = syn::parse2(quote!(crate::models::Address)).expect("valid type");
349 assert_eq!(
350 nested_lenses_type_name(&ty)
351 .expect("nested lenses type")
352 .to_string(),
353 "AddressLenses"
354 );
355 }
356}