morok_ir/uop/constructors/data.rs
1//! Data creation: constants, buffers, device specifications.
2//!
3//! This module contains constructors for creating data primitives:
4//! - Constants (scalar, native, index)
5//! - Buffers (new, view)
6//! - Device specifications
7//! - No-op and cast operations
8
9use std::sync::Arc;
10
11use morok_dtype::DType;
12use morok_dtype::DeviceSpec;
13use morok_dtype::ext::HasDType;
14
15use crate::IntoUOp;
16use crate::op::Op;
17use crate::types::{ConstValue, ConstValueHash};
18use crate::uop::core::UOp;
19use crate::uop::hash_consing::next_unique_id;
20
21impl UOp {
22 // =========================================================================
23 // Constants
24 // =========================================================================
25
26 /// Create a constant UOp with explicit dtype and value.
27 ///
28 /// Normalizes the value to match the target dtype (e.g., `Float(5.0)` becomes
29 /// `Int(5)` when dtype is Int32). This prevents codegen from emitting
30 /// mismatched literals.
31 ///
32 /// Use `native_const` for type-inferred constants from Rust values.
33 pub fn const_(dtype: DType, value: ConstValue) -> Arc<Self> {
34 let normalized = value.cast(&dtype).unwrap_or(value);
35 Self::new(Op::Const(ConstValueHash(normalized)), dtype)
36 }
37
38 /// Create a constant UOp from a Rust native value with automatic dtype inference.
39 pub fn native_const<T: HasDType + IntoUOp>(value: T) -> Arc<Self> {
40 value.into_uop(T::DTYPE)
41 }
42
43 /// Create an index constant.
44 pub fn index_const(value: i64) -> Arc<Self> {
45 Self::const_(DType::Index, ConstValue::Int(value))
46 }
47
48 /// Create a constant with the same dtype as self.
49 ///
50 /// This is the Rust equivalent of Tinygrad's `x.const_like(value)`.
51 /// Useful for creating identity elements, zeros, or other constants
52 /// that match an existing UOp's type.
53 ///
54 /// # Examples
55 ///
56 /// ```rust
57 /// # use std::sync::Arc;
58 /// # use morok_ir::UOp;
59 /// # use morok_dtype::DType;
60 /// let x = UOp::const_(DType::Float32, morok_ir::ConstValue::Float(5.0));
61 /// let zero = x.const_like(0.0);
62 /// assert_eq!(zero.dtype(), DType::Float32);
63 /// ```
64 pub fn const_like<T: crate::IntoUOp>(self: &Arc<Self>, value: T) -> Arc<Self> {
65 value.into_uop(self.dtype())
66 }
67
68 /// Create a vector constant from multiple values.
69 ///
70 /// Dtype is inferred from the first value; all values must be same type.
71 pub fn vconst(values: Vec<ConstValue>, scalar_dtype: DType) -> Arc<Self> {
72 let vec_dtype = scalar_dtype.vec(values.len());
73 Self::new(Op::VConst { values }, vec_dtype)
74 }
75
76 // =========================================================================
77 // Buffers
78 // =========================================================================
79
80 /// Create a unique buffer identifier.
81 pub fn buffer_id(num: Option<usize>) -> Arc<Self> {
82 let id = num.unwrap_or_else(next_unique_id);
83 Self::new(Op::Unique(id), DType::Void)
84 }
85
86 /// Create a new buffer.
87 ///
88 /// Equivalent to: `UOp(Ops.BUFFER, dtype, (unique(), device(device_spec)), size)`
89 pub fn new_buffer(device: DeviceSpec, size: usize, dtype: DType) -> Arc<Self> {
90 let unique = Self::buffer_id(None);
91 let dev = Self::device(device);
92 Self::new(Op::Buffer { unique, device: dev, size }, dtype)
93 }
94
95 /// Create a normalized buffer parameter with positional slot.
96 /// Used by pre-schedule normalization (BUFFER→PARAM) to erase buffer identity.
97 /// Matches Tinygrad's `UOp.param(slot, dtype, shape, device)` (ops.py:817-819).
98 pub fn param(slot: usize, size: usize, dtype: DType, device: Option<Arc<Self>>) -> Arc<Self> {
99 Self::new(Op::Param { slot, size, device }, dtype)
100 }
101
102 /// Create a buffer view.
103 pub fn view(self: &Arc<Self>, size: usize, offset: usize) -> Arc<Self> {
104 let dtype = self.dtype.clone();
105 Self::new(Op::BufferView { buffer: self.clone(), size, offset }, dtype)
106 }
107
108 // =========================================================================
109 // Device
110 // =========================================================================
111
112 /// Create a device specification.
113 pub fn device(device: DeviceSpec) -> Arc<Self> {
114 Self::new(Op::Device(device), DType::Void)
115 }
116
117 // =========================================================================
118 // Type Operations
119 // =========================================================================
120
121 /// Create a no-op.
122 pub fn noop() -> Arc<Self> {
123 Self::new(Op::Noop, DType::Void)
124 }
125
126 /// Cast to a different dtype.
127 ///
128 /// If casting a vector to a scalar type, automatically promotes the target
129 /// dtype to a matching vector type. This prevents invalid scalar-to-vector
130 /// casts in the IR. (Matches Tinygrad's cast behavior.)
131 pub fn cast(self: &Arc<Self>, dtype: DType) -> Arc<Self> {
132 let src_vcount = self.dtype().vcount();
133 let dst_vcount = dtype.vcount();
134
135 // Auto-promote scalar target to vector if source is vector
136 let dtype = if dst_vcount == 1 && src_vcount > 1 { dtype.vec(src_vcount) } else { dtype };
137
138 // No-op if types match
139 if self.dtype() == dtype {
140 return self.clone();
141 }
142
143 Self::new(Op::Cast { src: self.clone(), dtype: dtype.clone() }, dtype)
144 }
145
146 /// Bitcast: reinterpret bits as different type.
147 pub fn bitcast(self: &Arc<Self>, dtype: DType) -> Arc<Self> {
148 Self::new(Op::BitCast { src: self.clone(), dtype: dtype.clone() }, dtype)
149 }
150}