use darling::FromMeta;
use syn::{
Attribute, Expr, ExprReference, Stmt, parse_quote,
visit_mut::{self, VisitMut},
};
use crate::{
expression::Expression,
parse::statement::parse_define_macro,
paths::{prelude_path, prelude_type},
scope::Context,
statement::DefineKind,
};
pub struct Unroll {
pub value: Expression,
pub always_true: bool,
}
impl Unroll {
pub fn from_attributes(
attrs: &[Attribute],
context: &mut Context,
) -> syn::Result<Option<Self>> {
#[derive(FromMeta)]
struct NameVal {
pub value: Expr,
}
let attr = attrs.iter().find(|attr| attr.path().is_ident("unroll"));
let attr = match attr {
Some(attr) => attr,
None => return Ok(None),
};
let res = match &attr.meta {
syn::Meta::Path(_) => Self {
value: Expression::from_expr(parse_quote![true], context).unwrap(),
always_true: true,
},
syn::Meta::List(list) => {
let expr = syn::parse2(list.tokens.clone())?;
let expr = Expression::from_expr(expr, context)?;
Self {
value: expr,
always_true: false,
}
}
meta => {
let expr = NameVal::from_meta(meta)?;
let expr = Expression::from_expr(expr.value, context)?;
Self {
value: expr,
always_true: false,
}
}
};
Ok(Some(res))
}
pub fn unroll_expr(attrs: &[Attribute]) -> Option<Expr> {
#[derive(FromMeta)]
struct NameVal {
pub value: Expr,
}
let attr = attrs.iter().find(|attr| attr.path().is_ident("unroll"))?;
match &attr.meta {
syn::Meta::Path(_) => None,
syn::Meta::List(list) => syn::parse2(list.tokens.clone()).ok(),
meta => Some(NameVal::from_meta(meta).ok()?.value),
}
}
}
pub struct RemoveHelpers;
impl VisitMut for RemoveHelpers {
fn visit_fn_arg_mut(&mut self, i: &mut syn::FnArg) {
match i {
syn::FnArg::Receiver(recv) => recv
.attrs
.retain(|it| !is_comptime_attr(it) && !is_define_attribute(it)),
syn::FnArg::Typed(typed) => typed
.attrs
.retain(|it| !is_comptime_attr(it) && !is_define_attribute(it)),
}
visit_mut::visit_fn_arg_mut(self, i);
}
fn visit_expr_for_loop_mut(&mut self, i: &mut syn::ExprForLoop) {
let unroll = Unroll::unroll_expr(&i.attrs);
i.attrs.retain(|attr| !is_unroll_attr(attr));
if let Some(unroll) = unroll {
i.body
.stmts
.insert(0, parse_quote![let __unroll = #unroll;])
}
visit_mut::visit_expr_for_loop_mut(self, i);
}
fn visit_local_mut(&mut self, i: &mut syn::Local) {
i.attrs.retain(|attr| !is_comptime_attr(attr));
visit_mut::visit_local_mut(self, i);
}
fn visit_expr_match_mut(&mut self, i: &mut syn::ExprMatch) {
i.attrs.retain(|attr| !is_comptime_attr(attr));
visit_mut::visit_expr_match_mut(self, i);
}
fn visit_expr_if_mut(&mut self, i: &mut syn::ExprIf) {
i.attrs.retain(|attr| !is_comptime_attr(attr));
visit_mut::visit_expr_if_mut(self, i);
}
fn visit_type_param_mut(&mut self, i: &mut syn::TypeParam) {
i.attrs.retain(|attr| !is_helper(attr));
visit_mut::visit_type_param_mut(self, i);
}
}
pub struct ReplaceIndices;
pub struct ReplaceIndex;
pub struct ReplaceIndexMut;
pub struct ReplaceDefines;
impl VisitMut for ReplaceIndices {
fn visit_expr_assign_mut(&mut self, i: &mut syn::ExprAssign) {
ReplaceIndexMut.visit_expr_mut(&mut i.left);
ReplaceIndex.visit_expr_mut(&mut i.right);
visit_mut::visit_expr_assign_mut(self, i);
}
fn visit_expr_binary_mut(&mut self, i: &mut syn::ExprBinary) {
match i.op {
syn::BinOp::AddAssign(_)
| syn::BinOp::SubAssign(_)
| syn::BinOp::MulAssign(_)
| syn::BinOp::DivAssign(_)
| syn::BinOp::RemAssign(_)
| syn::BinOp::BitXorAssign(_)
| syn::BinOp::BitAndAssign(_)
| syn::BinOp::BitOrAssign(_)
| syn::BinOp::ShlAssign(_)
| syn::BinOp::ShrAssign(_) => {
ReplaceIndexMut.visit_expr_mut(&mut i.left);
ReplaceIndex.visit_expr_mut(&mut i.right);
}
_ => {}
}
visit_mut::visit_expr_binary_mut(self, i);
}
fn visit_expr_mut(&mut self, i: &mut syn::Expr) {
match i {
Expr::Reference(ExprReference {
mutability: Some(_),
expr,
..
}) => {
ReplaceIndexMut.visit_expr_mut(expr);
}
Expr::Index(_) => ReplaceIndex.visit_expr_mut(i),
_ => {}
}
visit_mut::visit_expr_mut(self, i);
}
fn visit_item_fn_mut(&mut self, i: &mut syn::ItemFn) {
let prelude_path = prelude_path();
let import = parse_quote![use #prelude_path::{
CubeIndex as _, CubeIndexMut as _,
ComptimeIndex as _, ComptimeIndexMut as _
};];
i.block.stmts.insert(0, import);
visit_mut::visit_item_fn_mut(self, i);
}
fn visit_impl_item_fn_mut(&mut self, i: &mut syn::ImplItemFn) {
let prelude_path = prelude_path();
let import = parse_quote![use #prelude_path::{
CubeIndex as _, CubeIndexMut as _,
ComptimeIndex as _, ComptimeIndexMut as _
};];
i.block.stmts.insert(0, import);
visit_mut::visit_impl_item_fn_mut(self, i);
}
fn visit_trait_item_fn_mut(&mut self, i: &mut syn::TraitItemFn) {
if let Some(block) = &mut i.default {
let prelude_path = prelude_path();
let import = parse_quote![use #prelude_path::{
CubeIndex as _, CubeIndexMut as _,
ComptimeIndex as _, ComptimeIndexMut as _
};];
block.stmts.insert(0, import);
}
visit_mut::visit_trait_item_fn_mut(self, i);
}
}
impl VisitMut for ReplaceIndex {
fn visit_expr_mut(&mut self, i: &mut Expr) {
match i {
Expr::Reference(ExprReference {
mutability: Some(_),
expr,
..
}) => {
ReplaceIndexMut.visit_expr_mut(expr);
}
Expr::Index(index) => {
let inner = &index.expr;
let index = &index.index;
*i = parse_quote![*#inner.cube_idx(#index)]
}
_ => {}
}
visit_mut::visit_expr_mut(self, i);
}
}
impl VisitMut for ReplaceIndexMut {
fn visit_expr_mut(&mut self, i: &mut syn::Expr) {
if let Expr::Index(index) = i {
let inner = &index.expr;
let index = &index.index;
*i = parse_quote![*#inner.cube_idx_mut(#index)]
}
visit_mut::visit_expr_mut(self, i);
}
}
impl VisitMut for ReplaceDefines {
fn visit_block_mut(&mut self, i: &mut syn::Block) {
let stmts = core::mem::take(&mut i.stmts);
i.stmts = stmts
.into_iter()
.flat_map(|stmt| match stmt {
Stmt::Local(local) => {
if let Some((name, kind, init)) = parse_define_macro(&local) {
let define: Stmt = match kind {
DefineKind::Type => {
let define_size = prelude_type("define_scalar");
parse_quote![#define_size!(#name);]
}
DefineKind::Size => {
let define_size = prelude_type("define_size");
parse_quote![#define_size!(#name);]
}
};
let init: Stmt = parse_quote!(let _ = #init;);
vec![define, init]
} else {
vec![Stmt::Local(local)]
}
}
other => vec![other],
})
.collect();
visit_mut::visit_block_mut(self, i);
}
}
pub fn is_comptime_attr(attr: &Attribute) -> bool {
attr.path().is_ident("comptime")
}
pub fn is_unroll_attr(attr: &Attribute) -> bool {
attr.path().is_ident("unroll")
}
pub fn is_expr_attribute(attr: &Attribute) -> bool {
attr.path().is_ident("expr")
}
pub fn is_define_attribute(attr: &Attribute) -> bool {
attr.path().is_ident("define")
}
pub fn is_helper(attr: &Attribute) -> bool {
is_comptime_attr(attr)
|| is_unroll_attr(attr)
|| is_expr_attribute(attr)
|| is_define_attribute(attr)
}