1use proc_macro::TokenStream;
10use proc_macro2::TokenStream as TokenStream2;
11use quote::{format_ident, quote};
12use syn::spanned::Spanned;
13use syn::{Expr, ExprPath, Member, parse_macro_input};
14
15#[derive(Clone, Debug)]
16enum LensStep {
17 Field(syn::Ident),
18 Index(Box<Expr>),
19}
20
21#[derive(Debug)]
22struct ParsedLens {
23 root: syn::Ident,
24 steps: Vec<LensStep>,
25}
26
27#[proc_macro]
28pub fn lens(input: TokenStream) -> TokenStream {
29 let expr = parse_macro_input!(input as Expr);
30
31 match parse_lens_expression(&expr).and_then(expand_lens_expression) {
32 Ok(expanded) => TokenStream::from(expanded),
33 Err(error) => error.to_compile_error().into(),
34 }
35}
36
37fn parse_lens_expression(expr: &Expr) -> Result<ParsedLens, syn::Error> {
38 let mut steps = Vec::new();
39 let root = collect_steps(expr, &mut steps)?;
40 if steps.is_empty() {
41 return Err(syn::Error::new(
42 expr.span(),
43 "lens!() expression must include at least one field access",
44 ));
45 }
46
47 Ok(ParsedLens { root, steps })
48}
49
50fn collect_steps(expr: &Expr, steps: &mut Vec<LensStep>) -> Result<syn::Ident, syn::Error> {
51 match expr {
52 Expr::Field(field_access) => {
53 let root = collect_steps(&field_access.base, steps)?;
54 let field_name = match &field_access.member {
55 Member::Named(field_ident) => field_ident.clone(),
56 Member::Unnamed(_) => {
57 return Err(syn::Error::new(
58 field_access.span(),
59 "lens!() only works with named fields, not tuple indexing",
60 ));
61 }
62 };
63 steps.push(LensStep::Field(field_name));
64 Ok(root)
65 }
66 Expr::Index(index_expr) => {
67 let root = collect_steps(&index_expr.expr, steps)?;
68 steps.push(LensStep::Index(index_expr.index.clone()));
69 Ok(root)
70 }
71 Expr::Path(path) => parse_root(path),
72 _ => Err(syn::Error::new(
73 expr.span(),
74 "lens!() expression must look like `Root.field` or `Root.field[index].child`",
75 )),
76 }
77}
78
79fn parse_root(path: &ExprPath) -> Result<syn::Ident, syn::Error> {
80 if path.path.segments.len() != 1 {
81 return Err(syn::Error::new(
82 path.span(),
83 "lens!() expression must start with an unqualified struct name",
84 ));
85 }
86
87 Ok(path.path.segments[0].ident.clone())
88}
89
90fn expand_lens_expression(parsed: ParsedLens) -> Result<TokenStream2, syn::Error> {
91 let mut lens_exprs = Vec::with_capacity(parsed.steps.len());
92 let root_lenses = format_ident!("_{}Lenses", parsed.root);
93 let mut current_lenses_expr = quote!(#root_lenses);
94 let mut last_field_context: Option<FieldContext> = None;
95
96 for step in parsed.steps {
97 match step {
98 LensStep::Field(field_name) => {
99 let field_lenses_expr = current_lenses_expr.clone();
100 let field_expr = quote!(#field_lenses_expr.#field_name);
101 lens_exprs.push(field_expr.clone());
102 let item_marker_name = vec_item_marker_name(&field_name);
103 let item_lenses_name = vec_item_lenses_field_name(&field_name);
104 let nested_lenses_name = nested_lenses_field_name(&field_name);
105
106 last_field_context = Some(FieldContext {
107 item_marker_expr: quote!(#field_lenses_expr.#item_marker_name),
108 item_lenses_expr: quote!(#field_lenses_expr.#item_lenses_name),
109 });
110 current_lenses_expr = quote!(#field_lenses_expr.#nested_lenses_name);
111 }
112 LensStep::Index(index_expr) => {
113 let Some(field_context) = &last_field_context else {
114 return Err(syn::Error::new(
115 index_expr.span(),
116 "lens!() indexing is only supported immediately after a named field",
117 ));
118 };
119 let item_marker_expr = field_context.item_marker_expr.clone();
120 let vec_expr = quote!(lens::vec_lens_from_marker(#item_marker_expr, #index_expr));
121 lens_exprs.push(vec_expr);
122 current_lenses_expr = field_context.item_lenses_expr.clone();
123 last_field_context = None;
124 }
125 }
126 }
127
128 Ok(quote! {
129 lens::compose_lens!(#(#lens_exprs),*)
130 })
131}
132
133struct FieldContext {
134 item_marker_expr: TokenStream2,
135 item_lenses_expr: TokenStream2,
136}
137
138fn nested_lenses_field_name(field_name: &syn::Ident) -> syn::Ident {
139 format_ident!("{field_name}_lenses")
140}
141
142fn vec_item_marker_name(field_name: &syn::Ident) -> syn::Ident {
143 format_ident!("{field_name}_item")
144}
145
146fn vec_item_lenses_field_name(field_name: &syn::Ident) -> syn::Ident {
147 format_ident!("{field_name}_item_lenses")
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use quote::quote;
154
155 #[test]
156 fn parses_field_and_index_steps() {
157 let expr: Expr = syn::parse2(quote!(Root.items[1].value)).expect("valid expression");
158 let parsed = parse_lens_expression(&expr).expect("parsed lens expression");
159 assert_eq!(parsed.root.to_string(), "Root");
160 assert_eq!(parsed.steps.len(), 3);
161 assert!(matches!(parsed.steps[0], LensStep::Field(_)));
162 assert!(matches!(parsed.steps[1], LensStep::Index(_)));
163 assert!(matches!(parsed.steps[2], LensStep::Field(_)));
164 }
165
166 #[test]
167 fn rejects_qualified_roots() {
168 let expr: Expr = syn::parse2(quote!(crate::Root.value)).expect("valid expression");
169 let error = parse_lens_expression(&expr).expect_err("qualified root should fail");
170 assert!(
171 error
172 .to_string()
173 .contains("must start with an unqualified struct name")
174 );
175 }
176
177 #[test]
178 fn rejects_tuple_fields() {
179 let expr: Expr = syn::parse2(quote!(Root.0)).expect("valid expression");
180 let error = parse_lens_expression(&expr).expect_err("tuple field should fail");
181 assert!(error.to_string().contains("named fields"));
182 }
183}