Skip to main content

morok_ir/uop/constructors/
data.rs

1//! Data creation: constants, buffers, device specifications.
2//!
3//! This module contains constructors for creating data primitives:
4//! - Constants (scalar, native, index)
5//! - Buffers (new, view)
6//! - Device specifications
7//! - No-op and cast operations
8
9use std::sync::Arc;
10
11use morok_dtype::DType;
12use morok_dtype::DeviceSpec;
13use morok_dtype::ext::HasDType;
14
15use crate::IntoUOp;
16use crate::op::Op;
17use crate::types::{ConstValue, ConstValueHash};
18use crate::uop::core::UOp;
19use crate::uop::hash_consing::next_unique_id;
20
21impl UOp {
22    // =========================================================================
23    // Constants
24    // =========================================================================
25
26    /// Create a constant UOp with explicit dtype and value.
27    ///
28    /// Normalizes the value to match the target dtype (e.g., `Float(5.0)` becomes
29    /// `Int(5)` when dtype is Int32). This prevents codegen from emitting
30    /// mismatched literals.
31    ///
32    /// Use `native_const` for type-inferred constants from Rust values.
33    pub fn const_(dtype: DType, value: ConstValue) -> Arc<Self> {
34        let normalized = value.cast(&dtype).unwrap_or(value);
35        Self::new(Op::Const(ConstValueHash(normalized)), dtype)
36    }
37
38    /// Create a constant UOp from a Rust native value with automatic dtype inference.
39    pub fn native_const<T: HasDType + IntoUOp>(value: T) -> Arc<Self> {
40        value.into_uop(T::DTYPE)
41    }
42
43    /// Create an index constant.
44    pub fn index_const(value: i64) -> Arc<Self> {
45        Self::const_(DType::Index, ConstValue::Int(value))
46    }
47
48    /// Create a constant with the same dtype as self.
49    ///
50    /// This is the Rust equivalent of Tinygrad's `x.const_like(value)`.
51    /// Useful for creating identity elements, zeros, or other constants
52    /// that match an existing UOp's type.
53    ///
54    /// # Examples
55    ///
56    /// ```rust
57    /// # use std::sync::Arc;
58    /// # use morok_ir::UOp;
59    /// # use morok_dtype::DType;
60    /// let x = UOp::const_(DType::Float32, morok_ir::ConstValue::Float(5.0));
61    /// let zero = x.const_like(0.0);
62    /// assert_eq!(zero.dtype(), DType::Float32);
63    /// ```
64    pub fn const_like<T: crate::IntoUOp>(self: &Arc<Self>, value: T) -> Arc<Self> {
65        value.into_uop(self.dtype())
66    }
67
68    /// Create a vector constant from multiple values.
69    ///
70    /// Dtype is inferred from the first value; all values must be same type.
71    pub fn vconst(values: Vec<ConstValue>, scalar_dtype: DType) -> Arc<Self> {
72        let vec_dtype = scalar_dtype.vec(values.len());
73        Self::new(Op::VConst { values }, vec_dtype)
74    }
75
76    // =========================================================================
77    // Buffers
78    // =========================================================================
79
80    /// Create a unique buffer identifier.
81    pub fn buffer_id(num: Option<usize>) -> Arc<Self> {
82        let id = num.unwrap_or_else(next_unique_id);
83        Self::new(Op::Unique(id), DType::Void)
84    }
85
86    /// Create a new buffer.
87    ///
88    /// Equivalent to: `UOp(Ops.BUFFER, dtype, (unique(), device(device_spec)), size)`
89    pub fn new_buffer(device: DeviceSpec, size: usize, dtype: DType) -> Arc<Self> {
90        let unique = Self::buffer_id(None);
91        let dev = Self::device(device);
92        Self::new(Op::Buffer { unique, device: dev, size }, dtype)
93    }
94
95    /// Create a normalized buffer parameter with positional slot.
96    /// Used by pre-schedule normalization (BUFFER→PARAM) to erase buffer identity.
97    /// Matches Tinygrad's `UOp.param(slot, dtype, shape, device)` (ops.py:817-819).
98    pub fn param(slot: usize, size: usize, dtype: DType, device: Option<Arc<Self>>) -> Arc<Self> {
99        Self::new(Op::Param { slot, size, device }, dtype)
100    }
101
102    /// Create a buffer view.
103    pub fn view(self: &Arc<Self>, size: usize, offset: usize) -> Arc<Self> {
104        let dtype = self.dtype.clone();
105        Self::new(Op::BufferView { buffer: self.clone(), size, offset }, dtype)
106    }
107
108    // =========================================================================
109    // Device
110    // =========================================================================
111
112    /// Create a device specification.
113    pub fn device(device: DeviceSpec) -> Arc<Self> {
114        Self::new(Op::Device(device), DType::Void)
115    }
116
117    // =========================================================================
118    // Type Operations
119    // =========================================================================
120
121    /// Create a no-op.
122    pub fn noop() -> Arc<Self> {
123        Self::new(Op::Noop, DType::Void)
124    }
125
126    /// Cast to a different dtype.
127    ///
128    /// If casting a vector to a scalar type, automatically promotes the target
129    /// dtype to a matching vector type. This prevents invalid scalar-to-vector
130    /// casts in the IR. (Matches Tinygrad's cast behavior.)
131    pub fn cast(self: &Arc<Self>, dtype: DType) -> Arc<Self> {
132        let src_vcount = self.dtype().vcount();
133        let dst_vcount = dtype.vcount();
134
135        // Auto-promote scalar target to vector if source is vector
136        let dtype = if dst_vcount == 1 && src_vcount > 1 { dtype.vec(src_vcount) } else { dtype };
137
138        // No-op if types match
139        if self.dtype() == dtype {
140            return self.clone();
141        }
142
143        Self::new(Op::Cast { src: self.clone(), dtype: dtype.clone() }, dtype)
144    }
145
146    /// Bitcast: reinterpret bits as different type.
147    pub fn bitcast(self: &Arc<Self>, dtype: DType) -> Arc<Self> {
148        Self::new(Op::BitCast { src: self.clone(), dtype: dtype.clone() }, dtype)
149    }
150}