use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use morok_dtype::DType;
use morok_ir::types::{AxisType, ConstValue};
use morok_ir::{Op, UOp};
use smallvec::SmallVec;
use crate::optimizer::Renderer;
use crate::pattern::TypedPatternMatcher;
pub fn pm_add_gpudims() -> TypedPatternMatcher<Renderer> {
crate::patterns! {
@context Renderer;
sink @ Sink { sources: _sources } => |sink| add_gpudims(ctx, sink),
}
}
fn add_gpudims(ctx: &Renderer, sink: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::Sink { sources } = sink.op() else {
return None;
};
let topo = sink.toposort();
if topo.iter().any(|u| matches!(u.op(), Op::Special { .. })) {
return None;
}
let mut all_ranges: HashMap<(usize, AxisType), Arc<UOp>> = HashMap::new();
for u in &topo {
if let Op::Range { axis_id, axis_type, .. } = u.op() {
all_ranges.insert((axis_id.value(), *axis_type), u.clone());
}
}
if all_ranges.is_empty() {
return None;
}
let mut global_dims: Vec<(usize, AxisType)> = Vec::new();
let mut local_dims: Vec<(usize, AxisType)> = Vec::new();
for (axis_id, axis_type) in all_ranges.keys() {
match axis_type {
AxisType::Global | AxisType::Thread => {
if !global_dims.iter().any(|(id, _)| *id == *axis_id) {
global_dims.push((*axis_id, *axis_type));
}
}
AxisType::Local | AxisType::Warp | AxisType::GroupReduce => {
if !local_dims.iter().any(|(id, _)| *id == *axis_id) {
local_dims.push((*axis_id, *axis_type));
}
}
_ => {}
}
}
global_dims.sort_by_key(|(id, _)| *id);
local_dims.sort_by_key(|(id, _)| *id);
if global_dims.is_empty() && local_dims.is_empty() {
return None;
}
let get_ranges_for_dims = |dims: &[(usize, AxisType)]| -> Vec<Arc<UOp>> {
dims.iter().filter_map(|(axis_id, axis_type)| all_ranges.get(&(*axis_id, *axis_type))).cloned().collect()
};
let global_ranges = get_ranges_for_dims(&global_dims);
let local_ranges = get_ranges_for_dims(&local_dims);
let extract_shape = |ranges: &[Arc<UOp>]| -> Vec<Arc<UOp>> {
ranges
.iter()
.filter_map(|r| match r.op() {
Op::Range { end, .. } => Some(end.clone()),
_ => None,
})
.collect()
};
let global_shape = extract_shape(&global_ranges);
let local_shape = extract_shape(&local_ranges);
let global_max = ctx.global_max.as_deref();
let local_max_product = ctx.local_max;
let local_max: Option<Vec<usize>> = local_max_product.map(|max| {
let n = local_shape.len().max(1);
let per_dim = (max as f64).powf(1.0 / n as f64).floor() as usize;
vec![per_dim.max(1); n]
});
let local_max_slice = local_max.as_deref();
let global_idxs = get_grouped_dims("gidx", &global_shape, global_max, true);
let local_idxs = get_grouped_dims("lidx", &local_shape, local_max_slice, false);
let local_idxs_for_masks = local_idxs.clone();
let all_idxs: Vec<Arc<UOp>> = global_idxs.into_iter().chain(local_idxs).collect();
let mut subs: HashMap<u64, Arc<UOp>> = HashMap::new();
let all_dims: Vec<(usize, AxisType)> = global_dims.iter().chain(local_dims.iter()).cloned().collect();
for (i, (axis_id, axis_type)) in all_dims.iter().enumerate() {
if *axis_type == AxisType::Reduce {
continue;
}
if let Some(range_uop) = all_ranges.get(&(*axis_id, *axis_type))
&& i < all_idxs.len()
{
subs.insert(range_uop.id, all_idxs[i].clone());
}
}
let store_subs = compute_store_masks(&topo, &all_ranges, &local_dims, &local_idxs_for_masks);
for (id, masked_idx) in store_subs {
subs.insert(id, masked_idx);
}
if subs.is_empty() {
return None;
}
let new_sources: SmallVec<[Arc<UOp>; 4]> = sources.iter().map(|s| substitute(s, &subs)).collect();
Some(UOp::new(Op::Sink { sources: new_sources }, sink.dtype().clone()))
}
fn compute_store_masks(
topo: &[Arc<UOp>],
all_ranges: &HashMap<(usize, AxisType), Arc<UOp>>,
local_dims: &[(usize, AxisType)],
local_idxs: &[Arc<UOp>],
) -> HashMap<u64, Arc<UOp>> {
let mut masks: HashMap<u64, Arc<UOp>> = HashMap::new();
for uop in topo {
let Op::Store { index, .. } = uop.op() else {
continue;
};
let is_global_store = match index.op() {
Op::Index { buffer, .. } => match buffer.dtype() {
DType::Ptr { addrspace, .. } => addrspace == morok_dtype::AddrSpace::Global,
_ => true, },
_ => continue,
};
if !is_global_store {
continue;
}
let index_ranges: HashSet<u64> = index.in_scope_ranges().iter().map(|key| key.0.id).collect();
let mut missing_locals: Vec<Arc<UOp>> = Vec::new();
for (i, (axis_id, axis_type)) in local_dims.iter().enumerate() {
if let Some(range_uop) = all_ranges.get(&(*axis_id, *axis_type))
&& !index_ranges.contains(&range_uop.id)
&& i < local_idxs.len()
{
missing_locals.push(local_idxs[i].clone());
}
}
if missing_locals.is_empty() {
continue;
}
let zero = UOp::index_const(0);
let mut mask: Option<Arc<UOp>> = None;
for local_idx in missing_locals {
let eq_zero = local_idx.eq(&zero);
mask = Some(match mask {
None => eq_zero,
Some(m) => m.and_(&eq_zero),
});
}
if let (Some(mask), Op::Index { buffer, indices, gate }) = (mask, index.op()) {
let new_gate = match gate {
Some(existing) => existing.and_(&mask),
None => mask,
};
let new_index = UOp::index()
.buffer(buffer.clone())
.indices(indices.clone())
.gate(new_gate)
.call()
.expect("gpudims: INDEX gate construction failed");
masks.insert(index.id, new_index);
}
}
masks
}
fn substitute(uop: &Arc<UOp>, subs: &HashMap<u64, Arc<UOp>>) -> Arc<UOp> {
if let Some(replacement) = subs.get(&uop.id) {
return replacement.clone();
}
let children = uop.op().children();
if children.is_empty() {
return uop.clone();
}
let new_children: Vec<Arc<UOp>> = children.iter().map(|c| substitute(c, subs)).collect();
let changed = children.iter().zip(&new_children).any(|(old, new)| old.id != new.id);
if !changed {
return uop.clone();
}
uop.replace().src(new_children).call()
}
fn const_to_i64(cv: &ConstValue) -> Option<i64> {
match cv {
ConstValue::Int(v) => Some(*v),
ConstValue::UInt(v) => Some(*v as i64),
ConstValue::Bool(v) => Some(*v as i64),
ConstValue::Float(v) => Some(*v as i64),
}
}
fn get_grouped_dims(prefix: &str, dims: &[Arc<UOp>], max_sizes: Option<&[usize]>, reverse: bool) -> Vec<Arc<UOp>> {
if dims.is_empty() {
return vec![];
}
let concrete_dims: Option<Vec<usize>> = dims
.iter()
.map(|d| match d.op() {
Op::Const(c) => const_to_i64(&c.0).map(|v| v as usize),
_ => None,
})
.collect();
let limited_dims: Vec<usize> = match (&concrete_dims, max_sizes) {
(Some(dims_vec), Some(max)) => limit_dims(dims_vec, max),
(Some(dims_vec), None) => dims_vec.clone(),
(None, _) => {
return dims.iter().enumerate().map(|(i, d)| UOp::special(d.clone(), format!("{}{}", prefix, i))).collect();
}
};
let raw_idxs: Vec<Arc<UOp>> = limited_dims
.iter()
.enumerate()
.map(|(i, &size)| UOp::special(UOp::index_const(size as i64), format!("{}{}", prefix, i)))
.collect();
let original_len = dims.len();
let limited_len = limited_dims.len();
let result = if limited_len < original_len {
decompose_contracted_dims(&raw_idxs, &limited_dims, concrete_dims.as_ref().unwrap())
} else if limited_len > original_len {
combine_expanded_dims(&raw_idxs, &limited_dims, concrete_dims.as_ref().unwrap())
} else if limited_dims != *concrete_dims.as_ref().unwrap() {
flatten_unflatten_dims(&raw_idxs, &limited_dims, concrete_dims.as_ref().unwrap())
} else {
raw_idxs
};
if reverse { result.into_iter().rev().collect() } else { result }
}
fn limit_dims(dims: &[usize], max_sizes: &[usize]) -> Vec<usize> {
if let Some(grouped) = group_dims(dims, max_sizes) {
return grouped;
}
split_dims(dims, max_sizes)
}
fn group_dims(dims: &[usize], max_sizes: &[usize]) -> Option<Vec<usize>> {
let mut result = dims.to_vec();
while result.len() > max_sizes.len() || result.iter().zip(max_sizes).any(|(d, m)| *d > *m) {
let mut grouped = false;
for i in 0..max_sizes.len().min(result.len().saturating_sub(1)) {
if i + 1 < result.len() {
let product = result[i].saturating_mul(result[i + 1]);
if product <= max_sizes[i] {
result = result[..i]
.iter()
.chain(std::iter::once(&product))
.chain(result[i + 2..].iter())
.cloned()
.collect();
grouped = true;
break;
}
}
}
if !grouped {
return None;
}
}
Some(result)
}
fn split_dims(dims: &[usize], max_sizes: &[usize]) -> Vec<usize> {
let mut result: Vec<usize> = dims.to_vec();
while result.len() < 3 {
result.push(1);
}
for i in 0..result.len() {
let max = if i < max_sizes.len() { max_sizes[i] } else { usize::MAX };
while result[i] > max {
let div = find_smallest_divisor(result[i]);
if div == 1 {
break;
}
let next = (i + 1) % result.len();
result[i] /= div;
result[next] *= div;
}
}
while result.len() > 1 && result.last() == Some(&1) {
result.pop();
}
result
}
fn find_smallest_divisor(n: usize) -> usize {
if n <= 1 {
return 1;
}
let sqrt_n = (n as f64).sqrt().ceil() as usize;
for d in 2..=sqrt_n {
if n.is_multiple_of(d) {
return d;
}
}
1 }
fn decompose_contracted_dims(raw_idxs: &[Arc<UOp>], limited_dims: &[usize], original_dims: &[usize]) -> Vec<Arc<UOp>> {
let contraction = get_contraction(original_dims, limited_dims);
let Some(contraction) = contraction else {
return raw_idxs.to_vec();
};
let mut result: Vec<Arc<UOp>> = Vec::new();
for (idx, group) in raw_idxs.iter().zip(&contraction) {
let mut current = idx.clone();
for &dim_idx in group.iter().rev().skip(1).collect::<Vec<_>>().into_iter().rev() {
let dim_size = original_dims[dim_idx];
let dim_uop = UOp::index_const(dim_size as i64);
result.push(current.mod_(&dim_uop));
current = current.idiv(&dim_uop);
}
result.push(current);
}
result
}
fn get_contraction(original_dims: &[usize], limited_dims: &[usize]) -> Option<Vec<Vec<usize>>> {
if original_dims.is_empty() && limited_dims.is_empty() {
return Some(vec![]);
}
if limited_dims.is_empty() {
return None;
}
let acc_old: Vec<usize> = original_dims
.iter()
.scan(1usize, |s, &x| {
*s = s.saturating_mul(x);
Some(*s)
})
.collect();
let acc_new: Vec<usize> = limited_dims
.iter()
.scan(1usize, |s, &x| {
*s = s.saturating_mul(x);
Some(*s)
})
.collect();
let mut split = Vec::with_capacity(acc_new.len());
for &acc in &acc_new {
if acc == 1 {
split.push(0);
} else {
match acc_old.iter().position(|&x| x == acc) {
Some(idx) => split.push(idx + 1), None => return None, }
}
}
let mut result = Vec::with_capacity(split.len());
let mut prev = 0;
for (i, &s) in split.iter().enumerate() {
if i == split.len() - 1 {
result.push((prev..original_dims.len()).collect());
} else {
result.push((prev..s).collect());
prev = s;
}
}
Some(result)
}
fn combine_expanded_dims(raw_idxs: &[Arc<UOp>], limited_dims: &[usize], original_dims: &[usize]) -> Vec<Arc<UOp>> {
let a = limited_dims.len();
let b = original_dims.len();
match (a, b) {
(2, 1) => {
let mul = raw_idxs[0].mul(&UOp::index_const(limited_dims[1] as i64));
vec![mul.add(&raw_idxs[1])]
}
(3, 1) => {
let mul1 = raw_idxs[0].mul(&UOp::index_const(limited_dims[1] as i64));
let add1 = mul1.add(&raw_idxs[1]);
let mul2 = add1.mul(&UOp::index_const(limited_dims[2] as i64));
vec![mul2.add(&raw_idxs[2])]
}
(3, 2) => {
let mul = raw_idxs[0].mul(&UOp::index_const(limited_dims[1] as i64));
vec![mul.add(&raw_idxs[1]), raw_idxs[2].clone()]
}
_ => raw_idxs.to_vec(),
}
}
fn flatten_unflatten_dims(raw_idxs: &[Arc<UOp>], limited_dims: &[usize], original_dims: &[usize]) -> Vec<Arc<UOp>> {
let n = limited_dims.len();
if n == 2 {
let mul = raw_idxs[0].mul(&UOp::index_const(limited_dims[1] as i64));
let flat = mul.add(&raw_idxs[1]);
let dim1_uop = UOp::index_const(original_dims[1] as i64);
vec![flat.idiv(&dim1_uop), flat.mod_(&dim1_uop)]
} else if n == 3 {
let l12 = UOp::index_const((limited_dims[1] * limited_dims[2]) as i64);
let l2 = UOp::index_const(limited_dims[2] as i64);
let mul0 = raw_idxs[0].mul(&l12);
let mul1 = raw_idxs[1].mul(&l2);
let add0 = mul0.add(&mul1);
let flat = add0.add(&raw_idxs[2]);
let d1 = original_dims[1];
let d2 = original_dims[2];
let d12 = UOp::index_const((d1 * d2) as i64);
let d1_uop = UOp::index_const(d1 as i64);
let d2_uop = UOp::index_const(d2 as i64);
vec![flat.idiv(&d12), flat.idiv(&d2_uop).mod_(&d1_uop), flat.mod_(&d2_uop)]
} else {
raw_idxs.to_vec()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_group_dims_already_fits() {
let result = group_dims(&[4, 4], &[16, 16, 16]);
assert_eq!(result, Some(vec![4, 4]));
}
#[test]
fn test_group_dims_needs_grouping() {
let result = group_dims(&[4, 4, 4, 4], &[256, 256, 256]);
assert!(result.is_some());
let result = result.unwrap();
assert!(result.len() <= 3);
}
#[test]
fn test_group_dims_no_change() {
let result = group_dims(&[8, 8, 8], &[256, 256, 256]);
assert_eq!(result, Some(vec![8, 8, 8]));
}
#[test]
fn test_group_dims_impossible() {
let result = group_dims(&[1000], &[10]);
assert_eq!(result, None);
}
#[test]
fn test_split_dims_simple() {
let result = split_dims(&[100], &[64, 64, 64]);
assert!(result.iter().all(|&d| d <= 64));
}
#[test]
fn test_find_smallest_divisor() {
assert_eq!(find_smallest_divisor(1), 1);
assert_eq!(find_smallest_divisor(2), 2); assert_eq!(find_smallest_divisor(3), 1); assert_eq!(find_smallest_divisor(4), 2);
assert_eq!(find_smallest_divisor(9), 3);
assert_eq!(find_smallest_divisor(100), 2);
}
#[test]
fn test_get_contraction_non_consecutive() {
let result = get_contraction(&[2, 5, 2], &[10, 2]);
assert_eq!(result, Some(vec![vec![0, 1], vec![2]]));
}
#[test]
fn test_get_contraction_identity() {
let result = get_contraction(&[4, 4, 4], &[4, 4, 4]);
assert_eq!(result, Some(vec![vec![0], vec![1], vec![2]]));
}
#[test]
fn test_get_contraction_all_fused() {
let result = get_contraction(&[2, 3, 4], &[24]);
assert_eq!(result, Some(vec![vec![0, 1, 2]]));
}
#[test]
fn test_get_contraction_empty() {
let result = get_contraction(&[], &[]);
assert_eq!(result, Some(vec![]));
}
#[test]
fn test_get_contraction_invalid() {
let result = get_contraction(&[2, 3, 4], &[5, 4]);
assert_eq!(result, None);
}
#[test]
fn test_get_contraction_partial() {
let result = get_contraction(&[2, 4, 3], &[8, 3]);
assert_eq!(result, Some(vec![vec![0, 1], vec![2]]));
}
}