morok-ir 0.1.0-alpha.2

Intermediate representation for the Morok ML compiler
Documentation
//! Data creation: constants, buffers, device specifications.
//!
//! This module contains constructors for creating data primitives:
//! - Constants (scalar, native, index)
//! - Buffers (new, view)
//! - Device specifications
//! - No-op and cast operations

use std::sync::Arc;

use morok_dtype::DType;
use morok_dtype::DeviceSpec;
use morok_dtype::ext::HasDType;

use crate::IntoUOp;
use crate::op::Op;
use crate::types::{ConstValue, ConstValueHash};
use crate::uop::core::UOp;
use crate::uop::hash_consing::next_unique_id;

impl UOp {
    // =========================================================================
    // Constants
    // =========================================================================

    /// Create a constant UOp with explicit dtype and value.
    ///
    /// Normalizes the value to match the target dtype (e.g., `Float(5.0)` becomes
    /// `Int(5)` when dtype is Int32). This prevents codegen from emitting
    /// mismatched literals.
    ///
    /// Use `native_const` for type-inferred constants from Rust values.
    pub fn const_(dtype: DType, value: ConstValue) -> Arc<Self> {
        let normalized = value.cast(&dtype).unwrap_or(value);
        Self::new(Op::Const(ConstValueHash(normalized)), dtype)
    }

    /// Create a constant UOp from a Rust native value with automatic dtype inference.
    pub fn native_const<T: HasDType + IntoUOp>(value: T) -> Arc<Self> {
        value.into_uop(T::DTYPE)
    }

    /// Create an index constant.
    pub fn index_const(value: i64) -> Arc<Self> {
        Self::const_(DType::Index, ConstValue::Int(value))
    }

    /// Create a constant with the same dtype as self.
    ///
    /// This is the Rust equivalent of Tinygrad's `x.const_like(value)`.
    /// Useful for creating identity elements, zeros, or other constants
    /// that match an existing UOp's type.
    ///
    /// # Examples
    ///
    /// ```rust
    /// # use std::sync::Arc;
    /// # use morok_ir::UOp;
    /// # use morok_dtype::DType;
    /// let x = UOp::const_(DType::Float32, morok_ir::ConstValue::Float(5.0));
    /// let zero = x.const_like(0.0);
    /// assert_eq!(zero.dtype(), DType::Float32);
    /// ```
    pub fn const_like<T: crate::IntoUOp>(self: &Arc<Self>, value: T) -> Arc<Self> {
        value.into_uop(self.dtype())
    }

    /// Create a vector constant from multiple values.
    ///
    /// Dtype is inferred from the first value; all values must be same type.
    pub fn vconst(values: Vec<ConstValue>, scalar_dtype: DType) -> Arc<Self> {
        let vec_dtype = scalar_dtype.vec(values.len());
        Self::new(Op::VConst { values }, vec_dtype)
    }

    // =========================================================================
    // Buffers
    // =========================================================================

    /// Create a unique buffer identifier.
    pub fn buffer_id(num: Option<usize>) -> Arc<Self> {
        let id = num.unwrap_or_else(next_unique_id);
        Self::new(Op::Unique(id), DType::Void)
    }

    /// Create a new buffer.
    ///
    /// Equivalent to: `UOp(Ops.BUFFER, dtype, (unique(), device(device_spec)), size)`
    pub fn new_buffer(device: DeviceSpec, size: usize, dtype: DType) -> Arc<Self> {
        let unique = Self::buffer_id(None);
        let dev = Self::device(device);
        Self::new(Op::Buffer { unique, device: dev, size }, dtype)
    }

    /// Create a normalized buffer parameter with positional slot.
    /// Used by pre-schedule normalization (BUFFER→PARAM) to erase buffer identity.
    /// Matches Tinygrad's `UOp.param(slot, dtype, shape, device)` (ops.py:817-819).
    pub fn param(slot: usize, size: usize, dtype: DType, device: Option<Arc<Self>>) -> Arc<Self> {
        Self::new(Op::Param { slot, size, device }, dtype)
    }

    /// Create a buffer view.
    pub fn view(self: &Arc<Self>, size: usize, offset: usize) -> Arc<Self> {
        let dtype = self.dtype.clone();
        Self::new(Op::BufferView { buffer: self.clone(), size, offset }, dtype)
    }

    // =========================================================================
    // Device
    // =========================================================================

    /// Create a device specification.
    pub fn device(device: DeviceSpec) -> Arc<Self> {
        Self::new(Op::Device(device), DType::Void)
    }

    // =========================================================================
    // Type Operations
    // =========================================================================

    /// Create a no-op.
    pub fn noop() -> Arc<Self> {
        Self::new(Op::Noop, DType::Void)
    }

    /// Cast to a different dtype.
    ///
    /// If casting a vector to a scalar type, automatically promotes the target
    /// dtype to a matching vector type. This prevents invalid scalar-to-vector
    /// casts in the IR. (Matches Tinygrad's cast behavior.)
    pub fn cast(self: &Arc<Self>, dtype: DType) -> Arc<Self> {
        let src_vcount = self.dtype().vcount();
        let dst_vcount = dtype.vcount();

        // Auto-promote scalar target to vector if source is vector
        let dtype = if dst_vcount == 1 && src_vcount > 1 { dtype.vec(src_vcount) } else { dtype };

        // No-op if types match
        if self.dtype() == dtype {
            return self.clone();
        }

        Self::new(Op::Cast { src: self.clone(), dtype: dtype.clone() }, dtype)
    }

    /// Bitcast: reinterpret bits as different type.
    pub fn bitcast(self: &Arc<Self>, dtype: DType) -> Arc<Self> {
        Self::new(Op::BitCast { src: self.clone(), dtype: dtype.clone() }, dtype)
    }
}