Skip to main content

lens_macros/
lib.rs

1//
2// Copyright (c) 2015-2019 Plausible Labs Cooperative, Inc.
3// All rights reserved.
4//
5// Copyright (c) 2025 Julius Foitzik on derivative work
6// All rights reserved.
7//
8
9use proc_macro::TokenStream;
10use proc_macro2::TokenStream as TokenStream2;
11use quote::{format_ident, quote};
12use syn::spanned::Spanned;
13use syn::{Expr, ExprPath, Member, parse_macro_input};
14
15#[derive(Clone, Debug)]
16enum LensStep {
17    Field(syn::Ident),
18    Index(Box<Expr>),
19}
20
21#[derive(Debug)]
22struct ParsedLens {
23    root: syn::Ident,
24    steps: Vec<LensStep>,
25}
26
27#[proc_macro]
28pub fn lens(input: TokenStream) -> TokenStream {
29    let expr = parse_macro_input!(input as Expr);
30
31    match parse_lens_expression(&expr).and_then(expand_lens_expression) {
32        Ok(expanded) => TokenStream::from(expanded),
33        Err(error) => error.to_compile_error().into(),
34    }
35}
36
37fn parse_lens_expression(expr: &Expr) -> Result<ParsedLens, syn::Error> {
38    let mut steps = Vec::new();
39    let root = collect_steps(expr, &mut steps)?;
40    if steps.is_empty() {
41        return Err(syn::Error::new(
42            expr.span(),
43            "lens!() expression must include at least one field access",
44        ));
45    }
46
47    Ok(ParsedLens { root, steps })
48}
49
50fn collect_steps(expr: &Expr, steps: &mut Vec<LensStep>) -> Result<syn::Ident, syn::Error> {
51    match expr {
52        Expr::Field(field_access) => {
53            let root = collect_steps(&field_access.base, steps)?;
54            let field_name = match &field_access.member {
55                Member::Named(field_ident) => field_ident.clone(),
56                Member::Unnamed(_) => {
57                    return Err(syn::Error::new(
58                        field_access.span(),
59                        "lens!() only works with named fields, not tuple indexing",
60                    ));
61                }
62            };
63            steps.push(LensStep::Field(field_name));
64            Ok(root)
65        }
66        Expr::Index(index_expr) => {
67            let root = collect_steps(&index_expr.expr, steps)?;
68            steps.push(LensStep::Index(index_expr.index.clone()));
69            Ok(root)
70        }
71        Expr::Path(path) => parse_root(path),
72        _ => Err(syn::Error::new(
73            expr.span(),
74            "lens!() expression must look like `Root.field` or `Root.field[index].child`",
75        )),
76    }
77}
78
79fn parse_root(path: &ExprPath) -> Result<syn::Ident, syn::Error> {
80    if path.path.segments.len() != 1 {
81        return Err(syn::Error::new(
82            path.span(),
83            "lens!() expression must start with an unqualified struct name",
84        ));
85    }
86
87    Ok(path.path.segments[0].ident.clone())
88}
89
90fn expand_lens_expression(parsed: ParsedLens) -> Result<TokenStream2, syn::Error> {
91    let mut lens_exprs = Vec::with_capacity(parsed.steps.len());
92    let root_lenses = format_ident!("_{}Lenses", parsed.root);
93    let mut current_lenses_expr = quote!(#root_lenses);
94    let mut last_field_context: Option<FieldContext> = None;
95
96    for step in parsed.steps {
97        match step {
98            LensStep::Field(field_name) => {
99                let field_lenses_expr = current_lenses_expr.clone();
100                let field_expr = quote!(#field_lenses_expr.#field_name);
101                lens_exprs.push(field_expr.clone());
102                let item_marker_name = vec_item_marker_name(&field_name);
103                let item_lenses_name = vec_item_lenses_field_name(&field_name);
104                let nested_lenses_name = nested_lenses_field_name(&field_name);
105
106                last_field_context = Some(FieldContext {
107                    item_marker_expr: quote!(#field_lenses_expr.#item_marker_name),
108                    item_lenses_expr: quote!(#field_lenses_expr.#item_lenses_name),
109                });
110                current_lenses_expr = quote!(#field_lenses_expr.#nested_lenses_name);
111            }
112            LensStep::Index(index_expr) => {
113                let Some(field_context) = &last_field_context else {
114                    return Err(syn::Error::new(
115                        index_expr.span(),
116                        "lens!() indexing is only supported immediately after a named field",
117                    ));
118                };
119                let item_marker_expr = field_context.item_marker_expr.clone();
120                let vec_expr = quote!(lens::vec_lens_from_marker(#item_marker_expr, #index_expr));
121                lens_exprs.push(vec_expr);
122                current_lenses_expr = field_context.item_lenses_expr.clone();
123                last_field_context = None;
124            }
125        }
126    }
127
128    Ok(quote! {
129        lens::compose_lens!(#(#lens_exprs),*)
130    })
131}
132
133struct FieldContext {
134    item_marker_expr: TokenStream2,
135    item_lenses_expr: TokenStream2,
136}
137
138fn nested_lenses_field_name(field_name: &syn::Ident) -> syn::Ident {
139    format_ident!("{field_name}_lenses")
140}
141
142fn vec_item_marker_name(field_name: &syn::Ident) -> syn::Ident {
143    format_ident!("{field_name}_item")
144}
145
146fn vec_item_lenses_field_name(field_name: &syn::Ident) -> syn::Ident {
147    format_ident!("{field_name}_item_lenses")
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use quote::quote;
154
155    #[test]
156    fn parses_field_and_index_steps() {
157        let expr: Expr = syn::parse2(quote!(Root.items[1].value)).expect("valid expression");
158        let parsed = parse_lens_expression(&expr).expect("parsed lens expression");
159        assert_eq!(parsed.root.to_string(), "Root");
160        assert_eq!(parsed.steps.len(), 3);
161        assert!(matches!(parsed.steps[0], LensStep::Field(_)));
162        assert!(matches!(parsed.steps[1], LensStep::Index(_)));
163        assert!(matches!(parsed.steps[2], LensStep::Field(_)));
164    }
165
166    #[test]
167    fn rejects_qualified_roots() {
168        let expr: Expr = syn::parse2(quote!(crate::Root.value)).expect("valid expression");
169        let error = parse_lens_expression(&expr).expect_err("qualified root should fail");
170        assert!(
171            error
172                .to_string()
173                .contains("must start with an unqualified struct name")
174        );
175    }
176
177    #[test]
178    fn rejects_tuple_fields() {
179        let expr: Expr = syn::parse2(quote!(Root.0)).expect("valid expression");
180        let error = parse_lens_expression(&expr).expect_err("tuple field should fail");
181        assert!(error.to_string().contains("named fields"));
182    }
183}