plutus_parser_derive/
lib.rs

1use std::collections::HashSet;
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::{quote, quote_spanned};
6use syn::{
7    Attribute, Data, DeriveInput, Error, Expr, ExprLit, Fields, Ident, Lit, Meta,
8    parse_macro_input, spanned::Spanned,
9};
10
11#[proc_macro_derive(AsPlutus, attributes(variant))]
12pub fn derive_as_plutus(input: TokenStream) -> TokenStream {
13    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
14    let name = &input.ident;
15
16    let implementation = match &input.data {
17        Data::Struct(s) => {
18            let n = get_variant(&input.attrs).unwrap_or_default();
19            let from_plutus;
20            let to_plutus;
21            match &s.fields {
22                Fields::Named(named) => {
23                    let names: Vec<_> = named
24                        .named
25                        .iter()
26                        .map(|n| n.ident.as_ref().unwrap())
27                        .collect();
28                    let assignments = names.iter().map(|n| {
29                        quote! {
30                            #n: plutus_parser::AsPlutus::from_plutus(#n)?,
31                        }
32                    });
33                    let casts: Vec<_> = names
34                        .iter()
35                        .map(|n| {
36                            quote! {
37                                self.#n.to_plutus(),
38                            }
39                        })
40                        .collect();
41
42                    from_plutus = quote! {
43                        let (variant, fields) = plutus_parser::parse_constr(data)?;
44                        if variant == #n {
45                            let [#(#names),*] = plutus_parser::parse_variant(variant, fields)?;
46                            return Ok(Self {
47                                #(#assignments)*
48                            });
49                        }
50                        Err(plutus_parser::DecodeError::UnexpectedVariant { variant })
51                    };
52                    to_plutus = quote! {
53                        plutus_parser::create_constr(#n, vec![
54                            #(#casts)*
55                        ])
56                    };
57                }
58                Fields::Unit => {
59                    from_plutus = quote! {
60                        let (variant, fields) = plutus_parser::parse_constr(data)?;
61                        if variant == #n {
62                            let [] = plutus_parser::parse_variant(variant, fields)?;
63                            return Ok(Self);
64                        }
65                        Err(plutus_parser::DecodeError::UnexpectedVariant { variant })
66                    };
67                    to_plutus = quote! {
68                        plutus_parser::create_constr(#n, vec![])
69                    }
70                }
71                Fields::Unnamed(fields) => {
72                    let names: Vec<_> = fields
73                        .unnamed
74                        .iter()
75                        .enumerate()
76                        .map(|(i, field)| {
77                            let name = format!("f{i}");
78                            let span = field.span();
79                            Ident::new(&name, span)
80                        })
81                        .collect();
82                    let assignments: Vec<_> = names
83                        .iter()
84                        .map(|n| {
85                            quote! {
86                                plutus_parser::AsPlutus::from_plutus(#n)?,
87                            }
88                        })
89                        .collect();
90                    let casts: Vec<_> = names
91                        .iter()
92                        .map(|n| {
93                            quote! {
94                                #n.to_plutus(),
95                            }
96                        })
97                        .collect();
98                    from_plutus = quote! {
99                        let (variant, fields) = plutus_parser::parse_constr(data)?;
100                        if variant == #n {
101                            let [#(#names),*] = plutus_parser::parse_variant(variant, fields)?;
102                            return Ok(Self(#(#assignments)*));
103                        }
104                        Err(plutus_parser::DecodeError::UnexpectedVariant { variant })
105                    };
106                    to_plutus = quote! {
107                        let Self(#(#names),*) = self;
108                        plutus_parser::create_constr(#n, vec![
109                            #(#casts)*
110                        ])
111                    }
112                }
113            };
114
115            quote! {
116                fn from_plutus(data: plutus_parser::PlutusData) -> Result<Self, plutus_parser::DecodeError> {
117                    #from_plutus
118                }
119
120                fn to_plutus(self) -> plutus_parser::PlutusData {
121                    #to_plutus
122                }
123            }
124        }
125        Data::Enum(e) => {
126            let mut from_plutus = quote! {
127                let (variant, fields) = plutus_parser::parse_constr(data)?;
128            };
129            let mut to_plutus = quote! {};
130            let mut seen_variants = HashSet::new();
131            for variant in &e.variants {
132                let name = &variant.ident;
133                let n = get_variant(&variant.attrs).unwrap_or(seen_variants.len() as u64);
134                seen_variants.insert(n);
135                let (from_clause, to_clause) = match &variant.fields {
136                    Fields::Named(named) => {
137                        let names: Vec<_> = named
138                            .named
139                            .iter()
140                            .map(|n| n.ident.as_ref().unwrap())
141                            .collect();
142                        let assignments = names.iter().map(|n| {
143                            quote! {
144                                #n: plutus_parser::AsPlutus::from_plutus(#n)?,
145                            }
146                        });
147                        let casts: Vec<_> = names
148                            .iter()
149                            .map(|n| {
150                                quote! {
151                                    #n.to_plutus(),
152                                }
153                            })
154                            .collect();
155                        (
156                            quote! {
157                                let [#(#names),*] = plutus_parser::parse_variant(variant, fields)?;
158                                return Ok(Self::#name {
159                                    #(#assignments)*
160                                });
161                            },
162                            quote! {
163                                Self::#name { #(#names),* } => plutus_parser::create_constr(#n, vec![
164                                    #(#casts)*
165                                ]),
166                            },
167                        )
168                    }
169                    Fields::Unit => (
170                        quote! {
171                            let [] = plutus_parser::parse_variant(variant, fields)?;
172                            return Ok(Self::#name);
173                        },
174                        quote! {
175                            Self::#name => plutus_parser::create_constr(#n, vec![]),
176                        },
177                    ),
178                    Fields::Unnamed(fields) => {
179                        let names: Vec<_> = fields
180                            .unnamed
181                            .iter()
182                            .enumerate()
183                            .map(|(i, field)| {
184                                let name = format!("f{i}");
185                                let span = field.span();
186                                Ident::new(&name, span)
187                            })
188                            .collect();
189                        let assignments: Vec<_> = names
190                            .iter()
191                            .map(|n| {
192                                quote! {
193                                    plutus_parser::AsPlutus::from_plutus(#n)?,
194                                }
195                            })
196                            .collect();
197                        let casts: Vec<_> = names
198                            .iter()
199                            .map(|n| {
200                                quote! {
201                                    #n.to_plutus(),
202                                }
203                            })
204                            .collect();
205                        (
206                            quote! {
207                                let [#(#names),*] = plutus_parser::parse_variant(variant, fields)?;
208                                return Ok(Self::#name(#(#assignments),*));
209                            },
210                            quote! {
211                                Self::#name(#(#names)*) => plutus_parser::create_constr(#n, vec![
212                                    #(#casts)*
213                                ]),
214                            },
215                        )
216                    }
217                };
218                from_plutus.extend(quote_spanned! { variant.span() =>
219                    if variant == #n {
220                        #from_clause
221                    }
222                });
223                to_plutus.extend(quote_spanned! {variant.span() =>
224                    #to_clause
225                });
226            }
227            from_plutus.extend(quote! {
228                Err(plutus_parser::DecodeError::UnexpectedVariant { variant })
229            });
230
231            quote! {
232                fn from_plutus(data: plutus_parser::PlutusData) -> Result<Self, plutus_parser::DecodeError> {
233                    #from_plutus
234                }
235
236                fn to_plutus(self) -> plutus_parser::PlutusData {
237                    match self {
238                        #to_plutus
239                    }
240                }
241            }
242        }
243        _ => {
244            return Error::new(Span::call_site(), "Unsupported type")
245                .into_compile_error()
246                .into();
247        }
248    };
249
250    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
251    let expanded = quote! {
252        impl #impl_generics plutus_parser::AsPlutus for #name #ty_generics #where_clause {
253            #implementation
254        }
255    };
256
257    TokenStream::from(expanded)
258}
259
260fn get_variant(attrs: &[Attribute]) -> Option<u64> {
261    attrs.iter().find_map(|a| {
262        let Meta::NameValue(name_value) = &a.meta else {
263            return None;
264        };
265        if !name_value.path.is_ident("variant") {
266            return None;
267        }
268        let Expr::Lit(ExprLit {
269            lit: Lit::Int(int), ..
270        }) = &name_value.value
271        else {
272            return None;
273        };
274        int.base10_parse().ok()
275    })
276}