morok_ir/uop/constructors/memory.rs
1//! Memory operations: load, store, index, copy, bufferize.
2//!
3//! This module contains operations for memory access:
4//! - Indexing: index, index_gated, pointer_index, slice
5//! - Memory access: load, store (gate is on INDEX, not LOAD/STORE)
6//! - Device operations: copy, copy_to_device
7//! - Bufferization: bufferize, bufferize_global, bufferize_local
8//! - Memory definitions: define_local, define_reg
9
10use std::sync::Arc;
11
12use bon::bon;
13use morok_dtype::DType;
14use morok_dtype::DeviceSpec;
15use smallvec::SmallVec;
16use snafu::ensure;
17
18use crate::Result;
19use crate::error::IndexTypeMismatchSnafu;
20use crate::indexing::IndexSpec;
21use crate::op::Op;
22use crate::types::{AddrSpace, BufferizeOpts};
23use crate::uop::UOp;
24
25#[bon]
26impl UOp {
27 // =========================================================================
28 // Indexing Operations
29 // =========================================================================
30
31 /// Create a buffer index operation for multi-dimensional access.
32 ///
33 /// All indices must have Index dtype.
34 ///
35 /// # Dtype behavior (matches Tinygrad's `buf.index(idx, ptr=False, dtype=None)`)
36 /// - If `dtype` is provided: use it directly (explicit dtype takes precedence)
37 /// - If `ptr` is true: keep the buffer's Ptr dtype (for STORE targets)
38 /// - Otherwise (ptr=false, default): extract element type from buffer (for LOAD sources)
39 ///
40 /// # Examples
41 /// ```ignore
42 /// // Element dtype (default) - for LOAD
43 /// UOp::index().buffer(buf).indices(vec![idx]).call()?
44 ///
45 /// // Ptr dtype via ptr=true - for STORE (preferred, Tinygrad-aligned)
46 /// UOp::index().buffer(buf).indices(vec![idx]).ptr(true).call()?
47 ///
48 /// // Explicit Ptr dtype - for STORE (legacy, works but prefer .ptr(true))
49 /// let ptr_dtype = DType::Float32.ptr(Some(size), AddrSpace::Global);
50 /// UOp::index().buffer(buf).indices(vec![idx]).dtype(ptr_dtype).call()?
51 ///
52 /// // With gate
53 /// UOp::index().buffer(buf).indices(vec![idx]).gate(gate_uop).call()?
54 /// ```
55 #[builder]
56 pub fn index<I: Into<SmallVec<[Arc<Self>; 4]>>>(
57 buffer: Arc<Self>,
58 indices: I,
59 gate: Option<Arc<Self>>,
60 dtype: Option<DType>,
61 /// When true, keep buffer's Ptr dtype (for STORE targets).
62 /// When false (default), extract element type (for LOAD sources).
63 /// Matches Tinygrad's `buf.index(idx, ptr=True/False)`.
64 ptr: Option<bool>,
65 ) -> Result<Arc<Self>> {
66 let indices = indices.into();
67 // Validate that all indices have integer/index base dtype.
68 // Allows both scalar (Index, Int64, Int32) and vector (Index.vec(N), Int64.vec(N))
69 // for devectorized register/local buffer indexing.
70 for idx in &indices {
71 let base = idx.dtype().base();
72 ensure!(
73 matches!(
74 base,
75 morok_dtype::ScalarDType::Index | morok_dtype::ScalarDType::Int64 | morok_dtype::ScalarDType::Int32
76 ),
77 IndexTypeMismatchSnafu { actual: idx.dtype() }
78 );
79 }
80
81 // Determine result dtype based on (dtype, ptr) parameters
82 // Priority: explicit dtype > ptr flag > default (element type)
83 let result_dtype = match (dtype, ptr.unwrap_or(false)) {
84 (Some(d), _) => d, // Explicit dtype takes precedence
85 (None, true) => buffer.dtype(), // ptr=true: keep Ptr dtype
86 (None, false) => match buffer.dtype() {
87 // ptr=false: extract element type
88 DType::Ptr { base, .. } => base.as_ref().clone(),
89 other => other,
90 },
91 };
92
93 Ok(Self::new(Op::Index { buffer, indices, gate }, result_dtype))
94 }
95
96 /// Create a pointer index operation (pointer arithmetic).
97 ///
98 /// Performs pointer + offset arithmetic for address calculation in kernels.
99 /// Both self (ptr) and offset should have Index dtype.
100 pub fn pointer_index(self: &Arc<Self>, offset: Arc<Self>) -> Result<Arc<Self>> {
101 let ptr_dtype = self.dtype();
102 let offset_dtype = offset.dtype();
103 ensure!(ptr_dtype == DType::Index, IndexTypeMismatchSnafu { actual: ptr_dtype });
104 ensure!(offset_dtype == DType::Index, IndexTypeMismatchSnafu { actual: offset_dtype });
105 Ok(Self::new(Op::PointerIndex { ptr: self.clone(), offset }, DType::Index))
106 }
107
108 /// Multi-dimensional slicing with IndexSpec.
109 ///
110 /// **Note**: Range and NewAxis specs are not fully implemented;
111 /// currently only Single indices are properly supported.
112 pub fn slice(buffer: Arc<Self>, specs: Vec<IndexSpec>) -> Result<Arc<Self>> {
113 let mut indices = Vec::new();
114
115 for spec in specs {
116 match spec {
117 IndexSpec::Single(idx) => {
118 // Single index - just use it directly
119 indices.push(idx);
120 }
121 IndexSpec::Range { start, end: _, step: _ } => {
122 // Range indexing - for now, just use start as a simple index
123 // TODO: Proper range expansion requires loop IR and range operations
124 indices.push(start);
125 }
126 IndexSpec::Full => {
127 // Full slice - skip (means "all elements")
128 // TODO: Proper handling requires understanding dimension size
129 }
130 IndexSpec::NewAxis => {
131 // NewAxis - adds dimension
132 // TODO: Requires reshape operation
133 }
134 }
135 }
136
137 if indices.is_empty() {
138 // No actual indexing, just return buffer
139 Ok(buffer)
140 } else {
141 Self::index().buffer(buffer).indices(indices).call()
142 }
143 }
144
145 /// Gated slicing - conditional access with gate.
146 pub fn slice_gated(buffer: Arc<Self>, specs: Vec<IndexSpec>, gate: Arc<Self>) -> Result<Arc<Self>> {
147 let mut indices = Vec::new();
148
149 for spec in specs {
150 match spec {
151 IndexSpec::Single(idx) => indices.push(idx),
152 IndexSpec::Range { start, .. } => indices.push(start),
153 IndexSpec::Full | IndexSpec::NewAxis => {}
154 }
155 }
156
157 if indices.is_empty() { Ok(buffer) } else { Self::index().buffer(buffer).indices(indices).gate(gate).call() }
158 }
159
160 // =========================================================================
161 // Index Helpers
162 // =========================================================================
163
164 /// Wrap index with validity condition.
165 ///
166 /// This is the Rust equivalent of Tinygrad's `idx.valid(cond)`.
167 /// Creates WHERE(cond, self, Invalid) to mark conditional index validity.
168 ///
169 /// # Examples
170 ///
171 /// ```ignore
172 /// // Create a conditionally valid index
173 /// let valid_idx = idx.valid(cond);
174 /// // Equivalent to: WHERE(cond, idx, INVALID)
175 /// ```
176 pub fn valid(self: &Arc<Self>, cond: Arc<Self>) -> Arc<Self> {
177 UOp::try_where(cond, self.clone(), UOp::invalid_marker()).expect("valid: WHERE construction failed")
178 }
179
180 // =========================================================================
181 // Memory Access Operations
182 // =========================================================================
183
184 /// Create a LOAD operation.
185 ///
186 /// # Example
187 /// ```ignore
188 /// // Infer dtype from buffer
189 /// UOp::load().buffer(buf).index(idx).call()
190 ///
191 /// // Explicit dtype for vector loads
192 /// UOp::load().buffer(buf).index(idx).dtype(vec4_dtype).call()
193 ///
194 /// // With alt value for gated loads
195 /// UOp::load().buffer(buf).index(idx).alt(zero).call()
196 /// ```
197 #[builder]
198 pub fn load(buffer: Arc<Self>, index: Arc<Self>, dtype: Option<DType>, alt: Option<Arc<Self>>) -> Arc<Self> {
199 let dtype = dtype.unwrap_or_else(|| match &buffer.dtype {
200 DType::Ptr { base, .. } => (**base).clone(),
201 other => other.clone(),
202 });
203 Self::new(Op::Load { buffer, index, alt }, dtype)
204 }
205
206 /// Create a STORE operation without ranges.
207 ///
208 /// Stores a value at self (INDEX location).
209 /// The buffer is accessed indirectly through the INDEX node.
210 /// For stores with ranges (e.g., output upcasting), use `store_with_ranges`.
211 ///
212 /// For gated stores, use an INDEX with a gate (INDEX has optional gate field).
213 pub fn store(self: &Arc<Self>, value: Arc<Self>) -> Arc<Self> {
214 self.store_with_ranges(value, SmallVec::new())
215 }
216
217 /// Create a STORE operation with ranges.
218 ///
219 /// Stores a value at self (INDEX location), with explicit ranges
220 /// that define the scope of the store operation. This matches Tinygrad's
221 /// architecture where STORE sources are `(index, value, *ranges)`.
222 ///
223 /// Ranges are used for output upcasting: Range(Upcast) becomes UNROLL
224 /// during expansion, which `fix_store_unroll` contracts via CONTRACT.
225 ///
226 /// For gated stores, use an INDEX with a gate (INDEX has optional gate field).
227 pub fn store_with_ranges(self: &Arc<Self>, value: Arc<Self>, ranges: SmallVec<[Arc<Self>; 4]>) -> Arc<Self> {
228 Self::new(Op::Store { index: self.clone(), value, ranges }, DType::Void)
229 }
230
231 // =========================================================================
232 // Device Operations
233 // =========================================================================
234
235 /// Copy to a different device.
236 pub fn copy_to_device(self: &Arc<Self>, device: DeviceSpec) -> Arc<Self> {
237 let dev = Self::device(device);
238 Self::new(Op::Copy { src: self.clone(), device: dev }, self.dtype.clone())
239 }
240
241 /// Create a COPY operation with explicit device UOp.
242 ///
243 /// Unlike `copy_to_device` which takes a `DeviceSpec`, this takes
244 /// a device UOp directly (useful when you already have one).
245 pub fn copy(self: &Arc<Self>, device: Arc<Self>) -> Arc<Self> {
246 let dtype = self.dtype.clone();
247 Self::new(Op::Copy { src: self.clone(), device }, dtype)
248 }
249
250 // =========================================================================
251 // Bufferization Operations
252 // =========================================================================
253
254 /// Create a BUFFERIZE operation.
255 ///
256 /// Marks a computation to be materialized into a buffer.
257 /// The computation is evaluated over the given ranges and stored.
258 pub fn bufferize(compute: Arc<Self>, ranges: Vec<Arc<Self>>, opts: BufferizeOpts) -> Arc<Self> {
259 let dtype = compute.dtype.clone();
260 Self::new(Op::Bufferize { compute, ranges: SmallVec::from_vec(ranges), opts }, dtype)
261 }
262
263 /// Create a BUFFERIZE operation with Global address space.
264 ///
265 /// This is the most common pattern - bufferize to global memory.
266 pub fn bufferize_global(compute: Arc<Self>, ranges: Vec<Arc<Self>>) -> Arc<Self> {
267 Self::bufferize(compute, ranges, BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true })
268 }
269
270 /// Create a BUFFERIZE operation with Local address space.
271 ///
272 /// For shared/local memory bufferization.
273 pub fn bufferize_local(compute: Arc<Self>, ranges: Vec<Arc<Self>>) -> Arc<Self> {
274 Self::bufferize(compute, ranges, BufferizeOpts { device: None, addrspace: AddrSpace::Local, removable: true })
275 }
276
277 // =========================================================================
278 // Memory Definition Operations
279 // =========================================================================
280
281 /// Create a DEFINE_LOCAL operation.
282 ///
283 /// Defines a local (shared) memory allocation with the given ID.
284 pub fn define_local(id: usize, dtype: DType) -> Arc<Self> {
285 Self::new(Op::DefineLocal(id), dtype)
286 }
287
288 /// Define register memory (void pointer - type determined by usage).
289 pub fn define_reg(size: usize) -> Arc<Self> {
290 use morok_dtype::AddrSpace;
291 let id = crate::uop::hash_consing::next_unique_id();
292 let ptr_dtype = DType::Void.ptr(Some(size), AddrSpace::Reg);
293 Self::new(Op::DefineReg { size, id }, ptr_dtype)
294 }
295
296 /// Define register memory with explicit element type.
297 ///
298 /// Creates a typed register accumulator for use in reductions.
299 /// The element_dtype specifies the type of each element (e.g., Float32 for a float accumulator).
300 pub fn define_reg_typed(size: usize, element_dtype: DType) -> Arc<Self> {
301 use morok_dtype::AddrSpace;
302 let id = crate::uop::hash_consing::next_unique_id();
303 let ptr_dtype =
304 DType::Ptr { base: Box::new(element_dtype), addrspace: AddrSpace::Reg, size: Some(size), vcount: 1 };
305 Self::new(Op::DefineReg { size, id }, ptr_dtype)
306 }
307}