Skip to main content

vyre_macros/
lib.rs

1#![forbid(unsafe_code)]
2#![warn(missing_docs)]
3//! Procedural macros for the [`vyre`](https://docs.rs/vyre) GPU compute IR
4//! compiler.
5//!
6//! This crate is compile-time only. Downstream users import from
7//! `vyre::optimizer::vyre_pass` rather than depending on this crate directly.
8//!
9//! The single macro is [`macro@vyre_pass`] — see that item for the full usage
10//! contract, argument shape, and a worked example. A high-level narrative
11//! lives in the crate [README](https://github.com/).
12
13mod ast_registry;
14mod define_op;
15
16use proc_macro::TokenStream;
17use quote::quote;
18use syn::parse::{Parse, ParseStream};
19use syn::spanned::Spanned;
20use syn::{
21    parse_macro_input, Attribute, Data, DeriveInput, ExprArray, Fields, ItemStruct, LitStr, Meta,
22    Token,
23};
24
25/// Function-like `define_op!` — single-site op registration via inventory.
26///
27/// See [`define_op`](define_op/index.html) for the full argument contract.
28#[proc_macro]
29pub fn define_op(item: TokenStream) -> TokenStream {
30    define_op::define_op_impl(item)
31}
32
33/// Generates the declarative IR AST core (Expr and Node enums)
34/// plus serialization and visitor traits.
35#[proc_macro]
36pub fn vyre_ast_registry(item: TokenStream) -> TokenStream {
37    ast_registry::vyre_ast_registry_impl(item)
38}
39
40/// A generic marker attribute used exclusively to instruct `vyre_ast_registry!`
41/// to skip generating a builder method for a specific struct field.
42#[proc_macro_attribute]
43pub fn skip_builder(_attr: TokenStream, item: TokenStream) -> TokenStream {
44    item
45}
46
47struct PassArgs {
48    name: LitStr,
49    requires: Vec<LitStr>,
50    invalidates: Vec<LitStr>,
51}
52
53impl Parse for PassArgs {
54    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
55        let mut name = None;
56        let mut requires = Vec::new();
57        let mut invalidates = Vec::new();
58
59        while !input.is_empty() {
60            let key: syn::Ident = input.parse()?;
61            input.parse::<Token![=]>()?;
62            match key.to_string().as_str() {
63                "name" => name = Some(input.parse()?),
64                "requires" => requires = parse_string_array(input)?,
65                "invalidates" => invalidates = parse_string_array(input)?,
66                _ => {
67                    return Err(syn::Error::new(
68                        key.span(),
69                        "unsupported vyre_pass argument. Fix: use name, requires, or invalidates.",
70                    ));
71                }
72            }
73            if input.peek(Token![,]) {
74                input.parse::<Token![,]>()?;
75            }
76        }
77
78        Ok(Self {
79            name: name.ok_or_else(|| input.error("missing pass name. Fix: add name = \"...\"."))?,
80            requires,
81            invalidates,
82        })
83    }
84}
85
86fn parse_string_array(input: ParseStream<'_>) -> syn::Result<Vec<LitStr>> {
87    let array: ExprArray = input.parse()?;
88    array
89        .elems
90        .into_iter()
91        .map(|expr| match expr {
92            syn::Expr::Lit(lit) => match lit.lit {
93                syn::Lit::Str(value) => Ok(value),
94                other => Err(syn::Error::new_spanned(
95                    other,
96                    "pass metadata arrays accept only string literals. Fix: use [\"analysis_name\"].",
97                )),
98            },
99            other => Err(syn::Error::new_spanned(
100                other,
101                "pass metadata arrays accept only string literals. Fix: use [\"analysis_name\"].",
102            )),
103        })
104        .collect()
105}
106
107/// Register a unit struct as a `vyre::optimizer::ProgramPass`.
108///
109/// Expands to (a) a full `ProgramPass` trait impl that forwards to your inherent
110/// `analyze` / `transform` / `fingerprint` methods and (b) an
111/// `inventory::submit!` that adds the pass to the global registry so
112/// `vyre::optimize()` picks it up automatically.
113///
114/// # Arguments
115///
116/// | Argument       | Type        | Meaning                                                             |
117/// |----------------|-------------|---------------------------------------------------------------------|
118/// | `name`         | string lit  | Stable pass name used in diagnostics / ordering.                    |
119/// | `requires`     | `[&str]`    | Pass names that must fire before this one.                          |
120/// | `invalidates`  | `[&str]`    | Analyses invalidated when this pass rewrites the program.           |
121///
122/// # Required inherent methods on the annotated type
123///
124/// ```ignore
125/// fn analyze(program: &Program) -> PassAnalysis;
126/// fn transform(program: Program) -> PassResult;
127/// fn fingerprint(program: &Program) -> u64;
128/// ```
129///
130/// # Example
131///
132/// ```ignore
133/// use vyre::optimizer::{vyre_pass, PassAnalysis, PassResult, fingerprint_program};
134/// use vyre::ir::Program;
135///
136/// #[vyre_pass(name = "fold_zero_add", requires = [], invalidates = [])]
137/// pub struct FoldZeroAdd;
138///
139/// impl FoldZeroAdd {
140///     fn analyze(_program: &Program) -> PassAnalysis { PassAnalysis::RUN }
141///     fn transform(program: Program) -> PassResult {
142///         // ... real rewrite ...
143///         PassResult::from_programs(&program.clone(), program)
144///     }
145///     fn fingerprint(program: &Program) -> u64 { fingerprint_program(program) }
146/// }
147/// ```
148///
149/// After expansion, `vyre::optimize(p)` will pick up `FoldZeroAdd` through
150/// the `inventory::collect!(ProgramPassRegistration)` entry emitted by the macro.
151/// No manual registration needed.
152#[proc_macro_attribute]
153pub fn vyre_pass(args: TokenStream, item: TokenStream) -> TokenStream {
154    let args = parse_macro_input!(args as PassArgs);
155    let item = parse_macro_input!(item as ItemStruct);
156    let ident = &item.ident;
157    let name = args.name;
158    let requires = args.requires;
159    let invalidates = args.invalidates;
160
161    quote! {
162        #item
163
164        impl ::vyre::optimizer::private::Sealed for #ident {}
165
166        impl ::vyre::optimizer::ProgramPass for #ident {
167            #[inline]
168            fn metadata(&self) -> ::vyre::optimizer::PassMetadata {
169                ::vyre::optimizer::PassMetadata {
170                    name: #name,
171                    requires: &[#(#requires),*],
172                    invalidates: &[#(#invalidates),*],
173                }
174            }
175
176            #[inline]
177            fn analyze(&self, program: &::vyre::ir::Program) -> ::vyre::optimizer::PassAnalysis {
178                Self::analyze(program)
179            }
180
181            #[inline]
182            fn transform(
183                &self,
184                program: ::vyre::ir::Program,
185            ) -> ::vyre::optimizer::PassResult {
186                Self::transform(program)
187            }
188
189            #[inline]
190            fn fingerprint(&self, program: &::vyre::ir::Program) -> u64 {
191                Self::fingerprint(program)
192            }
193        }
194
195        ::inventory::submit! {
196            ::vyre::optimizer::ProgramPassRegistration {
197                metadata: ::vyre::optimizer::PassMetadata {
198                    name: #name,
199                    requires: &[#(#requires),*],
200                    invalidates: &[#(#invalidates),*],
201                },
202                factory: || ::std::boxed::Box::new(#ident),
203            }
204        }
205    }
206    .into()
207}
208
209/// Derive `vyre::AlgebraicLawProvider` from a `#[vyre(laws = [...])]` attribute.
210///
211/// Attach the derive to a unit struct (or any struct) that represents an op
212/// type. List its algebraic laws in the attribute; the macro emits the trait
213/// impl plus a `const LAWS: &[AlgebraicLaw]` associated item.
214///
215/// # Example
216///
217/// ```ignore
218/// use vyre_macros::AlgebraicLaws;
219///
220/// #[derive(AlgebraicLaws)]
221/// #[vyre(laws = [Commutative, Associative, "Identity { element: 0 }"])]
222/// pub struct Xor;
223/// ```
224///
225/// Expands to:
226///
227/// ```ignore
228/// impl Xor {
229///     pub const LAWS: &'static [::vyre::ops::AlgebraicLaw] = &[
230///         ::vyre::ops::AlgebraicLaw::Commutative,
231///         ::vyre::ops::AlgebraicLaw::Associative,
232///         ::vyre::ops::AlgebraicLaw::Identity { element: 0 },
233///     ];
234/// }
235/// impl ::vyre::ops::AlgebraicLawProvider for Xor {
236///     fn laws() -> &'static [::vyre::ops::AlgebraicLaw] { Self::LAWS }
237/// }
238/// ```
239#[proc_macro_derive(AlgebraicLaws, attributes(vyre))]
240pub fn derive_algebraic_laws(item: TokenStream) -> TokenStream {
241    let input = parse_macro_input!(item as DeriveInput);
242    let ident = &input.ident;
243    let laws = match extract_laws_attribute(&input.attrs) {
244        Ok(v) => v,
245        Err(e) => return e.to_compile_error().into(),
246    };
247
248    // Parse each law string as an AlgebraicLaw variant expression.
249    let law_exprs = laws.iter().map(|lit| {
250        let src = lit.value();
251        let trimmed = src.trim();
252        let path: syn::Expr = match syn::parse_str(&format!("::vyre::ops::AlgebraicLaw::{trimmed}"))
253        {
254            Ok(e) => e,
255            Err(err) => {
256                return syn::Error::new_spanned(
257                    lit,
258                    format!("failed to parse AlgebraicLaw variant `{trimmed}`: {err}"),
259                )
260                .to_compile_error();
261            }
262        };
263        quote! { #path }
264    });
265
266    // ensure the input type is a struct/enum we can attach impls to
267    match &input.data {
268        Data::Struct(_) | Data::Enum(_) => {}
269        Data::Union(_) => {
270            return syn::Error::new_spanned(
271                ident,
272                "#[derive(AlgebraicLaws)] does not support unions.",
273            )
274            .to_compile_error()
275            .into();
276        }
277    }
278
279    let law_exprs_vec: Vec<_> = law_exprs.collect();
280
281    quote! {
282        impl #ident {
283            /// Algebraic laws declared on this op type.
284            pub const LAWS: &'static [::vyre::ops::AlgebraicLaw] = &[
285                #(#law_exprs_vec),*
286            ];
287        }
288
289        impl ::vyre::ops::AlgebraicLawProvider for #ident {
290            fn laws() -> &'static [::vyre::ops::AlgebraicLaw] {
291                Self::LAWS
292            }
293        }
294    }
295    .into()
296}
297
298fn extract_laws_attribute(attrs: &[Attribute]) -> syn::Result<Vec<LitStr>> {
299    for attr in attrs {
300        if !attr.path().is_ident("vyre") {
301            continue;
302        }
303        let mut laws: Option<Vec<LitStr>> = None;
304        attr.parse_nested_meta(|meta| {
305            if meta.path.is_ident("laws") {
306                let value = meta.value()?;
307                // Accept both [Commutative, Identity{element:0}] bracketed
308                // identifier lists and [ "Commutative", "Identity{element:0}" ]
309                // string-literal arrays.
310                let lookahead = value.lookahead1();
311                if lookahead.peek(syn::token::Bracket) {
312                    let content;
313                    syn::bracketed!(content in value);
314                    let mut collected = Vec::new();
315                    while !content.is_empty() {
316                        if content.peek(LitStr) {
317                            let lit: LitStr = content.parse()?;
318                            collected.push(lit);
319                        } else {
320                            // parse as raw token stream up to the next comma
321                            let expr: syn::Expr = content.parse()?;
322                            let rendered = quote! { #expr }.to_string();
323                            collected.push(LitStr::new(&rendered, expr.span()));
324                        }
325                        if content.peek(Token![,]) {
326                            content.parse::<Token![,]>()?;
327                        }
328                    }
329                    laws = Some(collected);
330                    Ok(())
331                } else {
332                    Err(meta.error("expected `laws = [..]`"))
333                }
334            } else {
335                Err(meta.error("unknown vyre() argument; expected `laws = [..]`"))
336            }
337        })?;
338        if let Some(l) = laws {
339            return Ok(l);
340        }
341    }
342    Ok(Vec::new())
343}
344
345// Keep unused imports alive (silence the compiler's unused warnings; `Fields`
346// and `Meta` are referenced through docs/future use, and removing them here
347// risks churn during the open-IR migration).
348#[allow(dead_code)]
349fn _keep_imports_alive(_: Fields, _: Meta) {}