morok-ir 0.1.0-alpha.2

Intermediate representation for the Morok ML compiler
Documentation
//! Graph organization: sink, group, assign, contiguous.
//!
//! This module contains graph organization and optimization operations:
//! - Graph structure: sink, group
//! - Assignment: assign
//! - Dependencies: after
//! - Materialization: detach, contiguous, contiguous_backward
//! - Optimization hints: precast
//! - Custom code: custom, customi

use std::sync::Arc;

use morok_dtype::DType;
use smallvec::SmallVec;

use crate::op::Op;
use crate::uop::UOp;

impl UOp {
    // =========================================================================
    // Graph Structure
    // =========================================================================

    /// Create a sink operation (graph termination).
    ///
    /// Sink marks outputs that must be evaluated. All sources are dependencies.
    pub fn sink(sources: Vec<Arc<Self>>) -> Arc<Self> {
        Self::new(Op::Sink { sources: SmallVec::from_vec(sources) }, DType::Void)
    }

    /// Create a group operation (merging/organizing related ops).
    ///
    /// Group is a NOOP that helps organize related operations together.
    /// It passes through the first source while ensuring all sources are dependencies.
    pub fn group(sources: Vec<Arc<Self>>) -> Arc<Self> {
        let dtype = if sources.is_empty() { DType::Void } else { sources[0].dtype.clone() };
        Self::new(Op::Group { sources: SmallVec::from_vec(sources) }, dtype)
    }

    // =========================================================================
    // Assignment
    // =========================================================================

    /// In-place assignment.
    ///
    /// # Arguments
    /// * `target` - The INDEX operation for the assignment destination
    /// * `value` - The value to assign
    pub fn assign(target: Arc<Self>, value: Arc<Self>) -> Arc<Self> {
        Self::assign_with_mops(target, value, None)
    }

    /// In-place assignment with movement ops chain.
    ///
    /// The `movement_ops` parameter captures shape transformations from the
    /// original target, used during bufferize_to_store to apply the same
    /// transformations to the result buffer.
    pub fn assign_with_mops(target: Arc<Self>, value: Arc<Self>, movement_ops: Option<Arc<Self>>) -> Arc<Self> {
        let dtype = target.dtype();
        Self::new(Op::Assign { target, value, movement_ops }, dtype)
    }

    // =========================================================================
    // Dependencies
    // =========================================================================

    /// Ordering constraint: self depends on deps.
    ///
    /// # Arguments
    /// * `deps` - Dependencies that must complete before this value is used
    ///
    /// # Panics (debug only)
    /// Panics if self is a control flow node (Range, End)
    pub fn after(self: &Arc<Self>, deps: SmallVec<[Arc<Self>; 4]>) -> Arc<Self> {
        #[cfg(debug_assertions)]
        debug_assert!(
            !matches!(self.op(), Op::Range { .. } | Op::End { .. }),
            "AFTER passthrough must be data-producing node, got {:?} (id={})",
            self.op(),
            self.id
        );

        let dtype = self.dtype();
        Self::new(Op::After { passthrough: self.clone(), deps }, dtype)
    }

    // =========================================================================
    // Materialization
    // =========================================================================

    /// Detach from gradient flow / force materialization.
    pub fn detach(self: &Arc<Self>) -> Arc<Self> {
        let dtype = self.dtype();
        Self::new(Op::Detach { src: self.clone() }, dtype)
    }

    /// Ensure contiguous memory layout.
    ///
    /// Elides the CONTIGUOUS wrapper when the source is already contiguous:
    /// - Already a CONTIGUOUS node (no double wrapping)
    /// - Has buffer identity (BUFFER, or RESHAPE/MULTI chain to BUFFER)
    ///
    /// Based on Tinygrad's `UOp.contiguous()` (ops.py:463-466).
    pub fn contiguous(self: &Arc<Self>) -> Arc<Self> {
        if matches!(self.op(), Op::Contiguous { .. }) {
            return self.clone();
        }
        if self.has_buffer_identity() {
            return self.clone();
        }
        let dtype = self.dtype();
        Self::new(Op::Contiguous { src: self.clone(), opts: smallvec::SmallVec::new() }, dtype)
    }

    /// Ensure contiguous memory layout with optimization hints.
    ///
    /// The hints are extracted during rangeify and passed to the optimizer.
    /// Based on Tinygrad's CONTIGUOUS.arg which carries Opt tuples.
    pub fn contiguous_with_opts(
        self: &Arc<Self>,
        opts: smallvec::SmallVec<[crate::types::ContiguousHint; 4]>,
    ) -> Arc<Self> {
        let dtype = self.dtype();
        Self::new(Op::Contiguous { src: self.clone(), opts }, dtype)
    }

    /// Contiguous backward pass.
    pub fn contiguous_backward(self: &Arc<Self>) -> Arc<Self> {
        let dtype = self.dtype();
        Self::new(Op::ContiguousBackward { src: self.clone() }, dtype)
    }

    // =========================================================================
    // Optimization Hints
    // =========================================================================

    /// Optimizer hint to force materialization before type conversion.
    ///
    /// Inserted before BITCAST to ensure the source is rendered separately
    /// in codegen (prevents invalid cast fusion).
    pub fn precast(self: &Arc<Self>) -> Arc<Self> {
        let dtype = self.dtype();
        Self::new(Op::Precast { src: self.clone() }, dtype)
    }

    // =========================================================================
    // Custom Code
    // =========================================================================

    /// Inject custom code as a statement in the generated kernel.
    ///
    /// `deps` are UOps whose rendered names can be referenced in `code`.
    /// `dtype` specifies the result type (often Void for statements).
    pub fn custom(deps: SmallVec<[Arc<Self>; 4]>, code: String, dtype: DType) -> Arc<Self> {
        Self::new(Op::Custom { deps, code }, dtype)
    }

    /// Inject custom code as an inline expression.
    ///
    /// Unlike `custom` (statement), this is substituted directly into expressions.
    /// `deps` provide values to reference; result has specified `dtype`.
    pub fn customi(deps: SmallVec<[Arc<Self>; 4]>, code: String, dtype: DType) -> Arc<Self> {
        Self::new(Op::CustomI { deps, code }, dtype)
    }
}