rovv_derive/
lib.rs

1use proc_macro2::Span;
2use quote::*;
3use syn::{
4    parse::{Parse, ParseStream},
5    parse_macro_input,
6    punctuated::Punctuated,
7    Token,
8};
9
10#[derive(Clone, Debug)]
11enum Mutability {
12    Ref(Token![ref]),
13    Mut(Token![mut]),
14    Move,
15}
16
17#[derive(Clone, Debug)]
18enum TypeSuffix {
19    Empty,
20    Star(Token![*]),
21    Question(Token![?]),
22}
23
24#[derive(Clone, Debug)]
25enum Key {
26    Ident(syn::Ident),
27    Type {
28        _bracket_token: syn::token::Bracket,
29        key_type: syn::Type,
30    },
31}
32
33/// ref a: A?, mut b: B*, c: C, [K]: V
34#[derive(Clone, Debug)]
35struct RowTypeField {
36    mutability: Mutability,
37    key: Key,
38    _colon_token: Token![:],
39    field_type: syn::Type,
40    suffix: TypeSuffix,
41}
42
43/// row! { a: A, b: B, c: C, .. : Trait1 + Trait2 + 'a }
44#[derive(Clone, Debug)]
45struct RowType {
46    fields: Punctuated<RowTypeField, Token![,]>,
47    _dot2token: Token![..],
48    _colon_token: Option<Token![:]>,
49    bounds: Punctuated<syn::TypeParamBound, Token![+]>,
50}
51
52impl Parse for Mutability {
53    fn parse(input: ParseStream) -> syn::Result<Self> {
54        let lookahead = input.lookahead1();
55        if lookahead.peek(Token![ref]) {
56            Ok(Mutability::Ref(input.parse()?))
57        } else if lookahead.peek(Token![mut]) {
58            Ok(Mutability::Mut(input.parse()?))
59        } else if lookahead.peek(syn::Ident) || lookahead.peek(syn::token::Bracket) {
60            Ok(Mutability::Move)
61        } else {
62            Err(syn::Error::new(
63                proc_macro2::Span::call_site(),
64                "expected `ref`, `mut` or nothing",
65            ))
66        }
67    }
68}
69
70impl Parse for TypeSuffix {
71    fn parse(input: ParseStream) -> syn::Result<Self> {
72        let lookahead = input.lookahead1();
73        if lookahead.peek(Token![*]) {
74            Ok(TypeSuffix::Star(input.parse()?))
75        } else if lookahead.peek(Token![?]) {
76            Ok(TypeSuffix::Question(input.parse()?))
77        } else {
78            Ok(TypeSuffix::Empty)
79        }
80    }
81}
82
83impl Parse for Key {
84    fn parse(input: ParseStream) -> syn::Result<Self> {
85        if input.peek(syn::token::Bracket) {
86            let content;
87            let _bracket_token = syn::bracketed!(content in input);
88            Ok(Self::Type {
89                _bracket_token,
90                key_type: content.parse()?,
91            })
92        } else {
93            Ok(Self::Ident(input.parse()?))
94        }
95    }
96}
97
98impl ToTokens for Key {
99    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
100        match self {
101            Key::Ident(id) => tokens.extend(quote! { lens_rs::Optics![#id] }),
102            Key::Type { key_type, .. } => key_type.to_tokens(tokens),
103        }
104    }
105}
106
107impl Parse for RowTypeField {
108    fn parse(input: ParseStream) -> syn::Result<Self> {
109        Ok(Self {
110            mutability: input.parse()?,
111            key: input.parse()?,
112            _colon_token: input.parse()?,
113            field_type: input.parse()?,
114            suffix: input.parse()?,
115        })
116    }
117}
118
119impl Parse for RowType {
120    fn parse(input: ParseStream) -> syn::Result<Self> {
121        let mut fields = Punctuated::new();
122        while !input.is_empty() && !input.peek(Token![..]) {
123            let row_field = input.call(RowTypeField::parse)?;
124            fields.push_value(row_field);
125            if input.is_empty() {
126                break;
127            }
128            let punct: Token![,] = input.parse()?;
129            fields.push_punct(punct);
130        }
131
132        let _dot2token = if fields.empty_or_trailing() && input.peek(Token![..]) {
133            input.parse()?
134        } else {
135            return Err(syn::Error::new(
136                proc_macro2::Span::call_site(),
137                "expected `..` token",
138            ));
139        };
140
141        let _colon_token = if input.peek(Token![:]) {
142            Some(input.parse()?)
143        } else {
144            return Ok(Self {
145                fields,
146                _dot2token,
147                _colon_token: None,
148                bounds: Default::default(),
149            });
150        };
151
152        Ok(Self {
153            fields,
154            _dot2token,
155            _colon_token,
156            bounds: Punctuated::parse_terminated(&input)?,
157        })
158    }
159}
160
161/// transform
162///
163/// ```rust
164/// row! { ref a: A, mut b: B, c: C, .. : Trait1 + Trait2 + 'a }
165/// ```
166///
167/// to
168///
169/// ```rust
170/// impl LensRef<Optic![a], A> + LensMut<Optic![b], B> + Lens<Optic![c], C> + Trait1 + Trait2 + 'a
171/// ```
172#[proc_macro]
173pub fn row(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
174    let row_type = parse_macro_input!(input as RowType);
175    let bound = row_type
176        .fields
177        .into_iter()
178        .map(|x: RowTypeField| {
179            let key = x.key;
180            let field_type = x.field_type;
181            match (x.suffix, x.mutability) {
182                (TypeSuffix::Empty, Mutability::Ref(_)) => {
183                    quote! { lens_rs:: LensRef<#key, #field_type> }
184                }
185                (TypeSuffix::Empty, Mutability::Mut(_)) => {
186                    quote! { lens_rs:: LensMut<#key, #field_type> }
187                }
188                (TypeSuffix::Empty, Mutability::Move) => {
189                    quote! { lens_rs:: Lens<#key, #field_type> }
190                }
191                (TypeSuffix::Star(_), Mutability::Ref(_)) => {
192                    quote! { lens_rs:: TraversalRef<#key, #field_type> }
193                }
194                (TypeSuffix::Star(_), Mutability::Mut(_)) => {
195                    quote! { lens_rs:: TraversalMut<#key, #field_type> }
196                }
197                (TypeSuffix::Star(_), Mutability::Move) => {
198                    quote! { lens_rs:: Traversal<#key, #field_type> }
199                }
200                (TypeSuffix::Question(_), Mutability::Ref(_)) => {
201                    quote! { lens_rs:: PrismRef<#key, #field_type> }
202                }
203                (TypeSuffix::Question(_), Mutability::Mut(_)) => {
204                    quote! { lens_rs:: PrismMut<#key, #field_type> }
205                }
206                (TypeSuffix::Question(_), Mutability::Move) => {
207                    quote! { lens_rs:: Prism<#key, #field_type> }
208                }
209            }
210        })
211        .chain(Some(quote! { rovv::Empty }))
212        .chain(
213            row_type
214                .bounds
215                .into_iter()
216                .map(|bound: syn::TypeParamBound| quote! { #bound }),
217        )
218        .collect::<Punctuated<proc_macro2::TokenStream, Token![+]>>();
219
220    // println!("{}", impl_ty.to_string());
221    proc_macro::TokenStream::from(quote! { impl #bound })
222}
223
224/// transform
225///
226/// ```rust
227/// dyn_row! { ref a: A, mut b: B, c: C, .. : Trait1 + Trait2 + 'a }
228/// ```
229///
230/// to
231///
232/// ```rust
233/// dyn LensRef<Optic![a],  A> + LensMut<Optic![b], B> + Lens<Optic![c], C> + Trait1 + Trait2 + 'a
234/// ```
235#[proc_macro]
236pub fn dyn_row(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
237    let row_type = parse_macro_input!(input as RowType);
238    let fields: Vec<RowTypeField> = row_type.fields.into_iter().collect::<Vec<_>>();
239    let (dyn_row_name, key, fields_ty, _) = join_dyn_row_field(fields);
240    let dyn_row_ident = syn::Ident::new(&dyn_row_name, Span::call_site());
241    let bounds = row_type.bounds.into_iter().collect::<Vec<_>>();
242
243    proc_macro::TokenStream::from(quote! {
244        dyn rovv::#dyn_row_ident<#(#key, #fields_ty),*> #(+ #bounds)*
245    })
246}
247
248fn join_dyn_row_field(
249    mut fields: Vec<RowTypeField>,
250) -> (String, Vec<Key>, Vec<syn::Type>, Vec<syn::Ident>) {
251    fields.sort_by_key(|field| map_trait(&field.suffix, &field.mutability));
252
253    let mut dyn_row_name = "_dyn_row".to_string();
254    let mut fields_key = Vec::new();
255    let mut fields_ty = Vec::new();
256    let mut optics_trait = Vec::new();
257    for field in fields {
258        let trait_name = map_trait(&field.suffix, &field.mutability);
259        dyn_row_name += &format!("_{}_", trait_name);
260        optics_trait.push(syn::Ident::new(trait_name, Span::call_site()));
261        fields_ty.push(field.field_type);
262        fields_key.push(field.key);
263    }
264
265    (dyn_row_name, fields_key, fields_ty, optics_trait)
266}
267
268fn map_trait(suffix: &TypeSuffix, mutability: &Mutability) -> &'static str {
269    match (suffix, mutability) {
270        (TypeSuffix::Empty, Mutability::Ref(_)) => "LensRef",
271        (TypeSuffix::Empty, Mutability::Mut(_)) => "LensMut",
272        (TypeSuffix::Empty, Mutability::Move) => "Lens",
273        (TypeSuffix::Star(_), Mutability::Ref(_)) => "TraversalRef",
274        (TypeSuffix::Star(_), Mutability::Mut(_)) => "TraversalMut",
275        (TypeSuffix::Star(_), Mutability::Move) => "Traversal",
276        (TypeSuffix::Question(_), Mutability::Ref(_)) => "PrismRef",
277        (TypeSuffix::Question(_), Mutability::Mut(_)) => "PrismMut",
278        (TypeSuffix::Question(_), Mutability::Move) => "Prism",
279    }
280}