Skip to main content

morok_ir/uop/constructors/
hardware.rs

1//! Hardware-specific operations: WMMA, vectorize, kernel.
2//!
3//! This module contains hardware-specific operations:
4//! - Tensor cores: wmma
5//! - Vectorization: vectorize, gep, contract, unroll, cat, ptrcat
6//! - Multi-device: mstack, mselect
7//! - Kernels: kernel
8
9use std::sync::Arc;
10
11use bon::bon;
12use morok_dtype::DType;
13use smallvec::SmallVec;
14use snafu::ensure;
15
16use crate::Result;
17use crate::error::{
18    BroadcastRequiresScalarSnafu, ContractCountMismatchSnafu, GepIndexOutOfBoundsSnafu, GepRequiresVectorSnafu,
19    UnrollCountMismatchSnafu, VectorizeDTypeMismatchSnafu, VectorizeEmptySnafu,
20};
21use crate::op::Op;
22use crate::types::WmmaMetadata;
23use crate::uop::UOp;
24
25#[bon]
26impl UOp {
27    // =========================================================================
28    // Tensor Core Operations
29    // =========================================================================
30
31    /// Warp Matrix Multiply-Accumulate for tensor cores.
32    ///
33    /// Computes D = A × B + C using hardware matrix units.
34    /// `metadata` specifies dimensions, dtypes, and upcast axes for vectorization.
35    pub fn wmma(a: Arc<Self>, b: Arc<Self>, c: Arc<Self>, metadata: WmmaMetadata) -> Arc<Self> {
36        let base_dtype = metadata.dtype_out.clone();
37
38        // Calculate vector size from C (output) upcast axes
39        let vec_size = metadata.upcast_axes.c.iter().map(|(_, size)| size).product::<usize>();
40
41        let dtype = if vec_size > 1 { base_dtype.vec(vec_size) } else { base_dtype };
42
43        Self::new(Op::Wmma { a, b, c, metadata }, dtype)
44    }
45
46    // =========================================================================
47    // Vectorization Operations
48    // =========================================================================
49
50    /// Create vector from scalar elements (fallible version with validation).
51    ///
52    /// # Errors
53    /// - `VectorizeRequiresMultiple` if elements is empty
54    /// - `VectorizeDTypeMismatch` if elements have different scalar dtypes
55    pub fn try_vectorize(elements: SmallVec<[Arc<Self>; 4]>) -> Result<Arc<Self>> {
56        ensure!(!elements.is_empty(), VectorizeEmptySnafu);
57
58        // Use full dtype (not scalar_dtype) to preserve Ptr type for pointer vectors.
59        // This matches Tinygrad's broadcast: `UOp(Ops.VECTORIZE, self.dtype.vec(count), ...)`
60        // For Ptr types: Ptr{vcount:1}.vec(N) → Ptr{vcount:N} (vector of pointers)
61        // For Scalar types: Scalar(Float32).vec(N) → Vector{Float32, N}
62        let expected_dtype = elements[0].dtype();
63        for elem in elements.iter().skip(1) {
64            let actual = elem.dtype();
65            ensure!(expected_dtype == actual, VectorizeDTypeMismatchSnafu { expected: expected_dtype, actual });
66        }
67
68        let vec_dtype = expected_dtype.vec(elements.len());
69        Ok(Self::new(Op::Vectorize { elements }, vec_dtype))
70    }
71
72    /// Create vector from scalar elements (panics on violation).
73    pub fn vectorize(elements: SmallVec<[Arc<Self>; 4]>) -> Arc<Self> {
74        Self::try_vectorize(elements).expect("vectorize precondition violated")
75    }
76
77    /// Broadcast a scalar value to a vector by replication (fallible version).
78    ///
79    /// Creates a VECTORIZE operation with `count` copies of the source.
80    /// If `count == 1`, returns the source unchanged.
81    ///
82    /// # Errors
83    /// - `BroadcastRequiresScalar` if source has vcount > 1
84    pub fn try_broadcast(self: &Arc<Self>, count: usize) -> Result<Arc<Self>> {
85        ensure!(self.dtype().vcount() == 1, BroadcastRequiresScalarSnafu { dtype: self.dtype() });
86
87        if count == 1 {
88            return Ok(self.clone());
89        }
90        let elements: SmallVec<[Arc<Self>; 4]> = (0..count).map(|_| self.clone()).collect();
91        Ok(Self::vectorize(elements))
92    }
93
94    /// Broadcast a scalar value to a vector by replication.
95    ///
96    /// Creates a VECTORIZE operation with `count` copies of the source.
97    /// If `count == 1`, returns the source unchanged.
98    ///
99    /// # Example
100    ///
101    /// ```ignore
102    /// let vector = scalar.broadcast(4);
103    /// ```
104    pub fn broadcast(self: &Arc<Self>, count: usize) -> Arc<Self> {
105        if count == 1 {
106            return self.clone();
107        }
108        let elements: SmallVec<[Arc<Self>; 4]> = (0..count).map(|_| self.clone()).collect();
109        Self::vectorize(elements)
110    }
111
112    /// Extract element(s) from vector (fallible version with validation).
113    ///
114    /// # Errors
115    /// - `GepRequiresVector` if source has vcount <= 1
116    /// - `GepIndexOutOfBounds` if any index >= source vcount
117    pub fn try_gep(self: &Arc<Self>, indices: Vec<usize>) -> Result<Arc<Self>> {
118        let vector_dtype = self.dtype();
119        let vcount = vector_dtype.vcount();
120
121        ensure!(vcount > 1, GepRequiresVectorSnafu { dtype: vector_dtype.clone() });
122
123        for &index in &indices {
124            ensure!(index < vcount, GepIndexOutOfBoundsSnafu { index, vcount });
125        }
126
127        let dtype = if indices.len() == 1 {
128            DType::Scalar(vector_dtype.base())
129        } else {
130            DType::Scalar(vector_dtype.base()).vec(indices.len())
131        };
132
133        Ok(Self::new(Op::Gep { vector: self.clone(), indices }, dtype))
134    }
135
136    /// Extract element(s) from vector (Get Element Pointer).
137    ///
138    /// # Example
139    ///
140    /// ```ignore
141    /// let elem = vector.gep(vec![0]);      // Extract single element
142    /// let sub = vector.gep(vec![0, 2]);    // Extract multiple elements
143    /// ```
144    pub fn gep(self: &Arc<Self>, indices: Vec<usize>) -> Arc<Self> {
145        let vector_dtype = self.dtype();
146        let dtype = if indices.len() == 1 {
147            DType::Scalar(vector_dtype.base())
148        } else {
149            DType::Scalar(vector_dtype.base()).vec(indices.len())
150        };
151        Self::new(Op::Gep { vector: self.clone(), indices }, dtype)
152    }
153
154    /// Contract unrolled values back into vectorized form (fallible version).
155    ///
156    /// # Errors
157    /// - `ContractCountMismatch` if dtype.vcount != product of axis sizes
158    pub fn try_contract(self: &Arc<Self>, upcast_ranges: Vec<(usize, usize)>) -> Result<Arc<Self>> {
159        let base_dtype = self.dtype();
160        let dtype_count = base_dtype.vcount();
161        let axis_product: usize = upcast_ranges.iter().map(|(_, size)| size).product();
162
163        // Only validate if dtype is not void (STORE ops have void dtype)
164        if base_dtype != DType::Void {
165            ensure!(dtype_count == axis_product, ContractCountMismatchSnafu { dtype_count, axis_product });
166        }
167
168        let dtype = if axis_product > 1 { base_dtype.vec(axis_product) } else { base_dtype };
169
170        Ok(Self::new(Op::Contract { src: self.clone(), upcast_ranges }, dtype))
171    }
172
173    /// Contract unrolled values back into vectorized form.
174    ///
175    /// Pairs with UNROLL: UNROLL expands loops for optimization,
176    /// CONTRACT combines the results. Used in WMMA and vectorization passes.
177    pub fn contract(self: &Arc<Self>, upcast_ranges: Vec<(usize, usize)>) -> Arc<Self> {
178        let base_dtype = self.dtype();
179        let vec_size = upcast_ranges.iter().map(|(_, size)| size).product::<usize>();
180        let dtype = if vec_size > 1 { base_dtype.vec(vec_size) } else { base_dtype };
181        Self::new(Op::Contract { src: self.clone(), upcast_ranges }, dtype)
182    }
183
184    /// Expand a value across unrolled loop iterations (fallible version).
185    ///
186    /// # Errors
187    /// - `UnrollCountMismatch` if src.dtype.vcount != product of axis sizes
188    pub fn try_unroll(self: &Arc<Self>, unroll_axes: Vec<(usize, usize)>) -> Result<Arc<Self>> {
189        let dtype = self.dtype();
190        let dtype_count = dtype.vcount();
191        let axis_product: usize = unroll_axes.iter().map(|(_, size)| size).product();
192
193        // Only validate if we have axes to unroll
194        if !unroll_axes.is_empty() {
195            ensure!(dtype_count == axis_product, UnrollCountMismatchSnafu { dtype_count, axis_product });
196        }
197
198        Ok(Self::new(Op::Unroll { src: self.clone(), unroll_axes }, dtype))
199    }
200
201    /// Expand a value across unrolled loop iterations.
202    ///
203    /// Creates multiple versions of the computation for each unroll axis.
204    /// Pairs with CONTRACT which combines results back together.
205    pub fn unroll(self: &Arc<Self>, unroll_axes: Vec<(usize, usize)>) -> Arc<Self> {
206        let dtype = self.dtype();
207        Self::new(Op::Unroll { src: self.clone(), unroll_axes }, dtype)
208    }
209
210    /// Create UNROLL with explicit dtype (for do_contract pattern).
211    ///
212    /// Used when UNROLL dtype should differ from source dtype,
213    /// specifically when CONTRACT collapses UNROLL via GEP and
214    /// we need to preserve the per-iteration element type.
215    ///
216    /// Based on Tinygrad's pattern where partial contraction creates
217    /// UNROLL with remaining axes but CONTRACT's dtype.
218    pub fn unroll_with_dtype(self: &Arc<Self>, unroll_axes: Vec<(usize, usize)>, dtype: DType) -> Arc<Self> {
219        Self::new(Op::Unroll { src: self.clone(), unroll_axes }, dtype)
220    }
221
222    /// Create a CAT operation (concatenate vectors).
223    ///
224    /// # Example
225    /// ```ignore
226    /// // Infer dtype (sum of vcounts)
227    /// UOp::cat().sources(vec![a, b]).call()
228    ///
229    /// // Explicit dtype
230    /// UOp::cat().sources(vec![a, b]).dtype(vec8_dtype).call()
231    /// ```
232    #[builder]
233    pub fn cat(sources: Vec<Arc<Self>>, dtype: Option<DType>) -> Arc<Self> {
234        assert!(!sources.is_empty(), "CAT requires at least one source");
235        let dtype = dtype.unwrap_or_else(|| {
236            let total_count: usize = sources.iter().map(|s| s.dtype().vcount()).sum();
237            DType::Scalar(sources[0].dtype.base()).vec(total_count)
238        });
239        Self::new(Op::Cat { sources: SmallVec::from_vec(sources) }, dtype)
240    }
241
242    /// Create a PTRCAT operation (concatenate pointers).
243    ///
244    /// # Example
245    /// ```ignore
246    /// UOp::ptrcat().sources(vec![a, b]).dtype(ptr_dtype).call()
247    /// ```
248    #[builder]
249    pub fn ptrcat(sources: Vec<Arc<Self>>, dtype: Option<DType>) -> Arc<Self> {
250        assert!(!sources.is_empty(), "PTRCAT requires at least one source");
251        let dtype = dtype.unwrap_or_else(|| {
252            // Compute vcount from total source pointer vcount, matching CAT's approach.
253            let total_vcount: usize = sources
254                .iter()
255                .map(|s| match s.dtype() {
256                    DType::Ptr { base, .. } => base.vcount(),
257                    other => other.vcount(),
258                })
259                .sum();
260            let base = &sources[0].dtype;
261            match base {
262                DType::Ptr { base, addrspace, size, .. } => {
263                    DType::Ptr { base: base.clone(), addrspace: *addrspace, size: *size, vcount: total_vcount }
264                }
265                _ => base.clone(),
266            }
267        });
268        Self::new(Op::PtrCat { sources: SmallVec::from_vec(sources) }, dtype)
269    }
270
271    // =========================================================================
272    // Multi-Device Operations
273    // =========================================================================
274
275    /// Stack multiple buffers (multi-device tensors).
276    ///
277    /// MStack combines buffers from multiple devices into a single logical tensor.
278    /// Used for distributed/multi-GPU tensor operations.
279    pub fn mstack(buffers: SmallVec<[Arc<Self>; 4]>) -> Arc<Self> {
280        let dtype = buffers.first().map(|b| b.dtype()).unwrap_or(DType::Void);
281        Self::new(Op::MStack { buffers }, dtype)
282    }
283
284    /// Select buffer by device index (multi-device access).
285    ///
286    /// MSelect retrieves a specific device's buffer from a multi-device tensor.
287    pub fn mselect(self: &Arc<Self>, device_index: usize) -> Arc<Self> {
288        let dtype = self.dtype();
289        Self::new(Op::MSelect { buffer: self.clone(), device_index }, dtype)
290    }
291
292    // =========================================================================
293    // Kernel Operations
294    // =========================================================================
295
296    /// Kernel wrapper.
297    ///
298    /// Creates a KERNEL operation with the given sources (kernel arguments) and AST (computation).
299    ///
300    /// # Arguments
301    ///
302    /// * `sources` - Kernel arguments (buffers and variables)
303    /// * `ast` - The computation graph (usually SINK, COPY, or BUFFER_VIEW)
304    pub fn kernel(sources: SmallVec<[Arc<Self>; 4]>, ast: Arc<Self>) -> Arc<Self> {
305        Self::new(Op::Kernel { sources, ast }, DType::Void)
306    }
307}