use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use indexmap::IndexMap;
use morok_ir::{AxisType, Op, SInt, UOp, UOpKey};
use smallvec::SmallVec;
use tracing::{debug, trace};
#[derive(Debug, Clone)]
pub struct PcontigConfig {
pub level: u8,
pub max_buffers_threshold: usize,
pub out_in_ratio_threshold: f64,
}
impl Default for PcontigConfig {
fn default() -> Self {
Self { level: 2, max_buffers_threshold: 3, out_in_ratio_threshold: 10.0 }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SplitReduceOpConfig {
pub split_threshold: usize,
pub output_size_bits: u32,
pub max_divisor: usize,
pub min_divisor: usize,
pub enabled: bool,
}
impl Default for SplitReduceOpConfig {
fn default() -> Self {
Self { split_threshold: 32768, output_size_bits: 22, max_divisor: 256, min_divisor: 8, enabled: true }
}
}
impl SplitReduceOpConfig {
pub fn max_output_size(&self) -> usize {
2_usize.pow(self.output_size_bits)
}
}
#[derive(Clone)]
pub struct KernelContext {
pub global_counter: usize,
pub local_counter: usize,
pub buffer_map: HashMap<UOpKey, Arc<UOp>>,
pub vars: HashMap<String, (Arc<UOp>, Option<i64>)>,
pub range_counter: usize,
}
impl KernelContext {
pub fn new() -> Self {
Self { global_counter: 0, local_counter: 0, buffer_map: HashMap::new(), vars: HashMap::new(), range_counter: 0 }
}
pub fn next_global(&mut self) -> usize {
let id = self.global_counter;
self.global_counter += 1;
id
}
pub fn next_local(&mut self) -> usize {
let id = self.local_counter;
self.local_counter += 1;
id
}
pub fn next_range(&mut self) -> usize {
let id = self.range_counter;
self.range_counter += 1;
id
}
pub fn has_buffer(&self, buf: &Arc<UOp>) -> bool {
self.buffer_map.contains_key(&UOpKey(buf.clone()))
}
pub fn get_buffer(&self, buf: &Arc<UOp>) -> Option<&Arc<UOp>> {
self.buffer_map.get(&UOpKey(buf.clone()))
}
pub fn map_buffer(&mut self, original: Arc<UOp>, replacement: Arc<UOp>) {
self.buffer_map.insert(UOpKey(original), replacement);
}
pub fn add_var(&mut self, var: Arc<UOp>, value: Option<i64>) {
if let Op::DefineVar { name, .. } = var.op() {
self.vars.insert(name.clone(), (var, value));
}
}
}
impl Default for KernelContext {
fn default() -> Self {
Self::new()
}
}
#[derive(Default)]
pub struct LocalAddBufferContext {
pub param_slot: usize,
pub map: IndexMap<UOpKey, Arc<UOp>>,
pub vars: HashMap<String, (Arc<UOp>, Option<i64>)>,
pub range: usize,
pub opts: Vec<morok_ir::ContiguousHint>,
}
impl LocalAddBufferContext {
pub fn new() -> Self {
Self::default()
}
pub fn next_param_slot(&mut self) -> usize {
let id = self.param_slot;
self.param_slot += 1;
id
}
pub fn next_range(&mut self) -> usize {
let id = self.range;
self.range += 1;
id
}
pub fn add_var(&mut self, var: Arc<UOp>, value: Option<i64>) {
if let Op::DefineVar { name, .. } = var.op() {
self.vars.insert(name.clone(), (var, value));
}
}
pub fn map_buffer(&mut self, buf: Arc<UOp>, after: Arc<UOp>) {
self.map.insert(UOpKey(buf), after);
}
pub fn has_buffer(&self, buf: &Arc<UOp>) -> bool {
self.map.contains_key(&UOpKey(buf.clone()))
}
}
#[derive(Debug, Clone)]
pub struct KernelAstMarker;
fn extract_stored_value(ret: &Arc<UOp>) -> &Arc<UOp> {
match ret.op() {
Op::Store { value, .. } => value,
Op::End { computation, .. } => match computation.op() {
Op::Store { value, .. } => value,
_ => ret,
},
_ => ret,
}
}
pub fn split_store(_ctx: &mut Vec<Arc<UOp>>, x: &Arc<UOp>) -> Option<Arc<UOp>> {
use super::patterns::{local_to_param_patterns, rangeify_codegen_patterns};
use crate::rewrite::graph_rewrite_bottom_up;
trace!(uop_id = x.id, op = ?std::mem::discriminant(x.op()), "split_store: entering");
#[allow(clippy::mutable_key_type)] let in_scope = x.in_scope_ranges();
let has_non_outer =
in_scope.iter().any(|r| matches!(r.0.op(), Op::Range { axis_type, .. } if *axis_type != AxisType::Outer));
if has_non_outer {
return None;
}
if let Op::End { ranges, .. } = x.op()
&& let Some(r) = ranges.first()
&& matches!(r.op(), Op::Range { axis_type: AxisType::Outer, .. })
{
return None;
}
let is_valid = match x.op() {
Op::Store { .. } => true,
Op::End { computation, .. } => matches!(computation.op(), Op::Store { .. }),
_ => false,
};
if !is_valid {
return None;
}
let mut lctx = LocalAddBufferContext::new();
let ret = {
use std::sync::LazyLock;
static PM_CTX_DEP: LazyLock<crate::TypedPatternMatcher<LocalAddBufferContext>> =
LazyLock::new(|| local_to_param_patterns() + rangeify_codegen_patterns());
graph_rewrite_bottom_up(&*PM_CTX_DEP, x.clone(), &mut lctx)
};
let stored = extract_stored_value(&ret);
let ast = if matches!(stored.op(), Op::Copy { .. } | Op::BufferView { .. }) {
stored.clone()
} else {
UOp::sink(vec![ret]).with_metadata(KernelAstMarker)
};
let sources: SmallVec<[Arc<UOp>; 4]> =
lctx.map.values().cloned().chain(lctx.vars.values().map(|(uop, _)| uop.clone())).collect();
let kernel = UOp::kernel(sources.clone(), ast.clone());
debug!(
kernel_id = kernel.id,
num_sources = sources.len(),
map_size = lctx.map.len(),
vars_size = lctx.vars.len(),
"split_store: created kernel"
);
Some(kernel)
}
fn fix_assign(root: &Arc<UOp>) -> Arc<UOp> {
let mut kernel_assign: HashMap<u64, Arc<UOp>> = HashMap::new();
#[allow(clippy::mutable_key_type)]
let mut assign_rep: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
for u in root.toposort() {
let Op::After { passthrough, deps } = u.op() else {
continue;
};
let buf_id = passthrough.buf_uop().id;
kernel_assign.insert(buf_id, u.clone());
let Some(kernel) = deps.iter().find(|d| matches!(d.op(), Op::Kernel { .. })).cloned() else {
continue;
};
let Op::Kernel { sources, .. } = kernel.op() else {
continue;
};
for s in sources {
if !matches!(s.op(), Op::Buffer { .. } | Op::Param { .. }) {
continue;
}
let s_buf_id = s.buf_uop().id;
if s_buf_id == buf_id {
continue;
}
let Some(a) = kernel_assign.get(&s_buf_id) else {
continue;
};
if let Op::After { deps: a_deps, .. } = a.op()
&& a_deps.iter().any(|ad| deps.iter().any(|ud| Arc::ptr_eq(ad, ud)))
{
continue;
}
if u.any_in_subtree(|x| matches!(x.op(), Op::After { .. }) && x.buf_uop().id == s_buf_id) {
panic!(
"cycle detected in graph: kernel for buffer {} reads buffer {} which has AFTER in its tree",
buf_id, s_buf_id
);
}
if let Op::After { passthrough: a_passthrough, deps: a_deps } = a.op() {
let mut new_deps = a_deps.clone();
new_deps.push(u.clone());
let new_a = a_passthrough.after(new_deps);
assign_rep.insert(UOpKey(a.clone()), new_a.clone());
kernel_assign.insert(s_buf_id, new_a);
}
}
}
if assign_rep.is_empty() { root.clone() } else { root.substitute(&assign_rep) }
}
pub fn run_kernel_split_pipeline(root: Arc<UOp>) -> (Arc<UOp>, KernelContext) {
use super::transforms::pm_add_buffers_patterns;
use crate::rewrite::graph_rewrite_bottom_up;
let mut ctx = KernelContext::new();
let t_stage = std::time::Instant::now();
let after_buffers = {
use morok_ir::op::pattern_derived::OpKey;
use morok_ir::pattern::RewriteResult;
let mut matcher = pm_add_buffers_patterns();
matcher.add(&[OpKey::Sink], |node, _ctx| {
if node.metadata::<KernelAstMarker>().is_some() {
RewriteResult::Gate(node.clone())
} else {
RewriteResult::NoMatch
}
});
graph_rewrite_bottom_up(&matcher, root, &mut ctx)
};
tracing::debug!(elapsed_ms = t_stage.elapsed().as_millis() as u64, "kernel split: pm_add_buffers complete");
trace!(tree = %after_buffers.tree_full(), "after pm_add_buffers");
let t_stage = std::time::Instant::now();
let after_ctx_free = graph_rewrite_bottom_up(super::transforms::pm_flatten_range(), after_buffers, &mut ());
tracing::debug!(
elapsed_ms = t_stage.elapsed().as_millis() as u64,
"kernel split: pm_flatten_range pre-pass complete"
);
let t_stage = std::time::Instant::now();
let after_split = split_all_stores(&after_ctx_free);
tracing::debug!(elapsed_ms = t_stage.elapsed().as_millis() as u64, "kernel split: split_all_stores complete");
let t_stage = std::time::Instant::now();
let result = fix_assign(&after_split);
tracing::debug!(elapsed_ms = t_stage.elapsed().as_millis() as u64, "kernel split: fix_assign complete");
(result, ctx)
}
fn split_all_stores(root: &Arc<UOp>) -> Arc<UOp> {
use morok_ir::op::pattern_derived::OpKey;
use morok_ir::pattern::RewriteResult;
use morok_ir::rewrite::graph_rewrite_bottom_up;
let mut matcher = crate::patterns! {
@context Vec<Arc<UOp>>;
node @ Store { index: _, value: _ } => |node, ctx| split_store(ctx, node),
node @ End { computation, .. }
if matches!(computation.op(), Op::Store { .. } | Op::End { .. })
=> |node, ctx| split_store(ctx, node),
};
matcher.add(&[OpKey::Sink], |node, _ctx| {
if node.metadata::<KernelAstMarker>().is_some() {
RewriteResult::Gate(node.clone())
} else {
RewriteResult::NoMatch
}
});
let mut ctx = Vec::new();
graph_rewrite_bottom_up(&matcher, root.clone(), &mut ctx)
}
pub fn collect_range_ids(indexed: &Arc<UOp>) -> Vec<usize> {
let mut range_ids: Vec<usize> = indexed
.toposort()
.into_iter()
.filter_map(|node| if let Op::Range { axis_id, .. } = node.op() { Some(axis_id.value()) } else { None })
.collect();
range_ids.sort_unstable();
range_ids.dedup();
range_ids
}
#[derive(Debug, Clone)]
struct SplitCandidate {
dimension: usize,
divisor: usize,
#[allow(dead_code)]
output_size: usize,
}
fn detect_expanded_dimensions(source: &Arc<UOp>, input_shape: &[SInt]) -> Vec<bool> {
let ranges: Vec<Arc<UOp>> = input_shape
.iter()
.enumerate()
.map(|(axis_id, dim)| match dim {
SInt::Const(n) if *n > 1 => {
let end = UOp::index_const(*n as i64);
UOp::range_axis(end, morok_ir::AxisId::Unrenumbered(axis_id), morok_ir::AxisType::Loop)
}
_ => UOp::index_const(0),
})
.collect();
let indexed = match UOp::index().buffer(Arc::clone(source)).indices(ranges).call() {
Ok(idx) => idx,
Err(_) => return vec![false; input_shape.len()],
};
let base = source.base();
let noop = UOp::noop();
#[allow(clippy::mutable_key_type)]
let mut substitutions = HashMap::new();
substitutions.insert(UOpKey(base), noop);
let substituted = indexed.substitute(&substitutions);
use super::patterns::{movement_op_patterns, pm_syntactic_sugar};
use crate::rewrite::graph_rewrite_bottom_up;
use std::sync::LazyLock;
static PM_MOPS: LazyLock<crate::TypedPatternMatcher> =
LazyLock::new(|| movement_op_patterns() + pm_syntactic_sugar());
let transformed = graph_rewrite_bottom_up(&*PM_MOPS, substituted, &mut ());
let surviving_range_ids = collect_range_ids(&transformed);
let surviving_set: HashSet<usize> = surviving_range_ids.into_iter().collect();
input_shape.iter().enumerate().map(|(axis_id, _)| !surviving_set.contains(&axis_id)).collect()
}
fn find_split_candidates(
reduce: &Arc<UOp>,
input_shape: &[SInt],
is_expanded: &[bool],
config: &SplitReduceOpConfig,
) -> Vec<SplitCandidate> {
let Op::ReduceAxis { axes: reduce_axes, .. } = reduce.op() else {
return vec![];
};
let output_shape = match reduce.shape() {
Ok(Some(shape)) => shape,
_ => return vec![],
};
let output_size: usize = output_shape.iter().filter_map(|s| s.as_const()).product();
let mut candidates = Vec::new();
for &axis in reduce_axes {
if axis >= is_expanded.len() || is_expanded[axis] {
continue;
}
let dim_size = match &input_shape[axis] {
SInt::Const(n) => *n,
SInt::Symbolic(_) | SInt::Infer => continue,
};
for divisor in (config.min_divisor..=config.max_divisor).rev() {
if dim_size % divisor != 0 {
continue;
}
let new_output_size = output_size * divisor;
if new_output_size > config.max_output_size() {
continue;
}
candidates.push(SplitCandidate { dimension: axis, divisor, output_size: new_output_size });
}
}
candidates
}
fn apply_split_transformation(
source: &Arc<UOp>,
reduce: &Arc<UOp>,
candidate: &SplitCandidate,
input_shape: &[SInt],
) -> Option<Arc<UOp>> {
let Op::ReduceAxis { reduce_op, axes: reduce_axes, .. } = reduce.op() else {
return None;
};
let dim_to_split = candidate.dimension;
let divisor = candidate.divisor;
let dim_size = input_shape[dim_to_split].as_const()?;
let remainder = dim_size / divisor;
let mut splitted_shape: SmallVec<[SInt; 4]> = SmallVec::new();
for (i, dim) in input_shape.iter().enumerate() {
if i == dim_to_split {
splitted_shape.push(SInt::Const(divisor));
splitted_shape.push(SInt::Const(remainder));
} else {
splitted_shape.push(dim.clone());
}
}
let reshaped = source.try_reshape(&splitted_shape).ok()?;
let mut permutation: Vec<usize> = (0..splitted_shape.len()).filter(|&i| i != dim_to_split).collect();
permutation.push(dim_to_split);
let permuted = reshaped.try_permute(permutation.clone()).ok()?;
let adjusted_axes: Vec<usize> = reduce_axes
.iter()
.map(|&axis| {
if axis < dim_to_split {
axis
} else if axis == dim_to_split {
dim_to_split + 1
} else {
axis + 1
}
})
.collect();
let permuted_axes: Vec<usize> =
adjusted_axes.iter().map(|&old_axis| permutation.iter().position(|&p| p == old_axis).unwrap()).collect();
let first_reduce = permuted.try_reduce_axis(*reduce_op, permuted_axes).ok()?;
let contiguous = first_reduce.contiguous();
let output_shape = contiguous.shape().ok()??;
let split_axis = output_shape.len() - 1;
let second_reduce = contiguous.try_reduce_axis(*reduce_op, vec![split_axis]).ok()?;
let final_shape = reduce.shape().ok()??;
second_reduce.try_reshape(final_shape).ok()
}
pub fn split_reduceop(reduce: &Arc<UOp>, config: &SplitReduceOpConfig) -> Option<Arc<UOp>> {
if !config.enabled {
return None;
}
let Op::ReduceAxis { src: source, .. } = reduce.op() else {
return None;
};
let input_shape = source.shape().ok()??;
let output_shape = reduce.shape().ok()??;
if !input_shape.iter().all(|s| s.is_const()) {
return None;
}
let input_size: usize = input_shape.iter().map(|s| s.as_const().unwrap()).product();
let output_size: usize = output_shape.iter().map(|s| s.as_const().unwrap()).product();
if output_size == 0 {
return None;
}
let ratio = input_size / output_size;
if ratio < config.split_threshold {
return None;
}
let is_expanded = detect_expanded_dimensions(source, input_shape);
let candidates = find_split_candidates(reduce, input_shape, &is_expanded, config);
if candidates.is_empty() {
return None;
}
apply_split_transformation(source, reduce, &candidates[0], input_shape)
}