use std::sync::Arc;
use morok_ir::{AxisType, BinaryOp, Op, ReduceOp, TernaryOp};
use crate::optimizer::config::HeuristicsConfig;
use crate::optimizer::{Opt, Scheduler, apply_opt};
pub const DEFAULT_UPCAST_FACTOR: usize = 4;
pub fn hand_coded_optimizations(scheduler: &mut Scheduler, config: &HeuristicsConfig) {
use tracing::debug;
debug!("hand_coded_optimizations: starting");
if try_tensor_cores(scheduler, config) {
debug!("hand_coded_optimizations: tensor cores applied, skipping remaining opts");
return;
}
apply_image_upcasts(scheduler);
try_grouped_reduction(scheduler, config);
if scheduler.group_for_reduces() > 0 {
debug!("hand_coded_optimizations: group_for_reduces active, skipping remaining opts");
return;
}
apply_masked_upcasts(scheduler);
apply_heuristic_upcasts(scheduler);
apply_unroll(scheduler);
if !scheduler.upcasted() {
apply_default_upcast(scheduler);
}
apply_local_dims(scheduler, config);
debug!("hand_coded_optimizations: calling apply_threading with max_threads={}", config.thread_count);
let threading_applied = apply_threading(scheduler, config.thread_count);
debug!(threading_applied, "hand_coded_optimizations: apply_threading completed");
}
pub fn has_matmul_pattern(scheduler: &Scheduler) -> bool {
let Some(reduceop) = scheduler.reduceop() else { return false };
if let Op::Reduce { src, reduce_op, .. } = reduceop.op() {
if *reduce_op != ReduceOp::Add {
return false;
}
if let Op::Binary(BinaryOp::Mul, left, right) = src.op() {
let left_is_index = matches!(left.op(), Op::Index { .. })
|| matches!(left.op(), Op::Cast { src, .. } if matches!(src.op(), Op::Index { .. }));
let right_is_index = matches!(right.op(), Op::Index { .. })
|| matches!(right.op(), Op::Cast { src, .. } if matches!(src.op(), Op::Index { .. }));
return left_is_index && right_is_index;
}
}
false
}
pub fn is_masked(scheduler: &Scheduler, axis: usize) -> bool {
let rngs = scheduler.rngs();
if axis >= rngs.len() {
return false;
}
let target_rng = &rngs[axis];
for node in scheduler.ast().toposort() {
if let Op::Ternary(TernaryOp::Where, cond, _, _) = node.op()
&& cond.backward_slice_ids().contains(&target_rng.id)
{
return true;
}
}
false
}
pub fn has_broadcast_pattern(scheduler: &Scheduler, axis: usize) -> bool {
let rngs = scheduler.rngs();
if axis >= rngs.len() {
return false;
}
let target_rng = &rngs[axis];
for buf in scheduler.bufs() {
let in_backward = buf.backward_slice_ids().contains(&target_rng.id);
if !in_backward {
continue;
}
if let Op::Index { indices, .. } = buf.op() {
let in_index = indices.iter().any(|idx| idx.backward_slice_ids().contains(&target_rng.id));
if !in_index {
return true;
}
}
}
false
}
pub fn count_strides(scheduler: &Scheduler, axis: usize) -> (usize, usize) {
let rngs = scheduler.rngs();
if axis >= rngs.len() {
return (0, 0);
}
let target_rng = &rngs[axis];
let mut num_strides = 0;
let mut sum_strides: usize = 0;
for buf in scheduler.bufs() {
if let Op::Index { indices, .. } = buf.op() {
let idx = indices.first().map(|i| i.get_idx()).unwrap_or_else(|| buf.clone());
if idx.backward_slice_ids().contains(&target_rng.id) {
num_strides += 1;
}
for term in idx.split_uop(BinaryOp::Add) {
if Arc::ptr_eq(&term, target_rng) {
sum_strides += 1;
} else if let Op::Binary(BinaryOp::Mul, lhs, rhs) = term.op() {
if Arc::ptr_eq(lhs, target_rng)
&& let Op::Const(cv) = rhs.op()
&& let morok_ir::ConstValue::Int(v) = cv.0
{
sum_strides += v as usize;
} else if Arc::ptr_eq(rhs, target_rng)
&& let Op::Const(cv) = lhs.op()
&& let morok_ir::ConstValue::Int(v) = cv.0
{
sum_strides += v as usize;
}
}
}
}
}
(num_strides, sum_strides)
}
pub fn apply_image_upcasts(_scheduler: &mut Scheduler) -> bool {
false
}
pub fn apply_default_upcast(scheduler: &mut Scheduler) -> bool {
use morok_ir::Op;
use tracing::debug;
if scheduler.upcasted() {
debug!("apply_default_upcast: skipping (already upcasted)");
return false;
}
let upcastable = scheduler.upcastable_dims();
debug!(upcastable_dims = ?upcastable, "apply_default_upcast: checking upcastable dims");
if upcastable.is_empty() {
debug!("apply_default_upcast: no upcastable dims");
return false;
}
let axis_idx = *upcastable.last().unwrap();
let rngs = scheduler.rngs();
if axis_idx < rngs.len()
&& let Op::Range { end, .. } = rngs[axis_idx].op()
&& let Op::Const(cv) = end.op()
&& let morok_ir::ConstValue::Int(size) = cv.0
&& size % DEFAULT_UPCAST_FACTOR as i64 != 0
{
debug!(axis_idx, size, factor = DEFAULT_UPCAST_FACTOR, "apply_default_upcast: skipping (size not divisible)");
return false;
}
let result = apply_opt(scheduler, &Opt::upcast(axis_idx, DEFAULT_UPCAST_FACTOR), true);
debug!(?result, axis = axis_idx, factor = DEFAULT_UPCAST_FACTOR, "apply_default_upcast: apply_opt result");
result.is_ok()
}
pub fn apply_unroll(scheduler: &mut Scheduler) -> bool {
use tracing::debug;
let unrollable = scheduler.unrollable_dims();
if unrollable.is_empty() {
return false;
}
let upcast_size = scheduler.upcast_size();
let has_unroll = !scheduler.axes_of(&[AxisType::Unroll]).is_empty();
if upcast_size >= 64 || (upcast_size > 4 && has_unroll) {
debug!(upcast_size, has_unroll, "apply_unroll: skipping (upcast_size guard)");
return false;
}
let last_unrollable = *unrollable.last().unwrap();
let rngs = scheduler.rngs();
let size = if last_unrollable < rngs.len()
&& let Op::Range { end, .. } = rngs[last_unrollable].op()
&& let Op::Const(cv) = end.op()
&& let morok_ir::ConstValue::Int(sz) = cv.0
{
sz as usize
} else {
return false;
};
let logical_idx = unrollable.len() - 1;
if size <= 32 {
debug!(last_unrollable, size, "apply_unroll: full unroll");
if apply_opt(scheduler, &Opt::unroll(logical_idx, 0), true).is_ok() {
if size <= 3 {
let unrollable2 = scheduler.unrollable_dims();
if let Some(&last2) = unrollable2.last() {
let rngs2 = scheduler.rngs();
if last2 < rngs2.len()
&& let Op::Range { end, .. } = rngs2[last2].op()
&& let Op::Const(cv) = end.op()
&& let morok_ir::ConstValue::Int(sz2) = cv.0
&& sz2 <= 3
{
let _ = apply_opt(scheduler, &Opt::unroll(unrollable2.len() - 1, 0), true);
}
}
}
return true;
}
}
for splits in [4] {
if size % splits == 0 {
debug!(last_unrollable, size, splits, "apply_unroll: partial unroll");
if apply_opt(scheduler, &Opt::unroll(logical_idx, splits), true).is_ok() {
return true;
}
}
}
false
}
pub fn apply_masked_upcasts(scheduler: &mut Scheduler) -> bool {
let upcastable = scheduler.upcastable_dims();
let mut product: i64 = 1;
let mut to_upcast: Vec<(usize, usize)> = Vec::new();
for axis_idx in upcastable {
if !is_masked(scheduler, axis_idx) {
continue;
}
let rngs = scheduler.rngs();
if axis_idx >= rngs.len() {
continue;
}
let rng = &rngs[axis_idx];
if let Op::Range { end, .. } = rng.op()
&& let Op::Const(cv) = end.op()
&& let morok_ir::ConstValue::Int(size) = cv.0
&& size > 1
&& size <= 7
&& product * size <= 49
{
to_upcast.push((axis_idx, size as usize));
product *= size;
}
}
let mut applied = false;
for (axis_idx, size) in to_upcast.into_iter().rev() {
if apply_opt(scheduler, &Opt::upcast(axis_idx, size), true).is_ok() {
applied = true;
}
}
applied
}
pub fn try_grouped_reduction(scheduler: &mut Scheduler, config: &HeuristicsConfig) -> bool {
if !scheduler.renderer().has_local || config.disable_locals || !scheduler.renderer().has_shared {
return false;
}
let upcastable = scheduler.upcastable_dims();
let full_shape = scheduler.full_shape();
let group_for_reduces: i64 = upcastable.iter().map(|&i| full_shape.get(i).copied().unwrap_or(1)).product();
if group_for_reduces > 2048 {
return false;
}
for axis in 0..3 {
if apply_opt(scheduler, &Opt::grouptop(axis, 16), true).is_ok() {
return true;
}
}
false
}
pub fn apply_matmul_tiling(scheduler: &mut Scheduler, config: &HeuristicsConfig) -> bool {
use tracing::debug;
if !has_matmul_pattern(scheduler) {
return false;
}
if !config.output_upcast {
debug!("apply_matmul_tiling: skipped (output_upcast disabled)");
return false;
}
let output_axes = scheduler.axes_of(&[AxisType::Outer, AxisType::Global, AxisType::Loop]);
debug!(output_axes = ?output_axes, "apply_matmul_tiling: output axes");
if output_axes.len() < 2 {
debug!("apply_matmul_tiling: not enough output axes (need 2)");
return false;
}
const UPCAST_FACTORS: [usize; 5] = [8, 7, 6, 5, 4];
let rngs = scheduler.rngs();
let mut axes_with_sizes: Vec<(usize, usize)> = Vec::new();
for &axis_idx in output_axes.iter().take(2) {
if axis_idx >= rngs.len() {
continue;
}
if let Op::Range { end, .. } = rngs[axis_idx].op()
&& let Op::Const(cv) = end.op()
&& let morok_ir::ConstValue::Int(size) = cv.0
&& size >= 4
{
axes_with_sizes.push((axis_idx, size as usize));
}
}
if axes_with_sizes.len() < 2 {
debug!(found = axes_with_sizes.len(), "apply_matmul_tiling: not enough output axes");
return false;
}
let mut applied = false;
for (axis_idx, size) in axes_with_sizes {
if let Some(&factor) = UPCAST_FACTORS.iter().find(|&&f| size >= f && size % f == 0)
&& apply_opt(scheduler, &Opt::upcast(axis_idx, factor), true).is_ok()
{
debug!(axis = axis_idx, factor, size, "apply_matmul_tiling: applied UPCAST");
applied = true;
}
}
applied
}
pub fn apply_matmul_output_upcasting(scheduler: &mut Scheduler, config: &HeuristicsConfig) -> bool {
apply_matmul_tiling(scheduler, config)
}
pub fn apply_threading(scheduler: &mut Scheduler, max_threads: usize) -> bool {
use tracing::debug;
if !scheduler.renderer().has_threads || max_threads <= 1 {
return false;
}
let total_elements: i64 = scheduler.full_shape().iter().copied().product();
const THREAD_LIST: [usize; 9] = [32, 16, 12, 8, 6, 5, 4, 3, 2];
for &threads in &THREAD_LIST {
if threads > max_threads {
continue;
}
if total_elements < (threads as i64) * 131072 {
continue;
}
let loop_axes = scheduler.axes_of(&[AxisType::Loop]);
let mut thread_applied = false;
for &axis_idx in &loop_axes {
let rngs = scheduler.rngs();
if axis_idx >= rngs.len() {
continue;
}
if let Op::Range { end, .. } = rngs[axis_idx].op()
&& let Op::Const(cv) = end.op()
&& let morok_ir::ConstValue::Int(size) = cv.0
&& (size as usize).is_multiple_of(threads)
{
thread_applied = apply_opt(scheduler, &Opt::thread(axis_idx, threads), true).is_ok();
if thread_applied {
debug!(axis = axis_idx, threads, "apply_threading: applied THREAD");
}
break;
}
}
if thread_applied {
return true;
}
}
false
}
pub fn apply_heuristic_upcasts(scheduler: &mut Scheduler) -> bool {
use tracing::debug;
let mut applied = false;
let mut upcasted_axes: Vec<usize> = Vec::new();
loop {
let upcastable = scheduler.upcastable_dims();
if upcastable.is_empty() {
break;
}
let output_shape_product: i64 = {
let rngs = scheduler.rngs();
upcastable
.iter()
.filter_map(|&idx| {
if idx < rngs.len()
&& let Op::Range { end, .. } = rngs[idx].op()
&& let Op::Const(cv) = end.op()
&& let morok_ir::ConstValue::Int(sz) = cv.0
{
Some(sz)
} else {
None
}
})
.product()
};
if output_shape_product < 1024 || scheduler.upcast_size() >= 32 {
debug!(
output_shape_product,
upcast_size = scheduler.upcast_size(),
"apply_heuristic_upcasts: terminating (threshold)"
);
break;
}
let mut choices: Vec<(usize, usize, usize, usize)> = Vec::new();
let upcast_and_unroll_ranges = scheduler.ranges_of(&[AxisType::Upcast, AxisType::Unroll]);
for &axis_idx in &upcastable {
if upcasted_axes.contains(&axis_idx) {
continue;
}
let rngs = scheduler.rngs();
if axis_idx >= rngs.len() {
continue;
}
let rng = &rngs[axis_idx];
let has_stride0 = {
let bufs = scheduler.bufs();
bufs.iter().any(|buf| {
if let Op::Index { indices, .. } = buf.op() {
let rng_not_in_idx = !indices.iter().any(|idx| idx.backward_slice_ids().contains(&rng.id));
let all_upcast_in_idx = upcast_and_unroll_ranges
.iter()
.all(|r2| indices.iter().any(|idx| idx.backward_slice_ids().contains(&r2.id)));
rng_not_in_idx && all_upcast_in_idx
} else {
false
}
})
};
if !has_stride0 {
continue;
}
for &upcast_amount in &[3usize, 4] {
let size = if let Op::Range { end, .. } = rng.op()
&& let Op::Const(cv) = end.op()
&& let morok_ir::ConstValue::Int(sz) = cv.0
{
sz
} else {
continue;
};
if size % upcast_amount as i64 != 0 {
continue;
}
let (num_strides, sum_strides) = count_strides(scheduler, axis_idx);
choices.push((num_strides, sum_strides, axis_idx, upcast_amount));
}
}
if choices.is_empty() {
debug!("apply_heuristic_upcasts: no valid choices, breaking");
break;
}
choices.sort();
let (_, _, best_axis, best_amount) = choices[0];
debug!(best_axis, best_amount, "apply_heuristic_upcasts: applying upcast");
if apply_opt(scheduler, &Opt::upcast(best_axis, best_amount), true).is_ok() {
upcasted_axes.push(best_axis);
applied = true;
} else {
break;
}
}
applied
}
pub fn apply_local_dims(scheduler: &mut Scheduler, config: &HeuristicsConfig) -> bool {
if !scheduler.renderer().has_local || config.disable_locals {
return false;
}
let eligible_axes = scheduler.axes_of(&[AxisType::Global, AxisType::Loop]);
let full_shape = scheduler.full_shape();
let mut local_axis_ranking: Vec<(bool, usize)> = Vec::new();
for &axis in &eligible_axes {
let rngs = scheduler.rngs();
if axis >= rngs.len() {
continue;
}
if let Op::Range { end, .. } = rngs[axis].op() {
if !matches!(end.op(), Op::Const(..)) {
continue;
}
} else {
continue;
}
let is_expand = has_broadcast_pattern(scheduler, axis);
local_axis_ranking.push((is_expand, axis));
}
local_axis_ranking.sort_by(|a, b| b.cmp(a));
let mut to_local: Vec<(usize, usize)> = Vec::new();
for &(_, axis) in &local_axis_ranking {
let cumulative_local: usize = to_local.iter().map(|(_, sz)| *sz).product::<usize>().max(1);
let axis_size = full_shape[axis];
if axis_size <= 0 {
continue;
}
let candidates: &[usize] = if axis == 0 { &[32, 16, 8, 4, 3, 2] } else { &[16, 8, 4, 3, 2] };
let local_sz =
candidates.iter().copied().find(|&x| (axis_size as usize).is_multiple_of(x) && cumulative_local * x <= 128);
if let Some(sz) = local_sz {
to_local.push((axis, sz));
}
}
let mut to_apply: Vec<(usize, usize)> = to_local.into_iter().take(3).collect();
to_apply.sort();
let mut applied = false;
let mut deleted_shape = 0usize;
for (axis, local_sz) in to_apply {
let adjusted_axis = axis - deleted_shape;
let will_delete = local_sz == full_shape[axis] as usize;
if apply_opt(scheduler, &Opt::local(adjusted_axis, local_sz), true).is_ok() {
applied = true;
if will_delete {
deleted_shape += 1;
}
}
}
applied
}
pub fn try_tensor_cores(scheduler: &mut Scheduler, config: &HeuristicsConfig) -> bool {
use crate::optimizer::config::TcUsage;
use crate::optimizer::tc;
if config.tc_enabled == TcUsage::Disabled {
return false;
}
if scheduler.renderer().tensor_cores.is_empty() {
return false;
}
if !has_matmul_pattern(scheduler) {
return false;
}
let reduce_count = scheduler.axes_of(&[AxisType::GroupReduce, AxisType::Reduce]).len();
if reduce_count != 1 && config.tc_opt.as_usize() < 1 {
return false;
}
let mut trial = scheduler.clone();
let tc_result =
tc::apply(&mut trial, config.tc_select.as_i32(), config.tc_opt.as_usize(), config.tc_enabled.as_usize());
let axes = match tc_result {
Ok(axes) => axes,
Err(_) => return false,
};
let opt = Opt::tc(None, config.tc_select.as_i32(), config.tc_opt.as_usize(), config.tc_enabled.as_usize());
trial.applied_opts.push(opt);
if !trial.renderer().is_amx() {
let mut tc_rngs = [axes[0].clone(), axes[1].clone()];
for tc_dim in [1usize, 0] {
for &sz in &[5usize, 4, 3, 2] {
if tc_rngs[tc_dim].divisible_by(sz).is_some() {
if let Some(rng_idx) = trial.rngs().iter().position(|r| Arc::ptr_eq(r, &tc_rngs[tc_dim]))
&& let Ok((replaced, _)) =
trial.shift_to(tc_rngs[tc_dim].clone(), sz, AxisType::Upcast, false, None)
{
trial.applied_opts.push(Opt::upcast(rng_idx, sz));
tc_rngs[tc_dim] = replaced;
}
break;
}
}
}
if trial.renderer().has_local {
for &sz in &[4usize, 2] {
if tc_rngs[0].divisible_by(sz).is_some() {
if let Some(rng_idx) = trial.rngs().iter().position(|r| Arc::ptr_eq(r, &tc_rngs[0]))
&& trial.shift_to(tc_rngs[0].clone(), sz, AxisType::Local, false, None).is_ok()
{
trial.applied_opts.push(Opt::local(rng_idx, sz));
}
break;
}
}
}
}
*scheduler = trial;
true
}