morok-schedule 0.1.0-alpha.2

Optimization passes and pattern engine for the Morok ML compiler
Documentation
//! Index dtype lowering patterns.
//!
//! Converts abstract Index dtype to concrete integer types (i32 or i64)
//! based on value bounds analysis. Follows Tinygrad's cascade approach.
//!
//! ## Cascade Pattern (from Tinygrad)
//!
//! Phase 1 - Create wrappers:
//!   CONST(Index) → CONST(concrete).cast(Index)
//!   DEFINE_VAR(Index) → DEFINE_VAR(concrete).cast(Index)
//!
//! Phase 2 - Process wrapped values:
//!   Binary(x.cast(Index), y.cast(Index)) → Binary(x, y, concrete).cast(Index)
//!   RANGE(end.cast(Index)) → RANGE(end, concrete).cast(Index)
//!
//! Phase 3 - Strip at terminals:
//!   INDEX(idx.cast(Index)) → INDEX(idx)
//!   SINK/END strip .cast(Index)

use std::sync::Arc;

use morok_dtype::{DType, ScalarDType};
use morok_ir::types::ConstValue;
use morok_ir::uop::cached_property::CachedProperty;
use morok_ir::uop::properties::VminVmaxProperty;
use morok_ir::{Op, UOp};

use crate::TypedPatternMatcher;

/// Select concrete dtype based on bounds analysis.
fn select_dtype(uop: &Arc<UOp>) -> DType {
    let (vmin, vmax) = VminVmaxProperty::get(uop);
    let fits_i32 = match (vmin, vmax) {
        (ConstValue::Int(min), ConstValue::Int(max)) => *min >= i32::MIN as i64 && *max <= i32::MAX as i64,
        (ConstValue::UInt(min), ConstValue::UInt(max)) => *min <= i32::MAX as u64 && *max <= i32::MAX as u64,
        (ConstValue::Bool(_), ConstValue::Bool(_)) => true,
        _ => false,
    };
    if fits_i32 { DType::Scalar(ScalarDType::Int32) } else { DType::Scalar(ScalarDType::Int64) }
}

/// Compute least upper dtype for integer types.
fn least_upper_dtype(a: &DType, b: &DType) -> DType {
    match (a, b) {
        (DType::Scalar(ScalarDType::Int64), _) | (_, DType::Scalar(ScalarDType::Int64)) => {
            DType::Scalar(ScalarDType::Int64)
        }
        _ => DType::Scalar(ScalarDType::Int32),
    }
}

/// Pattern matcher for lowering Index dtype to concrete i32/i64.
/// Based on Tinygrad's pm_lower_index_dtype.
pub fn pm_lower_index_dtype() -> TypedPatternMatcher {
    crate::patterns! {
        // ================================================================
        // PHASE 1: Create wrappers (leaf nodes)
        // ================================================================

        // CONST(Index) → CONST(concrete).cast(Index)
        // Tinygrad: u.replace(dtype=select_dtype(u)).cast(u.dtype)
        c @const(cv) if c.dtype() == DType::Index => |c, cv| {
            let dt = select_dtype(c);
            Some(UOp::const_(dt, cv).cast(DType::Index))
        },

        // VCONST(Vector<Index, N>) → VCONST(Vector<concrete, N>).cast(Vector<Index, N>)
        // Tinygrad: (CONST, VCONST) with dtype=index → u.replace(dtype=select_dtype(u)).cast(u.dtype)
        vc @ VConst { values } if vc.dtype().base() == ScalarDType::Index => |vc, values| {
            let dt = select_dtype(vc);
            let vcount = vc.dtype().vcount();
            let vec_dt = dt.vec(vcount);
            let vec_index_dt = DType::Vector { scalar: ScalarDType::Index, count: vcount };
            let new_vc = UOp::new(Op::VConst { values: values.clone() }, vec_dt);
            Some(new_vc.cast(vec_index_dt))
        },

        // DEFINE_VAR(Index) → DEFINE_VAR(concrete).cast(Index)
        dv @ DefineVar { name, min_val, max_val } if dv.dtype() == DType::Index => |dv, name, min_val, max_val| {
            let dt = select_dtype(dv);
            let var = UOp::new(Op::DefineVar { name: name.clone(), min_val: *min_val, max_val: *max_val }, dt);
            Some(var.cast(DType::Index))
        },

        // ================================================================
        // PHASE 2: Process wrapped values
        // ================================================================

        // Unary(x.cast(Index)) → Unary(x.cast(dt), dt).cast(Index)
        // Handles Neg and other unary ops on Index-wrapped values.
        node if matches!(node.op(), Op::Unary(_, _)) && node.dtype() == DType::Index => |node| {
            let Op::Unary(op, x) = node.op() else { return None };
            let Op::Cast { src, dtype } = x.op() else { return None };
            if *dtype != DType::Index { return None; }

            let dt = least_upper_dtype(&select_dtype(node), &src.dtype());
            let result = UOp::new(Op::Unary(*op, src.cast(dt.clone())), dt);
            Some(result.cast(DType::Index))
        },

        // Binary(x.cast(Index), y.cast(Index)) → alu(op, x.cast(dt), y.cast(dt)).cast(u.dtype)
        // Tinygrad: x.cast(dt:=least_upper_dtype(select_dtype(u), x.dtype, y.dtype)).alu(u.op, y.cast(dt)).cast(u.dtype)
        // No dtype guard on result — comparisons (Lt, Ge, etc.) produce Bool, not Index,
        // but their operands still need unwrapping from .cast(Index).
        node if matches!(node.op(), Op::Binary(_, _, _)) => |node| {
            let Op::Binary(op, lhs, rhs) = node.op() else { return None };

            // Both operands must be .cast(Index) wrappers
            let (Op::Cast { src: x, dtype: lhs_dt }, Op::Cast { src: y, dtype: rhs_dt }) = (lhs.op(), rhs.op()) else {
                return None;
            };
            if *lhs_dt != DType::Index || *rhs_dt != DType::Index {
                return None;
            }

            // dt = least_upper_dtype(select_dtype(result), x.dtype, y.dtype)
            let result_dt = select_dtype(node);
            let dt = least_upper_dtype(&result_dt, &least_upper_dtype(&x.dtype(), &y.dtype()));

            // alu() auto-selects Bool for comparisons, dt for arithmetic
            let result = UOp::alu(*op, x.cast(dt.clone()), y.cast(dt));
            Some(result.cast(node.dtype()))
        },

        // WHERE(cond, x.cast(Index), y.cast(Index)) → WHERE(cond, x.cast(dt), y.cast(dt)).cast(Index)
        Where(cond, true_val, false_val)
            if true_val.dtype() == DType::Index && false_val.dtype() == DType::Index => |cond, true_val, false_val| {
            let (Op::Cast { src: x, dtype: t_dt }, Op::Cast { src: y, dtype: f_dt }) = (true_val.op(), false_val.op()) else {
                return None;
            };
            if *t_dt != DType::Index || *f_dt != DType::Index {
                return None;
            }

            let dt = least_upper_dtype(&x.dtype(), &y.dtype());
            let result = UOp::try_where(cond.clone(), x.cast(dt.clone()), y.cast(dt)).ok()?;
            Some(result.cast(DType::Index))
        },

        // RANGE(end.cast(Index)) → RANGE(end, end.dtype).cast(Index)
        // Tinygrad: r.replace(dtype=end.dtype, src=(end,)).cast(dtypes.index)
        range @ Range { end, axis_id, axis_type } if range.dtype() == DType::Index => |end, axis_id, axis_type| {
            let Op::Cast { src: end_inner, dtype: end_dt } = end.op() else {
                return None;
            };
            if *end_dt != DType::Index {
                return None;
            }

            let dt = end_inner.dtype();
            let result = UOp::new(Op::Range { end: end_inner.clone(), axis_id: *axis_id, axis_type: *axis_type, deps: smallvec::SmallVec::new() }, dt);
            Some(result.cast(DType::Index))
        },

        // SPECIAL(end.cast(Index)) → SPECIAL(end, i32).cast(Index)
        // Tinygrad: u.replace(dtype=dtypes.int, src=(var,)).cast(dtypes.index)
        special @ Special { name, end } if special.dtype() == DType::Index => |name, end| {
            let Op::Cast { src: end_inner, dtype: end_dt } = end.op() else {
                return None;
            };
            if *end_dt != DType::Index {
                return None;
            }

            let i32_dt = DType::Scalar(ScalarDType::Int32);
            let result = UOp::new(Op::Special { end: end_inner.clone(), name: name.clone() }, i32_dt);
            Some(result.cast(DType::Index))
        },

        // VECTORIZE(e0.cast(Index), ...) → VECTORIZE(e0.cast(dt), ...).cast(Vector<Index>)
        vec @ Vectorize { elements } if vec.dtype().base() == ScalarDType::Index => |vec, elements| {
            let inner: Option<Vec<_>> = elements.iter().map(|e| {
                match e.op() {
                    Op::Cast { src, dtype } if *dtype == DType::Index => Some(src.clone()),
                    _ => None,
                }
            }).collect();
            let inner = inner?;

            let dt = select_dtype(vec);
            let casted: Vec<_> = inner.iter().map(|e| e.cast(dt.clone())).collect();
            let vec_index_dt = DType::Vector { scalar: ScalarDType::Index, count: elements.len() };
            Some(UOp::vectorize(casted.into()).cast(vec_index_dt))
        },

        // BIND(var.cast(Index), val.cast(Index)) → var.bind(val).cast(Index)
        // Tinygrad: (UPat(Ops.BIND, src=(var.cast(index), val.cast(index))), lambda var,val: var.bind(val).cast(index))
        Bind { var, value } if var.dtype() == DType::Index => |var, value| {
            let Op::Cast { src: var_inner, dtype: var_dt } = var.op() else { return None };
            let Op::Cast { src: val_inner, dtype: val_dt } = value.op() else { return None };
            if *var_dt != DType::Index || *val_dt != DType::Index { return None; }

            // Compute common dtype for the binding
            let dt = least_upper_dtype(&var_inner.dtype(), &val_inner.dtype());
            let bound = var_inner.cast(dt.clone()).bind(val_inner.cast(dt));
            Some(bound.cast(DType::Index))
        },

        // ================================================================
        // PHASE 3: Cleanup - strip wrappers at terminal nodes
        // ================================================================

        // INDEX(buf, ...idx.cast(Index)...) → INDEX(buf, ...idx...) — strip Cast(Index) from all per-dim indices
        // Tinygrad: ops.py:1308-1310 — buf.index(idx, ptr=True)
        // Generalized for multi-index INDEX (Tinygrad keeps multi-index through the pipeline).
        node @ Index { buffer, indices, gate } => |node, buffer, indices, gate| {
            let mut changed = false;
            let new_indices: smallvec::SmallVec<[std::sync::Arc<UOp>; 4]> = indices.iter().map(|idx| {
                if let Op::Cast { src, dtype } = idx.op()
                    && *dtype == DType::Index && src.dtype().is_int() {
                        changed = true;
                        return src.clone();
                    }
                idx.clone()
            }).collect();
            if !changed { return None; }
            Some(UOp::new(Op::Index { buffer: buffer.clone(), indices: new_indices, gate: gate.clone() }, node.dtype()))
        },

        // INDEX(buf, ...WHERE(cond, idx, Invalid)...) → INDEX(buf, ...idx..., gate=AND(conds...))
        // Tinygrad: ops.py:1306 — extract WHERE-Invalid from index, merge conds into gate
        // Generalized for multi-index: extracts validity from ALL per-dimension indices.
        node @ Index { buffer, indices, gate } => |node, buffer, indices, gate| {
            let mut new_indices: smallvec::SmallVec<[std::sync::Arc<UOp>; 4]> = smallvec::SmallVec::new();
            let mut conds: Vec<std::sync::Arc<UOp>> = Vec::new();
            let mut changed = false;
            for idx in indices.iter() {
                if let Op::Ternary(morok_ir::TernaryOp::Where, cond, true_val, false_val) = idx.op()
                    && UOp::is_invalid_marker(false_val) {
                        new_indices.push(true_val.clone());
                        conds.push(cond.clone());
                        changed = true;
                        continue;
                    }
                new_indices.push(idx.clone());
            }
            if !changed { return None; }
            let extracted = conds.into_iter().reduce(|a, b| a.try_and_op(&b).expect("ICE: AND gate merge"));
            let new_gate = match (gate, extracted) {
                (Some(existing), Some(ext)) => Some(existing.try_and_op(&ext).expect("ICE: AND gate merge")),
                (Some(existing), None) => Some(existing.clone()),
                (None, ext) => ext,
            };
            Some(UOp::new(Op::Index { buffer: buffer.clone(), indices: new_indices, gate: new_gate }, node.dtype()))
        },

        // SINK/END - strip .cast(Index) from sources
        // Tinygrad (ops.py:1311) also includes NOOP here, but Morok's Op::Noop has no sources,
        // so stripping .cast(Index) from NOOP sources is a no-op.

        // SINK - strip .cast(Index) from sources
        Sink { sources } => |sources| {
            let mut changed = false;
            let new_sources: Vec<Arc<UOp>> = sources.iter().map(|s| {
                if let Op::Cast { src, dtype } = s.op() && *dtype == DType::Index {
                    changed = true;
                    src.clone()
                } else {
                    s.clone()
                }
            }).collect();
            if !changed { return None; }
            Some(UOp::sink(new_sources))
        },

        // END - strip .cast(Index) from computation
        End { computation, ranges } => |computation, ranges| {
            let Op::Cast { src, dtype } = computation.op() else { return None };
            if *dtype != DType::Index { return None; }
            Some(UOp::new(Op::End { computation: src.clone(), ranges: ranges.clone() }, DType::Void))
        },
    }
}