sp1_derive/
lib.rs

1// The `aligned_borrow_derive` macro is taken from valida-xyz/valida under MIT license
2//
3// The MIT License (MIT)
4//
5// Copyright (c) 2023 The Valida Authors
6//
7// Permission is hereby granted, free of charge, to any person obtaining a copy
8// of this software and associated documentation files (the "Software"), to deal
9// in the Software without restriction, including without limitation the rights
10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11// copies of the Software, and to permit persons to whom the Software is
12// furnished to do so, subject to the following conditions:
13//
14// The above copyright notice and this permission notice shall be included in
15// all copies or substantial portions of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23// THE SOFTWARE.
24
25extern crate proc_macro;
26
27use proc_macro::TokenStream;
28use quote::quote;
29use syn::{
30    parse_macro_input, parse_quote, Data, DeriveInput, GenericParam, ItemFn, WherePredicate,
31};
32
33#[proc_macro_derive(AlignedBorrow)]
34pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream {
35    let ast = parse_macro_input!(input as DeriveInput);
36    let name = &ast.ident;
37
38    // Get first generic which must be type (ex. `T`) for input <T, N: NumLimbs, const M: usize>
39    let type_generic = ast
40        .generics
41        .params
42        .iter()
43        .map(|param| match param {
44            GenericParam::Type(type_param) => &type_param.ident,
45            _ => panic!("Expected first generic to be a type"),
46        })
47        .next()
48        .expect("Expected at least one generic");
49
50    // Get generics after the first (ex. `N: NumLimbs, const M: usize`)
51    // We need this because when we assert the size, we want to substitute u8 for T.
52    let non_first_generics = ast
53        .generics
54        .params
55        .iter()
56        .skip(1)
57        .filter_map(|param| match param {
58            GenericParam::Type(type_param) => Some(&type_param.ident),
59            GenericParam::Const(const_param) => Some(&const_param.ident),
60            _ => None,
61        })
62        .collect::<Vec<_>>();
63
64    // Get impl generics (`<T, N: NumLimbs, const M: usize>`), type generics (`<T, N>`), where
65    // clause (`where T: Clone`)
66    let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
67
68    let methods = quote! {
69        impl #impl_generics core::borrow::Borrow<#name #type_generics> for [#type_generic] #where_clause {
70            fn borrow(&self) -> &#name #type_generics {
71                debug_assert_eq!(self.len(), std::mem::size_of::<#name<u8 #(, #non_first_generics)*>>());
72                let (prefix, shorts, _suffix) = unsafe { self.align_to::<#name #type_generics>() };
73                debug_assert!(prefix.is_empty(), "Alignment should match");
74                debug_assert_eq!(shorts.len(), 1);
75                &shorts[0]
76            }
77        }
78
79        impl #impl_generics core::borrow::BorrowMut<#name #type_generics> for [#type_generic] #where_clause {
80            fn borrow_mut(&mut self) -> &mut #name #type_generics {
81                debug_assert_eq!(self.len(), std::mem::size_of::<#name<u8 #(, #non_first_generics)*>>());
82                let (prefix, shorts, _suffix) = unsafe { self.align_to_mut::<#name #type_generics>() };
83                debug_assert!(prefix.is_empty(), "Alignment should match");
84                debug_assert_eq!(shorts.len(), 1);
85                &mut shorts[0]
86            }
87        }
88    };
89
90    TokenStream::from(methods)
91}
92
93#[proc_macro_derive(
94    MachineAir,
95    attributes(sp1_core_path, execution_record_path, program_path, builder_path, eval_trait_bound)
96)]
97pub fn machine_air_derive(input: TokenStream) -> TokenStream {
98    let ast: syn::DeriveInput = syn::parse(input).unwrap();
99
100    let name = &ast.ident;
101    let generics = &ast.generics;
102    let execution_record_path = find_execution_record_path(&ast.attrs);
103    let program_path = find_program_path(&ast.attrs);
104    let builder_path = find_builder_path(&ast.attrs);
105    let eval_trait_bound = find_eval_trait_bound(&ast.attrs);
106    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
107
108    match &ast.data {
109        Data::Struct(_) => unimplemented!("Structs are not supported yet"),
110        Data::Enum(e) => {
111            let variants = e
112                .variants
113                .iter()
114                .map(|variant| {
115                    let variant_name = &variant.ident;
116
117                    let mut fields = variant.fields.iter();
118                    let field = fields.next().unwrap();
119                    assert!(fields.next().is_none(), "Only one field is supported");
120                    (variant_name, field)
121                })
122                .collect::<Vec<_>>();
123
124            let width_arms = variants.iter().map(|(variant_name, field)| {
125                let field_ty = &field.ty;
126                quote! {
127                    #name::#variant_name(x) => <#field_ty as p3_air::BaseAir<F>>::width(x)
128                }
129            });
130
131            let base_air = quote! {
132                impl #impl_generics p3_air::BaseAir<F> for #name #ty_generics #where_clause {
133                    fn width(&self) -> usize {
134                        match self {
135                            #(#width_arms,)*
136                        }
137                    }
138
139                    fn preprocessed_trace(&self) -> Option<p3_matrix::dense::RowMajorMatrix<F>> {
140                        unreachable!("A machine air should use the preprocessed trace from the `MachineAir` trait")
141                    }
142                }
143            };
144
145            let name_arms = variants.iter().map(|(variant_name, field)| {
146                let field_ty = &field.ty;
147                quote! {
148                    #name::#variant_name(x) => <#field_ty as sp1_stark::air::MachineAir<F>>::name(x)
149                }
150            });
151
152            let preprocessed_width_arms = variants.iter().map(|(variant_name, field)| {
153                let field_ty = &field.ty;
154                quote! {
155                    #name::#variant_name(x) => <#field_ty as sp1_stark::air::MachineAir<F>>::preprocessed_width(x)
156                }
157            });
158
159            let generate_preprocessed_trace_arms = variants.iter().map(|(variant_name, field)| {
160                let field_ty = &field.ty;
161                quote! {
162                    #name::#variant_name(x) => <#field_ty as sp1_stark::air::MachineAir<F>>::generate_preprocessed_trace(x, program)
163                }
164            });
165
166            let generate_trace_arms = variants.iter().map(|(variant_name, field)| {
167                let field_ty = &field.ty;
168                quote! {
169                    #name::#variant_name(x) => <#field_ty as sp1_stark::air::MachineAir<F>>::generate_trace(x, input, output)
170                }
171            });
172
173            let generate_dependencies_arms = variants.iter().map(|(variant_name, field)| {
174                let field_ty = &field.ty;
175                quote! {
176                    #name::#variant_name(x) => <#field_ty as sp1_stark::air::MachineAir<F>>::generate_dependencies(x, input, output)
177                }
178            });
179
180            let included_arms = variants.iter().map(|(variant_name, field)| {
181                let field_ty = &field.ty;
182                quote! {
183                    #name::#variant_name(x) => <#field_ty as sp1_stark::air::MachineAir<F>>::included(x, shard)
184                }
185            });
186
187            let commit_scope_arms = variants.iter().map(|(variant_name, field)| {
188                let field_ty = &field.ty;
189                quote! {
190                    #name::#variant_name(x) => <#field_ty as sp1_stark::air::MachineAir<F>>::commit_scope(x)
191                }
192            });
193
194            let local_only_arms = variants.iter().map(|(variant_name, field)| {
195                let field_ty = &field.ty;
196                quote! {
197                    #name::#variant_name(x) => <#field_ty as sp1_stark::air::MachineAir<F>>::local_only(x)
198                }
199            });
200
201            let machine_air = quote! {
202                impl #impl_generics sp1_stark::air::MachineAir<F> for #name #ty_generics #where_clause {
203                    type Record = #execution_record_path;
204
205                    type Program = #program_path;
206
207                    fn name(&self) -> String {
208                        match self {
209                            #(#name_arms,)*
210                        }
211                    }
212
213                    fn preprocessed_width(&self) -> usize {
214                        match self {
215                            #(#preprocessed_width_arms,)*
216                        }
217                    }
218
219                    fn generate_preprocessed_trace(
220                        &self,
221                        program: &#program_path,
222                    ) -> Option<p3_matrix::dense::RowMajorMatrix<F>> {
223                        match self {
224                            #(#generate_preprocessed_trace_arms,)*
225                        }
226                    }
227
228                    fn generate_trace(
229                        &self,
230                        input: &#execution_record_path,
231                        output: &mut #execution_record_path,
232                    ) -> p3_matrix::dense::RowMajorMatrix<F> {
233                        match self {
234                            #(#generate_trace_arms,)*
235                        }
236                    }
237
238                    fn generate_dependencies(
239                        &self,
240                        input: &#execution_record_path,
241                        output: &mut #execution_record_path,
242                    ) {
243                        match self {
244                            #(#generate_dependencies_arms,)*
245                        }
246                    }
247
248                    fn included(&self, shard: &Self::Record) -> bool {
249                        match self {
250                            #(#included_arms,)*
251                        }
252                    }
253
254                    fn commit_scope(&self) -> InteractionScope {
255                        match self {
256                            #(#commit_scope_arms,)*
257                        }
258                    }
259
260                    fn local_only(&self) -> bool {
261                        match self {
262                            #(#local_only_arms,)*
263                        }
264                    }
265                }
266            };
267
268            let eval_arms = variants.iter().map(|(variant_name, field)| {
269                let field_ty = &field.ty;
270                quote! {
271                    #name::#variant_name(x) => <#field_ty as p3_air::Air<AB>>::eval(x, builder)
272                }
273            });
274
275            // Attach an extra generic AB : crate::air::SP1AirBuilder to the generics of the enum
276            let generics = &ast.generics;
277            let mut new_generics = generics.clone();
278            new_generics.params.push(syn::parse_quote! { AB: p3_air::PairBuilder + #builder_path });
279
280            let (air_impl_generics, _, _) = new_generics.split_for_impl();
281
282            let mut new_generics = generics.clone();
283            let where_clause = new_generics.make_where_clause();
284            if let Some(eval_trait_bound) = &eval_trait_bound {
285                let predicate: WherePredicate = syn::parse_str(eval_trait_bound).unwrap();
286                where_clause.predicates.push(predicate);
287            }
288
289            let air = quote! {
290                impl #air_impl_generics p3_air::Air<AB> for #name #ty_generics #where_clause {
291                    fn eval(&self, builder: &mut AB) {
292                        match self {
293                            #(#eval_arms,)*
294                        }
295                    }
296                }
297            };
298
299            quote! {
300                #base_air
301
302                #machine_air
303
304                #air
305            }
306            .into()
307        }
308        Data::Union(_) => unimplemented!("Unions are not supported"),
309    }
310}
311
312#[proc_macro_attribute]
313pub fn cycle_tracker(_attr: TokenStream, item: TokenStream) -> TokenStream {
314    let input = parse_macro_input!(item as ItemFn);
315    let visibility = &input.vis;
316    let name = &input.sig.ident;
317    let inputs = &input.sig.inputs;
318    let output = &input.sig.output;
319    let block = &input.block;
320    let generics = &input.sig.generics;
321    let where_clause = &input.sig.generics.where_clause;
322
323    let result = quote! {
324        #visibility fn #name #generics (#inputs) #output #where_clause {
325            eprintln!("cycle-tracker-start: {}", stringify!(#name));
326            let result = (|| #block)();
327            eprintln!("cycle-tracker-end: {}", stringify!(#name));
328            result
329        }
330    };
331
332    result.into()
333}
334
335#[proc_macro_attribute]
336pub fn cycle_tracker_recursion(_attr: TokenStream, item: TokenStream) -> TokenStream {
337    let input = parse_macro_input!(item as ItemFn);
338    let visibility = &input.vis;
339    let name = &input.sig.ident;
340    let inputs = &input.sig.inputs;
341    let output = &input.sig.output;
342    let block = &input.block;
343    let generics = &input.sig.generics;
344    let where_clause = &input.sig.generics.where_clause;
345
346    let result = quote! {
347        #visibility fn #name #generics (#inputs) #output #where_clause {
348            sp1_recursion_compiler::circuit::CircuitV2Builder::cycle_tracker_v2_enter(builder, stringify!(#name));
349            let result = #block;
350            sp1_recursion_compiler::circuit::CircuitV2Builder::cycle_tracker_v2_exit(builder);
351            result
352        }
353    };
354
355    result.into()
356}
357
358fn find_execution_record_path(attrs: &[syn::Attribute]) -> syn::Path {
359    for attr in attrs {
360        if attr.path.is_ident("execution_record_path") {
361            if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
362                if let syn::Lit::Str(lit_str) = &meta.lit {
363                    if let Ok(path) = lit_str.parse::<syn::Path>() {
364                        return path;
365                    }
366                }
367            }
368        }
369    }
370    parse_quote!(sp1_core_executor::ExecutionRecord)
371}
372
373fn find_program_path(attrs: &[syn::Attribute]) -> syn::Path {
374    for attr in attrs {
375        if attr.path.is_ident("program_path") {
376            if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
377                if let syn::Lit::Str(lit_str) = &meta.lit {
378                    if let Ok(path) = lit_str.parse::<syn::Path>() {
379                        return path;
380                    }
381                }
382            }
383        }
384    }
385    parse_quote!(sp1_core_executor::Program)
386}
387
388fn find_builder_path(attrs: &[syn::Attribute]) -> syn::Path {
389    for attr in attrs {
390        if attr.path.is_ident("builder_path") {
391            if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
392                if let syn::Lit::Str(lit_str) = &meta.lit {
393                    if let Ok(path) = lit_str.parse::<syn::Path>() {
394                        return path;
395                    }
396                }
397            }
398        }
399    }
400    parse_quote!(crate::air::SP1CoreAirBuilder<F = F>)
401}
402
403fn find_eval_trait_bound(attrs: &[syn::Attribute]) -> Option<String> {
404    for attr in attrs {
405        if attr.path.is_ident("eval_trait_bound") {
406            if let Ok(syn::Meta::NameValue(meta)) = attr.parse_meta() {
407                if let syn::Lit::Str(lit_str) = &meta.lit {
408                    return Some(lit_str.value());
409                }
410            }
411        }
412    }
413
414    None
415}