rsnark_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{Data, DeriveInput, Fields, Visibility, parse_macro_input};
4
5/// Circuit derive macro
6///
7/// Automatically generates CircuitElement implementation and related helper code
8/// for the annotated struct.
9///
10/// Usage:
11/// ```rust, ignore
12/// use rsnark_macros::Circuit;
13///
14/// #[derive(Circuit)]
15/// pub struct MyCircuit {
16///     a: u32,        // private field
17///     pub b: u32,    // public field
18/// }
19/// ```
20#[proc_macro_derive(Circuit)]
21pub fn circuit_derive(input: TokenStream) -> TokenStream {
22    let input = parse_macro_input!(input as DeriveInput);
23
24    match generate_circuit_impl(&input) {
25        Ok(tokens) => tokens,
26        Err(err) => err.to_compile_error().into(),
27    }
28}
29
30fn generate_circuit_impl(input: &DeriveInput) -> syn::Result<TokenStream> {
31    let name = &input.ident;
32    let name_str = name.to_string().to_lowercase();
33
34    // Generate module name: __rsnark_generated_{name}
35    let module_name = format_ident!("__rsnark_generated_{}", name_str);
36
37    // Generate CircuitDefine struct name: {Name}CircuitDefine
38    let define_name = format_ident!("{}CircuitDefine", name);
39
40    // Generate PublicWitness struct name: {Name}PublicWitness
41    let public_witness_name = format_ident!("{}PublicWitness", name);
42
43    // Parse struct fields
44    let fields = match &input.data {
45        Data::Struct(data) => match &data.fields {
46            Fields::Named(fields) => &fields.named,
47            _ => {
48                return Err(syn::Error::new_spanned(
49                    input,
50                    "Only named fields are supported",
51                ));
52            }
53        },
54        _ => return Err(syn::Error::new_spanned(input, "Only structs are supported")),
55    };
56
57    // Separate public and private fields
58    let mut private_fields = Vec::new();
59    let mut public_fields = Vec::new();
60
61    for field in fields {
62        let field_name = field.ident.as_ref().unwrap();
63        let field_type = &field.ty;
64
65        match &field.vis {
66            Visibility::Public(_) => {
67                public_fields.push((field_name, field_type));
68            }
69            _ => {
70                private_fields.push((field_name, field_type));
71            }
72        }
73    }
74
75    // Generate fields for CircuitDefine struct
76    let define_fields = fields.iter().map(|field| {
77        let field_name = field.ident.as_ref().unwrap();
78        let field_type = &field.ty;
79
80        match &field.vis {
81            Visibility::Public(_) => {
82                quote! {
83                    pub #field_name: ::rsnark_core::PublicCircuitElement<#field_type>
84                }
85            }
86            _ => {
87                quote! {
88                    pub #field_name: ::rsnark_core::PrivateCircuitElement<#field_type>
89                }
90            }
91        }
92    });
93
94    // Generate field initialization in new method
95    let new_field_inits = fields.iter().map(|field| {
96        let field_name = field.ident.as_ref().unwrap();
97        let field_type = &field.ty;
98
99        match &field.vis {
100            Visibility::Public(_) => {
101                quote! {
102                    let #field_name = #field_type::create_public(initer, is_private);
103                }
104            }
105            _ => {
106                quote! {
107                    let #field_name = #field_type::create_private(initer);
108                }
109            }
110        }
111    });
112
113    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
114
115    // Generate append_public method implementation for original struct
116    let append_public_impl_orig = public_fields.iter().map(|(field_name, _)| {
117        quote! {
118            self.#field_name.append_public_witness(witness, false);
119        }
120    });
121
122    // Generate append_public method implementation for PublicWitness struct
123    let append_public_impl_witness = public_fields.iter().map(|(field_name, _)| {
124        quote! {
125            self.#field_name.append_public_witness(witness, false);
126        }
127    });
128
129    // Generate append_witness method implementation for all fields
130    let append_witness_impl = fields.iter().map(|field| {
131        let field_name = field.ident.as_ref().unwrap();
132
133        match &field.vis {
134            Visibility::Public(_) => {
135                // Public fields: is_private = false
136                quote! {
137                    self.#field_name.append_witness(public, private, false || _is_private);
138                }
139            }
140            _ => {
141                // Private fields: is_private = true
142                quote! {
143                    self.#field_name.append_witness(public, private, true);
144                }
145            }
146        }
147    });
148
149    // Generate public witness fields for into_public_witness method
150    let public_witness_fields = public_fields.iter().map(|(field_name, _)| {
151        quote! {
152            #field_name: self.#field_name.into_public_witness()
153        }
154    });
155
156    // Generate PublicWitness struct fields
157    let public_witness_struct_fields = public_fields.iter().map(|(field_name, field_type)| {
158        quote! {
159            pub #field_name: ::rsnark_core::PublicWitness<#field_type>
160        }
161    });
162
163    let expanded = quote! {
164        mod #module_name {
165            use super::*;
166
167            use ::rsnark_core::{
168                CircuitWitness, CircuitPublicWitness, BigInt, VariableIniter,
169            };
170
171            impl CircuitWitness for #name {
172                type PrivateElement = #define_name;
173                type PublicElement = #define_name;
174                type PublicWitness = #public_witness_name;
175
176                fn create_public(initer: &mut VariableIniter, is_private: bool) -> Self::PublicElement {
177                    #define_name::new(initer, is_private)
178                }
179
180                fn create_private(initer: &mut VariableIniter) -> Self::PrivateElement {
181                    #define_name::new(initer, true)
182                }
183
184                fn append_witness(&self, public: &mut Vec<BigInt>, private: &mut Vec<BigInt>, _is_private: bool) {
185                    #(#append_witness_impl)*
186                }
187
188                fn into_public_witness(self) -> Self::PublicWitness {
189                    #public_witness_name {
190                        #(#public_witness_fields,)*
191                    }
192                }
193            }
194
195            #[doc(hidden)]
196            pub struct #define_name {
197                #(#define_fields,)*
198            }
199
200            impl #define_name {
201                fn new(initer: &mut VariableIniter, is_private: bool) -> Self {
202                    #(#new_field_inits)*
203
204                    Self {
205                        #(#field_names,)*
206                    }
207                }
208            }
209
210            impl CircuitPublicWitness for #name {
211                fn append_public_witness(&self, witness: &mut Vec<BigInt>, _is_private: bool) {
212                    #(#append_public_impl_orig)*
213                }
214            }
215
216            #[doc(hidden)]
217            pub struct #public_witness_name {
218                #(#public_witness_struct_fields,)*
219            }
220
221            impl CircuitPublicWitness for #public_witness_name {
222                fn append_public_witness(&self, witness: &mut Vec<BigInt>, _is_private: bool) {
223                    #(#append_public_impl_witness)*
224                }
225            }
226        }
227    };
228
229    Ok(TokenStream::from(expanded))
230}