Skip to main content

morok_ir/uop/constructors/
memory.rs

1//! Memory operations: load, store, index, copy, bufferize.
2//!
3//! This module contains operations for memory access:
4//! - Indexing: index, index_gated, pointer_index, slice
5//! - Memory access: load, store (gate is on INDEX, not LOAD/STORE)
6//! - Device operations: copy, copy_to_device
7//! - Bufferization: bufferize, bufferize_global, bufferize_local
8//! - Memory definitions: define_local, define_reg
9
10use std::sync::Arc;
11
12use bon::bon;
13use morok_dtype::DType;
14use morok_dtype::DeviceSpec;
15use smallvec::SmallVec;
16use snafu::ensure;
17
18use crate::Result;
19use crate::error::IndexTypeMismatchSnafu;
20use crate::indexing::IndexSpec;
21use crate::op::Op;
22use crate::types::{AddrSpace, BufferizeOpts};
23use crate::uop::UOp;
24
25#[bon]
26impl UOp {
27    // =========================================================================
28    // Indexing Operations
29    // =========================================================================
30
31    /// Create a buffer index operation for multi-dimensional access.
32    ///
33    /// All indices must have Index dtype.
34    ///
35    /// # Dtype behavior (matches Tinygrad's `buf.index(idx, ptr=False, dtype=None)`)
36    /// - If `dtype` is provided: use it directly (explicit dtype takes precedence)
37    /// - If `ptr` is true: keep the buffer's Ptr dtype (for STORE targets)
38    /// - Otherwise (ptr=false, default): extract element type from buffer (for LOAD sources)
39    ///
40    /// # Examples
41    /// ```ignore
42    /// // Element dtype (default) - for LOAD
43    /// UOp::index().buffer(buf).indices(vec![idx]).call()?
44    ///
45    /// // Ptr dtype via ptr=true - for STORE (preferred, Tinygrad-aligned)
46    /// UOp::index().buffer(buf).indices(vec![idx]).ptr(true).call()?
47    ///
48    /// // Explicit Ptr dtype - for STORE (legacy, works but prefer .ptr(true))
49    /// let ptr_dtype = DType::Float32.ptr(Some(size), AddrSpace::Global);
50    /// UOp::index().buffer(buf).indices(vec![idx]).dtype(ptr_dtype).call()?
51    ///
52    /// // With gate
53    /// UOp::index().buffer(buf).indices(vec![idx]).gate(gate_uop).call()?
54    /// ```
55    #[builder]
56    pub fn index<I: Into<SmallVec<[Arc<Self>; 4]>>>(
57        buffer: Arc<Self>,
58        indices: I,
59        gate: Option<Arc<Self>>,
60        dtype: Option<DType>,
61        /// When true, keep buffer's Ptr dtype (for STORE targets).
62        /// When false (default), extract element type (for LOAD sources).
63        /// Matches Tinygrad's `buf.index(idx, ptr=True/False)`.
64        ptr: Option<bool>,
65    ) -> Result<Arc<Self>> {
66        let indices = indices.into();
67        // Validate that all indices have integer/index base dtype.
68        // Allows both scalar (Index, Int64, Int32) and vector (Index.vec(N), Int64.vec(N))
69        // for devectorized register/local buffer indexing.
70        for idx in &indices {
71            let base = idx.dtype().base();
72            ensure!(
73                matches!(
74                    base,
75                    morok_dtype::ScalarDType::Index | morok_dtype::ScalarDType::Int64 | morok_dtype::ScalarDType::Int32
76                ),
77                IndexTypeMismatchSnafu { actual: idx.dtype() }
78            );
79        }
80
81        // Determine result dtype based on (dtype, ptr) parameters
82        // Priority: explicit dtype > ptr flag > default (element type)
83        let result_dtype = match (dtype, ptr.unwrap_or(false)) {
84            (Some(d), _) => d,              // Explicit dtype takes precedence
85            (None, true) => buffer.dtype(), // ptr=true: keep Ptr dtype
86            (None, false) => match buffer.dtype() {
87                // ptr=false: extract element type
88                DType::Ptr { base, .. } => base.as_ref().clone(),
89                other => other,
90            },
91        };
92
93        Ok(Self::new(Op::Index { buffer, indices, gate }, result_dtype))
94    }
95
96    /// Create a pointer index operation (pointer arithmetic).
97    ///
98    /// Performs pointer + offset arithmetic for address calculation in kernels.
99    /// Both self (ptr) and offset should have Index dtype.
100    pub fn pointer_index(self: &Arc<Self>, offset: Arc<Self>) -> Result<Arc<Self>> {
101        let ptr_dtype = self.dtype();
102        let offset_dtype = offset.dtype();
103        ensure!(ptr_dtype == DType::Index, IndexTypeMismatchSnafu { actual: ptr_dtype });
104        ensure!(offset_dtype == DType::Index, IndexTypeMismatchSnafu { actual: offset_dtype });
105        Ok(Self::new(Op::PointerIndex { ptr: self.clone(), offset }, DType::Index))
106    }
107
108    /// Multi-dimensional slicing with IndexSpec.
109    ///
110    /// **Note**: Range and NewAxis specs are not fully implemented;
111    /// currently only Single indices are properly supported.
112    pub fn slice(buffer: Arc<Self>, specs: Vec<IndexSpec>) -> Result<Arc<Self>> {
113        let mut indices = Vec::new();
114
115        for spec in specs {
116            match spec {
117                IndexSpec::Single(idx) => {
118                    // Single index - just use it directly
119                    indices.push(idx);
120                }
121                IndexSpec::Range { start, end: _, step: _ } => {
122                    // Range indexing - for now, just use start as a simple index
123                    // TODO: Proper range expansion requires loop IR and range operations
124                    indices.push(start);
125                }
126                IndexSpec::Full => {
127                    // Full slice - skip (means "all elements")
128                    // TODO: Proper handling requires understanding dimension size
129                }
130                IndexSpec::NewAxis => {
131                    // NewAxis - adds dimension
132                    // TODO: Requires reshape operation
133                }
134            }
135        }
136
137        if indices.is_empty() {
138            // No actual indexing, just return buffer
139            Ok(buffer)
140        } else {
141            Self::index().buffer(buffer).indices(indices).call()
142        }
143    }
144
145    /// Gated slicing - conditional access with gate.
146    pub fn slice_gated(buffer: Arc<Self>, specs: Vec<IndexSpec>, gate: Arc<Self>) -> Result<Arc<Self>> {
147        let mut indices = Vec::new();
148
149        for spec in specs {
150            match spec {
151                IndexSpec::Single(idx) => indices.push(idx),
152                IndexSpec::Range { start, .. } => indices.push(start),
153                IndexSpec::Full | IndexSpec::NewAxis => {}
154            }
155        }
156
157        if indices.is_empty() { Ok(buffer) } else { Self::index().buffer(buffer).indices(indices).gate(gate).call() }
158    }
159
160    // =========================================================================
161    // Index Helpers
162    // =========================================================================
163
164    /// Wrap index with validity condition.
165    ///
166    /// This is the Rust equivalent of Tinygrad's `idx.valid(cond)`.
167    /// Creates WHERE(cond, self, Invalid) to mark conditional index validity.
168    ///
169    /// # Examples
170    ///
171    /// ```ignore
172    /// // Create a conditionally valid index
173    /// let valid_idx = idx.valid(cond);
174    /// // Equivalent to: WHERE(cond, idx, INVALID)
175    /// ```
176    pub fn valid(self: &Arc<Self>, cond: Arc<Self>) -> Arc<Self> {
177        UOp::try_where(cond, self.clone(), UOp::invalid_marker()).expect("valid: WHERE construction failed")
178    }
179
180    // =========================================================================
181    // Memory Access Operations
182    // =========================================================================
183
184    /// Create a LOAD operation.
185    ///
186    /// # Example
187    /// ```ignore
188    /// // Infer dtype from buffer
189    /// UOp::load().buffer(buf).index(idx).call()
190    ///
191    /// // Explicit dtype for vector loads
192    /// UOp::load().buffer(buf).index(idx).dtype(vec4_dtype).call()
193    ///
194    /// // With alt value for gated loads
195    /// UOp::load().buffer(buf).index(idx).alt(zero).call()
196    /// ```
197    #[builder]
198    pub fn load(buffer: Arc<Self>, index: Arc<Self>, dtype: Option<DType>, alt: Option<Arc<Self>>) -> Arc<Self> {
199        let dtype = dtype.unwrap_or_else(|| match &buffer.dtype {
200            DType::Ptr { base, .. } => (**base).clone(),
201            other => other.clone(),
202        });
203        Self::new(Op::Load { buffer, index, alt }, dtype)
204    }
205
206    /// Create a STORE operation without ranges.
207    ///
208    /// Stores a value at self (INDEX location).
209    /// The buffer is accessed indirectly through the INDEX node.
210    /// For stores with ranges (e.g., output upcasting), use `store_with_ranges`.
211    ///
212    /// For gated stores, use an INDEX with a gate (INDEX has optional gate field).
213    pub fn store(self: &Arc<Self>, value: Arc<Self>) -> Arc<Self> {
214        self.store_with_ranges(value, SmallVec::new())
215    }
216
217    /// Create a STORE operation with ranges.
218    ///
219    /// Stores a value at self (INDEX location), with explicit ranges
220    /// that define the scope of the store operation. This matches Tinygrad's
221    /// architecture where STORE sources are `(index, value, *ranges)`.
222    ///
223    /// Ranges are used for output upcasting: Range(Upcast) becomes UNROLL
224    /// during expansion, which `fix_store_unroll` contracts via CONTRACT.
225    ///
226    /// For gated stores, use an INDEX with a gate (INDEX has optional gate field).
227    pub fn store_with_ranges(self: &Arc<Self>, value: Arc<Self>, ranges: SmallVec<[Arc<Self>; 4]>) -> Arc<Self> {
228        Self::new(Op::Store { index: self.clone(), value, ranges }, DType::Void)
229    }
230
231    // =========================================================================
232    // Device Operations
233    // =========================================================================
234
235    /// Copy to a different device.
236    pub fn copy_to_device(self: &Arc<Self>, device: DeviceSpec) -> Arc<Self> {
237        let dev = Self::device(device);
238        Self::new(Op::Copy { src: self.clone(), device: dev }, self.dtype.clone())
239    }
240
241    /// Create a COPY operation with explicit device UOp.
242    ///
243    /// Unlike `copy_to_device` which takes a `DeviceSpec`, this takes
244    /// a device UOp directly (useful when you already have one).
245    pub fn copy(self: &Arc<Self>, device: Arc<Self>) -> Arc<Self> {
246        let dtype = self.dtype.clone();
247        Self::new(Op::Copy { src: self.clone(), device }, dtype)
248    }
249
250    // =========================================================================
251    // Bufferization Operations
252    // =========================================================================
253
254    /// Create a BUFFERIZE operation.
255    ///
256    /// Marks a computation to be materialized into a buffer.
257    /// The computation is evaluated over the given ranges and stored.
258    pub fn bufferize(compute: Arc<Self>, ranges: Vec<Arc<Self>>, opts: BufferizeOpts) -> Arc<Self> {
259        let dtype = compute.dtype.clone();
260        Self::new(Op::Bufferize { compute, ranges: SmallVec::from_vec(ranges), opts }, dtype)
261    }
262
263    /// Create a BUFFERIZE operation with Global address space.
264    ///
265    /// This is the most common pattern - bufferize to global memory.
266    pub fn bufferize_global(compute: Arc<Self>, ranges: Vec<Arc<Self>>) -> Arc<Self> {
267        Self::bufferize(compute, ranges, BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true })
268    }
269
270    /// Create a BUFFERIZE operation with Local address space.
271    ///
272    /// For shared/local memory bufferization.
273    pub fn bufferize_local(compute: Arc<Self>, ranges: Vec<Arc<Self>>) -> Arc<Self> {
274        Self::bufferize(compute, ranges, BufferizeOpts { device: None, addrspace: AddrSpace::Local, removable: true })
275    }
276
277    // =========================================================================
278    // Memory Definition Operations
279    // =========================================================================
280
281    /// Create a DEFINE_LOCAL operation.
282    ///
283    /// Defines a local (shared) memory allocation with the given ID.
284    pub fn define_local(id: usize, dtype: DType) -> Arc<Self> {
285        Self::new(Op::DefineLocal(id), dtype)
286    }
287
288    /// Define register memory (void pointer - type determined by usage).
289    pub fn define_reg(size: usize) -> Arc<Self> {
290        use morok_dtype::AddrSpace;
291        let id = crate::uop::hash_consing::next_unique_id();
292        let ptr_dtype = DType::Void.ptr(Some(size), AddrSpace::Reg);
293        Self::new(Op::DefineReg { size, id }, ptr_dtype)
294    }
295
296    /// Define register memory with explicit element type.
297    ///
298    /// Creates a typed register accumulator for use in reductions.
299    /// The element_dtype specifies the type of each element (e.g., Float32 for a float accumulator).
300    pub fn define_reg_typed(size: usize, element_dtype: DType) -> Arc<Self> {
301        use morok_dtype::AddrSpace;
302        let id = crate::uop::hash_consing::next_unique_id();
303        let ptr_dtype =
304            DType::Ptr { base: Box::new(element_dtype), addrspace: AddrSpace::Reg, size: Some(size), vcount: 1 };
305        Self::new(Op::DefineReg { size, id }, ptr_dtype)
306    }
307}