#![forbid(unsafe_code)]
#![warn(missing_docs)]
mod ast_registry;
mod define_op;
use proc_macro::TokenStream;
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::spanned::Spanned;
use syn::{
parse_macro_input, Attribute, Data, DeriveInput, ExprArray, Fields, ItemStruct, LitStr, Meta,
Token,
};
#[proc_macro]
pub fn define_op(item: TokenStream) -> TokenStream {
define_op::define_op_impl(item)
}
#[proc_macro]
pub fn vyre_ast_registry(item: TokenStream) -> TokenStream {
ast_registry::vyre_ast_registry_impl(item)
}
#[proc_macro_attribute]
pub fn skip_builder(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
struct PassArgs {
name: LitStr,
requires: Vec<LitStr>,
invalidates: Vec<LitStr>,
}
impl Parse for PassArgs {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let mut name = None;
let mut requires = Vec::new();
let mut invalidates = Vec::new();
while !input.is_empty() {
let key: syn::Ident = input.parse()?;
input.parse::<Token![=]>()?;
match key.to_string().as_str() {
"name" => name = Some(input.parse()?),
"requires" => requires = parse_string_array(input)?,
"invalidates" => invalidates = parse_string_array(input)?,
_ => {
return Err(syn::Error::new(
key.span(),
"unsupported vyre_pass argument. Fix: use name, requires, or invalidates.",
));
}
}
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
}
}
Ok(Self {
name: name.ok_or_else(|| input.error("missing pass name. Fix: add name = \"...\"."))?,
requires,
invalidates,
})
}
}
fn parse_string_array(input: ParseStream<'_>) -> syn::Result<Vec<LitStr>> {
let array: ExprArray = input.parse()?;
array
.elems
.into_iter()
.map(|expr| match expr {
syn::Expr::Lit(lit) => match lit.lit {
syn::Lit::Str(value) => Ok(value),
other => Err(syn::Error::new_spanned(
other,
"pass metadata arrays accept only string literals. Fix: use [\"analysis_name\"].",
)),
},
other => Err(syn::Error::new_spanned(
other,
"pass metadata arrays accept only string literals. Fix: use [\"analysis_name\"].",
)),
})
.collect()
}
#[proc_macro_attribute]
pub fn vyre_pass(args: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as PassArgs);
let item = parse_macro_input!(item as ItemStruct);
let ident = &item.ident;
let name = args.name;
let requires = args.requires;
let invalidates = args.invalidates;
quote! {
#item
impl ::vyre::optimizer::private::Sealed for #ident {}
impl ::vyre::optimizer::ProgramPass for #ident {
#[inline]
fn metadata(&self) -> ::vyre::optimizer::PassMetadata {
::vyre::optimizer::PassMetadata {
name: #name,
requires: &[#(#requires),*],
invalidates: &[#(#invalidates),*],
}
}
#[inline]
fn analyze(&self, program: &::vyre::ir::Program) -> ::vyre::optimizer::PassAnalysis {
Self::analyze(program)
}
#[inline]
fn transform(
&self,
program: ::vyre::ir::Program,
) -> ::vyre::optimizer::PassResult {
Self::transform(program)
}
#[inline]
fn fingerprint(&self, program: &::vyre::ir::Program) -> u64 {
Self::fingerprint(program)
}
}
::inventory::submit! {
::vyre::optimizer::ProgramPassRegistration {
metadata: ::vyre::optimizer::PassMetadata {
name: #name,
requires: &[#(#requires),*],
invalidates: &[#(#invalidates),*],
},
factory: || ::std::boxed::Box::new(#ident),
}
}
}
.into()
}
#[proc_macro_derive(AlgebraicLaws, attributes(vyre))]
pub fn derive_algebraic_laws(item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as DeriveInput);
let ident = &input.ident;
let laws = match extract_laws_attribute(&input.attrs) {
Ok(v) => v,
Err(e) => return e.to_compile_error().into(),
};
let law_exprs = laws.iter().map(|lit| {
let src = lit.value();
let trimmed = src.trim();
let path: syn::Expr = match syn::parse_str(&format!("::vyre::ops::AlgebraicLaw::{trimmed}"))
{
Ok(e) => e,
Err(err) => {
return syn::Error::new_spanned(
lit,
format!("failed to parse AlgebraicLaw variant `{trimmed}`: {err}"),
)
.to_compile_error();
}
};
quote! { #path }
});
match &input.data {
Data::Struct(_) | Data::Enum(_) => {}
Data::Union(_) => {
return syn::Error::new_spanned(
ident,
"#[derive(AlgebraicLaws)] does not support unions.",
)
.to_compile_error()
.into();
}
}
let law_exprs_vec: Vec<_> = law_exprs.collect();
quote! {
impl #ident {
pub const LAWS: &'static [::vyre::ops::AlgebraicLaw] = &[
#(#law_exprs_vec),*
];
}
impl ::vyre::ops::AlgebraicLawProvider for #ident {
fn laws() -> &'static [::vyre::ops::AlgebraicLaw] {
Self::LAWS
}
}
}
.into()
}
fn extract_laws_attribute(attrs: &[Attribute]) -> syn::Result<Vec<LitStr>> {
for attr in attrs {
if !attr.path().is_ident("vyre") {
continue;
}
let mut laws: Option<Vec<LitStr>> = None;
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("laws") {
let value = meta.value()?;
let lookahead = value.lookahead1();
if lookahead.peek(syn::token::Bracket) {
let content;
syn::bracketed!(content in value);
let mut collected = Vec::new();
while !content.is_empty() {
if content.peek(LitStr) {
let lit: LitStr = content.parse()?;
collected.push(lit);
} else {
let expr: syn::Expr = content.parse()?;
let rendered = quote! { #expr }.to_string();
collected.push(LitStr::new(&rendered, expr.span()));
}
if content.peek(Token![,]) {
content.parse::<Token![,]>()?;
}
}
laws = Some(collected);
Ok(())
} else {
Err(meta.error("expected `laws = [..]`"))
}
} else {
Err(meta.error("unknown vyre() argument; expected `laws = [..]`"))
}
})?;
if let Some(l) = laws {
return Ok(l);
}
}
Ok(Vec::new())
}
#[allow(dead_code)]
fn _keep_imports_alive(_: Fields, _: Meta) {}