use cjc_ast::BinOp;
use crate::loop_analysis::{LoopId, LoopTree};
use crate::{MirBody, MirExpr, MirExprKind, MirFunction, MirProgram, MirStmt};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct ReductionId(pub u32);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReductionKind {
StrictFold,
KahanFold,
BinnedFold,
FixedTree,
BuiltinReduction,
Unknown,
}
impl ReductionKind {
pub fn is_reorderable(&self) -> bool {
matches!(self, ReductionKind::BinnedFold)
}
pub fn is_parallelizable(&self) -> bool {
matches!(
self,
ReductionKind::BinnedFold | ReductionKind::FixedTree
)
}
pub fn is_strict(&self) -> bool {
matches!(
self,
ReductionKind::StrictFold | ReductionKind::KahanFold | ReductionKind::Unknown
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReductionOp {
Add,
Mul,
Sub,
Min,
Max,
BitwiseOr,
BitwiseAnd,
BuiltinCall,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AccumulatorSemantics {
Plain,
Kahan,
Binned,
RuntimeDefined,
}
impl std::fmt::Display for AccumulatorSemantics {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AccumulatorSemantics::Plain => write!(f, "plain"),
AccumulatorSemantics::Kahan => write!(f, "kahan"),
AccumulatorSemantics::Binned => write!(f, "binned"),
AccumulatorSemantics::RuntimeDefined => write!(f, "runtime_defined"),
}
}
}
#[derive(Debug, Clone)]
pub struct ReductionInfo {
pub id: ReductionId,
pub accumulator_var: String,
pub op: ReductionOp,
pub kind: ReductionKind,
pub loop_id: Option<LoopId>,
pub function_name: String,
pub builtin_name: Option<String>,
pub reassociation_forbidden: bool,
pub strict_order_required: bool,
pub accumulator_semantics: AccumulatorSemantics,
}
#[derive(Debug, Clone)]
pub struct ReductionReport {
pub reductions: Vec<ReductionInfo>,
}
impl ReductionReport {
pub fn len(&self) -> usize {
self.reductions.len()
}
pub fn is_empty(&self) -> bool {
self.reductions.is_empty()
}
pub fn get(&self, id: ReductionId) -> &ReductionInfo {
&self.reductions[id.0 as usize]
}
pub fn reductions_in_loop(&self, loop_id: LoopId) -> Vec<&ReductionInfo> {
self.reductions
.iter()
.filter(|r| r.loop_id == Some(loop_id))
.collect()
}
pub fn reductions_in_function(&self, fn_name: &str) -> Vec<&ReductionInfo> {
self.reductions
.iter()
.filter(|r| r.function_name == fn_name)
.collect()
}
pub fn has_strict_reductions(&self) -> bool {
self.reductions.iter().any(|r| r.kind.is_strict())
}
}
const BUILTIN_REDUCTIONS: &[&str] = &[
"sum",
"mean",
"dot",
"prod",
"variance",
"sd",
"norm",
"min",
"max",
"median",
"binned_sum",
"kahan_sum",
"trapz",
"simps",
];
pub fn detect_reductions(program: &MirProgram, loop_trees: &[(String, LoopTree)]) -> ReductionReport {
let mut reductions = Vec::new();
let mut next_id: u32 = 0;
let loop_tree_map: Vec<(&str, &LoopTree)> = loop_trees
.iter()
.map(|(name, tree)| (name.as_str(), tree))
.collect();
for func in &program.functions {
let tree = loop_tree_map
.iter()
.find(|(name, _)| *name == func.name)
.map(|(_, t)| *t);
detect_loop_reductions_body(
&func.body,
&func.name,
tree,
&mut reductions,
&mut next_id,
);
detect_builtin_reductions_body(
&func.body,
&func.name,
tree,
&mut reductions,
&mut next_id,
);
}
ReductionReport { reductions }
}
fn detect_loop_reductions_body(
body: &MirBody,
fn_name: &str,
loop_tree: Option<&LoopTree>,
reductions: &mut Vec<ReductionInfo>,
next_id: &mut u32,
) {
for stmt in &body.stmts {
match stmt {
MirStmt::While { body: loop_body, .. } => {
scan_body_for_accumulations(
loop_body,
fn_name,
loop_tree,
reductions,
next_id,
);
detect_loop_reductions_body(
loop_body,
fn_name,
loop_tree,
reductions,
next_id,
);
}
MirStmt::If {
then_body,
else_body,
..
} => {
detect_loop_reductions_body(
then_body,
fn_name,
loop_tree,
reductions,
next_id,
);
if let Some(eb) = else_body {
detect_loop_reductions_body(
eb,
fn_name,
loop_tree,
reductions,
next_id,
);
}
}
MirStmt::NoGcBlock(inner) => {
detect_loop_reductions_body(
inner,
fn_name,
loop_tree,
reductions,
next_id,
);
}
_ => {}
}
}
}
fn scan_body_for_accumulations(
body: &MirBody,
fn_name: &str,
_loop_tree: Option<&LoopTree>,
reductions: &mut Vec<ReductionInfo>,
next_id: &mut u32,
) {
for stmt in &body.stmts {
if let MirStmt::Expr(expr) = stmt {
if let Some((acc_name, op)) = match_accumulation_pattern(expr) {
let already = reductions.iter().any(|r| {
r.accumulator_var == acc_name
&& r.function_name == fn_name
&& r.kind == ReductionKind::StrictFold
});
if !already {
reductions.push(ReductionInfo {
id: ReductionId(*next_id),
accumulator_var: acc_name,
op,
kind: ReductionKind::StrictFold,
loop_id: None, function_name: fn_name.to_string(),
builtin_name: None,
reassociation_forbidden: true,
strict_order_required: true,
accumulator_semantics: AccumulatorSemantics::Plain,
});
*next_id += 1;
}
}
}
}
}
fn match_accumulation_pattern(expr: &MirExpr) -> Option<(String, ReductionOp)> {
if let MirExprKind::Assign { target, value } = &expr.kind {
if let MirExprKind::Var(acc_name) = &target.kind {
if let MirExprKind::Binary { op, left, .. } = &value.kind {
if let MirExprKind::Var(left_name) = &left.kind {
if left_name == acc_name {
let reduction_op = match op {
BinOp::Add => Some(ReductionOp::Add),
BinOp::Mul => Some(ReductionOp::Mul),
BinOp::Sub => Some(ReductionOp::Sub),
BinOp::BitOr => Some(ReductionOp::BitwiseOr),
BinOp::BitAnd => Some(ReductionOp::BitwiseAnd),
_ => None,
};
return reduction_op.map(|rop| (acc_name.clone(), rop));
}
}
}
}
}
None
}
fn detect_builtin_reductions_body(
body: &MirBody,
fn_name: &str,
loop_tree: Option<&LoopTree>,
reductions: &mut Vec<ReductionInfo>,
next_id: &mut u32,
) {
for stmt in &body.stmts {
match stmt {
MirStmt::Let { init, .. } => {
detect_builtin_reductions_expr(init, fn_name, loop_tree, reductions, next_id);
}
MirStmt::Expr(expr) => {
detect_builtin_reductions_expr(expr, fn_name, loop_tree, reductions, next_id);
}
MirStmt::If {
cond,
then_body,
else_body,
} => {
detect_builtin_reductions_expr(cond, fn_name, loop_tree, reductions, next_id);
detect_builtin_reductions_body(then_body, fn_name, loop_tree, reductions, next_id);
if let Some(eb) = else_body {
detect_builtin_reductions_body(eb, fn_name, loop_tree, reductions, next_id);
}
}
MirStmt::While { cond, body: wb } => {
detect_builtin_reductions_expr(cond, fn_name, loop_tree, reductions, next_id);
detect_builtin_reductions_body(wb, fn_name, loop_tree, reductions, next_id);
}
MirStmt::Return(Some(expr)) => {
detect_builtin_reductions_expr(expr, fn_name, loop_tree, reductions, next_id);
}
MirStmt::NoGcBlock(inner) => {
detect_builtin_reductions_body(inner, fn_name, loop_tree, reductions, next_id);
}
_ => {}
}
}
if let Some(ref result) = body.result {
detect_builtin_reductions_expr(result, fn_name, loop_tree, reductions, next_id);
}
}
fn detect_builtin_reductions_expr(
expr: &MirExpr,
fn_name: &str,
loop_tree: Option<&LoopTree>,
reductions: &mut Vec<ReductionInfo>,
next_id: &mut u32,
) {
match &expr.kind {
MirExprKind::Call { callee, args } => {
if let MirExprKind::Var(callee_name) = &callee.kind {
if BUILTIN_REDUCTIONS.contains(&callee_name.as_str()) {
let (reassoc, strict, semantics) = classify_builtin_reduction(callee_name);
reductions.push(ReductionInfo {
id: ReductionId(*next_id),
accumulator_var: String::new(), op: ReductionOp::BuiltinCall,
kind: ReductionKind::BuiltinReduction,
loop_id: None,
function_name: fn_name.to_string(),
builtin_name: Some(callee_name.clone()),
reassociation_forbidden: reassoc,
strict_order_required: strict,
accumulator_semantics: semantics,
});
*next_id += 1;
}
}
for arg in args {
detect_builtin_reductions_expr(arg, fn_name, loop_tree, reductions, next_id);
}
}
MirExprKind::Binary { left, right, .. } => {
detect_builtin_reductions_expr(left, fn_name, loop_tree, reductions, next_id);
detect_builtin_reductions_expr(right, fn_name, loop_tree, reductions, next_id);
}
MirExprKind::Unary { operand, .. } => {
detect_builtin_reductions_expr(operand, fn_name, loop_tree, reductions, next_id);
}
MirExprKind::Assign { target, value } => {
detect_builtin_reductions_expr(target, fn_name, loop_tree, reductions, next_id);
detect_builtin_reductions_expr(value, fn_name, loop_tree, reductions, next_id);
}
MirExprKind::Index { object, index } => {
detect_builtin_reductions_expr(object, fn_name, loop_tree, reductions, next_id);
detect_builtin_reductions_expr(index, fn_name, loop_tree, reductions, next_id);
}
_ => {}
}
}
fn classify_builtin_reduction(name: &str) -> (bool, bool, AccumulatorSemantics) {
match name {
"kahan_sum" => (true, true, AccumulatorSemantics::Kahan),
"binned_sum" => (false, false, AccumulatorSemantics::Binned),
"trapz" | "simps" => (true, true, AccumulatorSemantics::Plain),
_ => (true, true, AccumulatorSemantics::RuntimeDefined),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{MirBody, MirExpr, MirExprKind, MirFunction, MirFnId, MirParam, MirProgram, MirStmt};
use cjc_ast::{BinOp, Visibility};
fn int_expr(v: i64) -> MirExpr {
MirExpr {
kind: MirExprKind::IntLit(v),
}
}
fn var_expr(name: &str) -> MirExpr {
MirExpr {
kind: MirExprKind::Var(name.to_string()),
}
}
fn bool_expr(b: bool) -> MirExpr {
MirExpr {
kind: MirExprKind::BoolLit(b),
}
}
fn assign_acc_add(acc: &str, rhs: MirExpr) -> MirExpr {
MirExpr {
kind: MirExprKind::Assign {
target: Box::new(var_expr(acc)),
value: Box::new(MirExpr {
kind: MirExprKind::Binary {
op: BinOp::Add,
left: Box::new(var_expr(acc)),
right: Box::new(rhs),
},
}),
},
}
}
fn call_expr(name: &str, args: Vec<MirExpr>) -> MirExpr {
MirExpr {
kind: MirExprKind::Call {
callee: Box::new(var_expr(name)),
args,
},
}
}
fn make_program(functions: Vec<MirFunction>) -> MirProgram {
let entry = functions.last().map(|f| f.id).unwrap_or(MirFnId(0));
MirProgram {
functions,
struct_defs: vec![],
enum_defs: vec![],
entry,
}
}
fn make_fn(name: &str, body: MirBody) -> MirFunction {
MirFunction {
id: MirFnId(0),
name: name.to_string(),
type_params: vec![],
params: vec![],
return_type: None,
body,
is_nogc: false,
cfg_body: None,
decorators: vec![],
vis: Visibility::Private,
}
}
#[test]
fn test_detect_strict_fold_add() {
let body = MirBody {
stmts: vec![
MirStmt::Let {
name: "acc".into(),
mutable: true,
init: int_expr(0),
alloc_hint: None,
},
MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![MirStmt::Expr(assign_acc_add("acc", var_expr("x")))],
result: None,
},
},
],
result: None,
};
let program = make_program(vec![make_fn("test", body)]);
let report = detect_reductions(&program, &[]);
assert!(
report.reductions.iter().any(|r| r.accumulator_var == "acc"
&& r.op == ReductionOp::Add
&& r.kind == ReductionKind::StrictFold),
"should detect acc = acc + x as StrictFold"
);
}
#[test]
fn test_detect_strict_fold_mul() {
let body = MirBody {
stmts: vec![MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![MirStmt::Expr(MirExpr {
kind: MirExprKind::Assign {
target: Box::new(var_expr("prod")),
value: Box::new(MirExpr {
kind: MirExprKind::Binary {
op: BinOp::Mul,
left: Box::new(var_expr("prod")),
right: Box::new(var_expr("x")),
},
}),
},
})],
result: None,
},
}],
result: None,
};
let program = make_program(vec![make_fn("test", body)]);
let report = detect_reductions(&program, &[]);
assert!(
report.reductions.iter().any(|r| r.accumulator_var == "prod"
&& r.op == ReductionOp::Mul
&& r.kind == ReductionKind::StrictFold),
"should detect prod = prod * x as StrictFold"
);
}
#[test]
fn test_detect_builtin_sum() {
let body = MirBody {
stmts: vec![MirStmt::Let {
name: "total".into(),
mutable: false,
init: call_expr("sum", vec![var_expr("arr")]),
alloc_hint: None,
}],
result: None,
};
let program = make_program(vec![make_fn("test", body)]);
let report = detect_reductions(&program, &[]);
assert!(
report.reductions.iter().any(|r| r.kind == ReductionKind::BuiltinReduction
&& r.builtin_name.as_deref() == Some("sum")),
"should detect sum() as BuiltinReduction"
);
}
#[test]
fn test_detect_multiple_builtins() {
let body = MirBody {
stmts: vec![
MirStmt::Let {
name: "s".into(),
mutable: false,
init: call_expr("sum", vec![var_expr("a")]),
alloc_hint: None,
},
MirStmt::Let {
name: "m".into(),
mutable: false,
init: call_expr("mean", vec![var_expr("a")]),
alloc_hint: None,
},
MirStmt::Let {
name: "d".into(),
mutable: false,
init: call_expr("dot", vec![var_expr("a"), var_expr("b")]),
alloc_hint: None,
},
],
result: None,
};
let program = make_program(vec![make_fn("test", body)]);
let report = detect_reductions(&program, &[]);
assert_eq!(report.len(), 3, "should detect 3 builtin reductions");
let names: Vec<&str> = report
.reductions
.iter()
.filter_map(|r| r.builtin_name.as_deref())
.collect();
assert!(names.contains(&"sum"));
assert!(names.contains(&"mean"));
assert!(names.contains(&"dot"));
}
#[test]
fn test_no_false_positives() {
let body = MirBody {
stmts: vec![
MirStmt::Let {
name: "r".into(),
mutable: false,
init: call_expr("print", vec![var_expr("x")]),
alloc_hint: None,
},
MirStmt::Expr(MirExpr {
kind: MirExprKind::Assign {
target: Box::new(var_expr("y")),
value: Box::new(MirExpr {
kind: MirExprKind::Binary {
op: BinOp::Add,
left: Box::new(var_expr("x")),
right: Box::new(int_expr(1)),
},
}),
},
}),
],
result: None,
};
let program = make_program(vec![make_fn("test", body)]);
let report = detect_reductions(&program, &[]);
assert!(
report.is_empty(),
"should not detect reductions in non-reduction code"
);
}
#[test]
fn test_reduction_kind_properties() {
assert!(ReductionKind::StrictFold.is_strict());
assert!(!ReductionKind::StrictFold.is_reorderable());
assert!(!ReductionKind::StrictFold.is_parallelizable());
assert!(ReductionKind::BinnedFold.is_reorderable());
assert!(ReductionKind::BinnedFold.is_parallelizable());
assert!(!ReductionKind::BinnedFold.is_strict());
assert!(!ReductionKind::FixedTree.is_reorderable());
assert!(ReductionKind::FixedTree.is_parallelizable());
assert!(ReductionKind::Unknown.is_strict());
assert!(!ReductionKind::Unknown.is_parallelizable());
}
#[test]
fn test_reductions_in_function() {
let fn1 = make_fn(
"compute",
MirBody {
stmts: vec![MirStmt::Let {
name: "s".into(),
mutable: false,
init: call_expr("sum", vec![var_expr("a")]),
alloc_hint: None,
}],
result: None,
},
);
let fn2 = make_fn(
"other",
MirBody {
stmts: vec![MirStmt::Let {
name: "m".into(),
mutable: false,
init: call_expr("mean", vec![var_expr("b")]),
alloc_hint: None,
}],
result: None,
},
);
let program = make_program(vec![fn1, fn2]);
let report = detect_reductions(&program, &[]);
let compute_reds = report.reductions_in_function("compute");
assert_eq!(compute_reds.len(), 1);
assert_eq!(compute_reds[0].builtin_name.as_deref(), Some("sum"));
let other_reds = report.reductions_in_function("other");
assert_eq!(other_reds.len(), 1);
assert_eq!(other_reds[0].builtin_name.as_deref(), Some("mean"));
}
#[test]
fn test_reduction_detection_determinism() {
let body = MirBody {
stmts: vec![
MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![MirStmt::Expr(assign_acc_add("acc", var_expr("x")))],
result: None,
},
},
MirStmt::Let {
name: "s".into(),
mutable: false,
init: call_expr("sum", vec![var_expr("arr")]),
alloc_hint: None,
},
MirStmt::Let {
name: "d".into(),
mutable: false,
init: call_expr("dot", vec![var_expr("a"), var_expr("b")]),
alloc_hint: None,
},
],
result: None,
};
let program = make_program(vec![make_fn("test", body)]);
let report1 = detect_reductions(&program, &[]);
let report2 = detect_reductions(&program, &[]);
assert_eq!(report1.len(), report2.len());
for (a, b) in report1.reductions.iter().zip(report2.reductions.iter()) {
assert_eq!(a.id, b.id);
assert_eq!(a.accumulator_var, b.accumulator_var);
assert_eq!(a.op, b.op);
assert_eq!(a.kind, b.kind);
assert_eq!(a.function_name, b.function_name);
assert_eq!(a.builtin_name, b.builtin_name);
}
}
}