use std::borrow::Borrow;
use syn::{
parse_quote, token::Brace, Block, Expr, ExprBlock, ExprForLoop, ExprIf, ExprLet, ExprRange,
Pat, PatIdent, RangeLimits, Stmt,
};
pub(crate) fn unroll_in_block(block: &Block, unroll_by: usize) -> Block {
let Block {
ref brace_token,
ref stmts,
} = block;
let mut new_stmts = Vec::new();
for stmt in stmts {
if let Stmt::Expr(expr, semi) = stmt {
new_stmts.push(Stmt::Expr(unroll(expr, unroll_by), *semi));
} else {
new_stmts.push((*stmt).clone());
}
}
Block {
brace_token: *brace_token,
stmts: new_stmts,
}
}
fn unroll(expr: &Expr, unroll_by: usize) -> Expr {
if let Expr::ForLoop(for_loop) = expr {
let ExprForLoop {
ref attrs,
ref label,
ref pat,
expr: ref range,
ref body,
..
} = *for_loop;
let new_body = unroll_in_block(body, unroll_by);
let forloop_with_body = |body| {
Expr::ForLoop(ExprForLoop {
body,
..(*for_loop).clone()
})
};
if let Pat::Ident(PatIdent {
by_ref,
mutability,
ident,
subpat,
..
}) = *pat.clone()
{
if by_ref.is_some() || mutability.is_some() || subpat.is_some() {
return forloop_with_body(new_body);
}
let idx = ident;
if let Expr::Range(ExprRange {
start, limits, end, ..
}) = range.borrow()
{
let begin = match start {
Some(e) => e.clone(),
_ => Box::new(parse_quote!(0usize)),
};
let end = match end {
Some(e) => e.clone(),
_ => return forloop_with_body(new_body),
};
let end_is_closed = if let RangeLimits::Closed(_) = limits {
1usize
} else {
0
};
let end: Expr = parse_quote!(#end + #end_is_closed);
let preamble: Vec<Stmt> = parse_quote! {
let total_iters: usize = (#end).checked_sub(#begin).unwrap_or(0);
let num_loops = total_iters / #unroll_by;
let remainder = total_iters % #unroll_by;
};
let mut block = Block {
brace_token: Brace::default(),
stmts: preamble,
};
let mut loop_expr: ExprForLoop = parse_quote! {
for #idx in (0..num_loops) {
let mut #idx = #begin + #idx * #unroll_by;
}
};
let loop_block: Vec<Stmt> = parse_quote! {
if #idx < #end {
#new_body
}
#idx += 1;
};
let loop_body = (0..unroll_by).flat_map(|_| loop_block.clone());
loop_expr.body.stmts.extend(loop_body);
block.stmts.push(Stmt::Expr(Expr::ForLoop(loop_expr), None));
block
.stmts
.push(parse_quote! { let mut #idx = #begin + num_loops * #unroll_by; });
let post_loop_block: Vec<Stmt> = parse_quote! {
if #idx < #end {
#new_body
}
#idx += 1;
};
let post_loop = (0..unroll_by).flat_map(|_| post_loop_block.clone());
block.stmts.extend(post_loop);
let mut attrs = attrs.clone();
attrs.extend(vec![parse_quote!(#[allow(unused)])]);
Expr::Block(ExprBlock {
attrs,
label: label.clone(),
block,
})
} else {
forloop_with_body(new_body)
}
} else {
forloop_with_body(new_body)
}
} else if let Expr::If(if_expr) = expr {
let ExprIf {
ref cond,
ref then_branch,
ref else_branch,
..
} = *if_expr;
Expr::If(ExprIf {
cond: Box::new(unroll(cond, unroll_by)),
then_branch: unroll_in_block(then_branch, unroll_by),
else_branch: else_branch
.as_ref()
.map(|x| (x.0, Box::new(unroll(&x.1, unroll_by)))),
..(*if_expr).clone()
})
} else if let Expr::Let(let_expr) = expr {
let ExprLet { ref expr, .. } = *let_expr;
Expr::Let(ExprLet {
expr: Box::new(unroll(expr, unroll_by)),
..(*let_expr).clone()
})
} else if let Expr::Block(expr_block) = expr {
let ExprBlock { ref block, .. } = *expr_block;
Expr::Block(ExprBlock {
block: unroll_in_block(block, unroll_by),
..(*expr_block).clone()
})
} else {
(*expr).clone()
}
}
#[test]
fn test_expand() {
use quote::ToTokens;
let for_loop: Block = parse_quote! {{
let mut sum = 0;
for i in 0..8 {
sum += i;
}
}};
println!("{}", unroll_in_block(&for_loop, 12).to_token_stream());
}