morok_ir/uop/constructors/hardware.rs
1//! Hardware-specific operations: WMMA, vectorize, kernel.
2//!
3//! This module contains hardware-specific operations:
4//! - Tensor cores: wmma
5//! - Vectorization: vectorize, gep, contract, unroll, cat, ptrcat
6//! - Multi-device: mstack, mselect
7//! - Kernels: kernel
8
9use std::sync::Arc;
10
11use bon::bon;
12use morok_dtype::DType;
13use smallvec::SmallVec;
14use snafu::ensure;
15
16use crate::Result;
17use crate::error::{
18 BroadcastRequiresScalarSnafu, ContractCountMismatchSnafu, GepIndexOutOfBoundsSnafu, GepRequiresVectorSnafu,
19 UnrollCountMismatchSnafu, VectorizeDTypeMismatchSnafu, VectorizeEmptySnafu,
20};
21use crate::op::Op;
22use crate::types::WmmaMetadata;
23use crate::uop::UOp;
24
25#[bon]
26impl UOp {
27 // =========================================================================
28 // Tensor Core Operations
29 // =========================================================================
30
31 /// Warp Matrix Multiply-Accumulate for tensor cores.
32 ///
33 /// Computes D = A × B + C using hardware matrix units.
34 /// `metadata` specifies dimensions, dtypes, and upcast axes for vectorization.
35 pub fn wmma(a: Arc<Self>, b: Arc<Self>, c: Arc<Self>, metadata: WmmaMetadata) -> Arc<Self> {
36 let base_dtype = metadata.dtype_out.clone();
37
38 // Calculate vector size from C (output) upcast axes
39 let vec_size = metadata.upcast_axes.c.iter().map(|(_, size)| size).product::<usize>();
40
41 let dtype = if vec_size > 1 { base_dtype.vec(vec_size) } else { base_dtype };
42
43 Self::new(Op::Wmma { a, b, c, metadata }, dtype)
44 }
45
46 // =========================================================================
47 // Vectorization Operations
48 // =========================================================================
49
50 /// Create vector from scalar elements (fallible version with validation).
51 ///
52 /// # Errors
53 /// - `VectorizeRequiresMultiple` if elements is empty
54 /// - `VectorizeDTypeMismatch` if elements have different scalar dtypes
55 pub fn try_vectorize(elements: SmallVec<[Arc<Self>; 4]>) -> Result<Arc<Self>> {
56 ensure!(!elements.is_empty(), VectorizeEmptySnafu);
57
58 // Use full dtype (not scalar_dtype) to preserve Ptr type for pointer vectors.
59 // This matches Tinygrad's broadcast: `UOp(Ops.VECTORIZE, self.dtype.vec(count), ...)`
60 // For Ptr types: Ptr{vcount:1}.vec(N) → Ptr{vcount:N} (vector of pointers)
61 // For Scalar types: Scalar(Float32).vec(N) → Vector{Float32, N}
62 let expected_dtype = elements[0].dtype();
63 for elem in elements.iter().skip(1) {
64 let actual = elem.dtype();
65 ensure!(expected_dtype == actual, VectorizeDTypeMismatchSnafu { expected: expected_dtype, actual });
66 }
67
68 let vec_dtype = expected_dtype.vec(elements.len());
69 Ok(Self::new(Op::Vectorize { elements }, vec_dtype))
70 }
71
72 /// Create vector from scalar elements (panics on violation).
73 pub fn vectorize(elements: SmallVec<[Arc<Self>; 4]>) -> Arc<Self> {
74 Self::try_vectorize(elements).expect("vectorize precondition violated")
75 }
76
77 /// Broadcast a scalar value to a vector by replication (fallible version).
78 ///
79 /// Creates a VECTORIZE operation with `count` copies of the source.
80 /// If `count == 1`, returns the source unchanged.
81 ///
82 /// # Errors
83 /// - `BroadcastRequiresScalar` if source has vcount > 1
84 pub fn try_broadcast(self: &Arc<Self>, count: usize) -> Result<Arc<Self>> {
85 ensure!(self.dtype().vcount() == 1, BroadcastRequiresScalarSnafu { dtype: self.dtype() });
86
87 if count == 1 {
88 return Ok(self.clone());
89 }
90 let elements: SmallVec<[Arc<Self>; 4]> = (0..count).map(|_| self.clone()).collect();
91 Ok(Self::vectorize(elements))
92 }
93
94 /// Broadcast a scalar value to a vector by replication.
95 ///
96 /// Creates a VECTORIZE operation with `count` copies of the source.
97 /// If `count == 1`, returns the source unchanged.
98 ///
99 /// # Example
100 ///
101 /// ```ignore
102 /// let vector = scalar.broadcast(4);
103 /// ```
104 pub fn broadcast(self: &Arc<Self>, count: usize) -> Arc<Self> {
105 if count == 1 {
106 return self.clone();
107 }
108 let elements: SmallVec<[Arc<Self>; 4]> = (0..count).map(|_| self.clone()).collect();
109 Self::vectorize(elements)
110 }
111
112 /// Extract element(s) from vector (fallible version with validation).
113 ///
114 /// # Errors
115 /// - `GepRequiresVector` if source has vcount <= 1
116 /// - `GepIndexOutOfBounds` if any index >= source vcount
117 pub fn try_gep(self: &Arc<Self>, indices: Vec<usize>) -> Result<Arc<Self>> {
118 let vector_dtype = self.dtype();
119 let vcount = vector_dtype.vcount();
120
121 ensure!(vcount > 1, GepRequiresVectorSnafu { dtype: vector_dtype.clone() });
122
123 for &index in &indices {
124 ensure!(index < vcount, GepIndexOutOfBoundsSnafu { index, vcount });
125 }
126
127 let dtype = if indices.len() == 1 {
128 DType::Scalar(vector_dtype.base())
129 } else {
130 DType::Scalar(vector_dtype.base()).vec(indices.len())
131 };
132
133 Ok(Self::new(Op::Gep { vector: self.clone(), indices }, dtype))
134 }
135
136 /// Extract element(s) from vector (Get Element Pointer).
137 ///
138 /// # Example
139 ///
140 /// ```ignore
141 /// let elem = vector.gep(vec![0]); // Extract single element
142 /// let sub = vector.gep(vec![0, 2]); // Extract multiple elements
143 /// ```
144 pub fn gep(self: &Arc<Self>, indices: Vec<usize>) -> Arc<Self> {
145 let vector_dtype = self.dtype();
146 let dtype = if indices.len() == 1 {
147 DType::Scalar(vector_dtype.base())
148 } else {
149 DType::Scalar(vector_dtype.base()).vec(indices.len())
150 };
151 Self::new(Op::Gep { vector: self.clone(), indices }, dtype)
152 }
153
154 /// Contract unrolled values back into vectorized form (fallible version).
155 ///
156 /// # Errors
157 /// - `ContractCountMismatch` if dtype.vcount != product of axis sizes
158 pub fn try_contract(self: &Arc<Self>, upcast_ranges: Vec<(usize, usize)>) -> Result<Arc<Self>> {
159 let base_dtype = self.dtype();
160 let dtype_count = base_dtype.vcount();
161 let axis_product: usize = upcast_ranges.iter().map(|(_, size)| size).product();
162
163 // Only validate if dtype is not void (STORE ops have void dtype)
164 if base_dtype != DType::Void {
165 ensure!(dtype_count == axis_product, ContractCountMismatchSnafu { dtype_count, axis_product });
166 }
167
168 let dtype = if axis_product > 1 { base_dtype.vec(axis_product) } else { base_dtype };
169
170 Ok(Self::new(Op::Contract { src: self.clone(), upcast_ranges }, dtype))
171 }
172
173 /// Contract unrolled values back into vectorized form.
174 ///
175 /// Pairs with UNROLL: UNROLL expands loops for optimization,
176 /// CONTRACT combines the results. Used in WMMA and vectorization passes.
177 pub fn contract(self: &Arc<Self>, upcast_ranges: Vec<(usize, usize)>) -> Arc<Self> {
178 let base_dtype = self.dtype();
179 let vec_size = upcast_ranges.iter().map(|(_, size)| size).product::<usize>();
180 let dtype = if vec_size > 1 { base_dtype.vec(vec_size) } else { base_dtype };
181 Self::new(Op::Contract { src: self.clone(), upcast_ranges }, dtype)
182 }
183
184 /// Expand a value across unrolled loop iterations (fallible version).
185 ///
186 /// # Errors
187 /// - `UnrollCountMismatch` if src.dtype.vcount != product of axis sizes
188 pub fn try_unroll(self: &Arc<Self>, unroll_axes: Vec<(usize, usize)>) -> Result<Arc<Self>> {
189 let dtype = self.dtype();
190 let dtype_count = dtype.vcount();
191 let axis_product: usize = unroll_axes.iter().map(|(_, size)| size).product();
192
193 // Only validate if we have axes to unroll
194 if !unroll_axes.is_empty() {
195 ensure!(dtype_count == axis_product, UnrollCountMismatchSnafu { dtype_count, axis_product });
196 }
197
198 Ok(Self::new(Op::Unroll { src: self.clone(), unroll_axes }, dtype))
199 }
200
201 /// Expand a value across unrolled loop iterations.
202 ///
203 /// Creates multiple versions of the computation for each unroll axis.
204 /// Pairs with CONTRACT which combines results back together.
205 pub fn unroll(self: &Arc<Self>, unroll_axes: Vec<(usize, usize)>) -> Arc<Self> {
206 let dtype = self.dtype();
207 Self::new(Op::Unroll { src: self.clone(), unroll_axes }, dtype)
208 }
209
210 /// Create UNROLL with explicit dtype (for do_contract pattern).
211 ///
212 /// Used when UNROLL dtype should differ from source dtype,
213 /// specifically when CONTRACT collapses UNROLL via GEP and
214 /// we need to preserve the per-iteration element type.
215 ///
216 /// Based on Tinygrad's pattern where partial contraction creates
217 /// UNROLL with remaining axes but CONTRACT's dtype.
218 pub fn unroll_with_dtype(self: &Arc<Self>, unroll_axes: Vec<(usize, usize)>, dtype: DType) -> Arc<Self> {
219 Self::new(Op::Unroll { src: self.clone(), unroll_axes }, dtype)
220 }
221
222 /// Create a CAT operation (concatenate vectors).
223 ///
224 /// # Example
225 /// ```ignore
226 /// // Infer dtype (sum of vcounts)
227 /// UOp::cat().sources(vec![a, b]).call()
228 ///
229 /// // Explicit dtype
230 /// UOp::cat().sources(vec![a, b]).dtype(vec8_dtype).call()
231 /// ```
232 #[builder]
233 pub fn cat(sources: Vec<Arc<Self>>, dtype: Option<DType>) -> Arc<Self> {
234 assert!(!sources.is_empty(), "CAT requires at least one source");
235 let dtype = dtype.unwrap_or_else(|| {
236 let total_count: usize = sources.iter().map(|s| s.dtype().vcount()).sum();
237 DType::Scalar(sources[0].dtype.base()).vec(total_count)
238 });
239 Self::new(Op::Cat { sources: SmallVec::from_vec(sources) }, dtype)
240 }
241
242 /// Create a PTRCAT operation (concatenate pointers).
243 ///
244 /// # Example
245 /// ```ignore
246 /// UOp::ptrcat().sources(vec![a, b]).dtype(ptr_dtype).call()
247 /// ```
248 #[builder]
249 pub fn ptrcat(sources: Vec<Arc<Self>>, dtype: Option<DType>) -> Arc<Self> {
250 assert!(!sources.is_empty(), "PTRCAT requires at least one source");
251 let dtype = dtype.unwrap_or_else(|| {
252 // Compute vcount from total source pointer vcount, matching CAT's approach.
253 let total_vcount: usize = sources
254 .iter()
255 .map(|s| match s.dtype() {
256 DType::Ptr { base, .. } => base.vcount(),
257 other => other.vcount(),
258 })
259 .sum();
260 let base = &sources[0].dtype;
261 match base {
262 DType::Ptr { base, addrspace, size, .. } => {
263 DType::Ptr { base: base.clone(), addrspace: *addrspace, size: *size, vcount: total_vcount }
264 }
265 _ => base.clone(),
266 }
267 });
268 Self::new(Op::PtrCat { sources: SmallVec::from_vec(sources) }, dtype)
269 }
270
271 // =========================================================================
272 // Multi-Device Operations
273 // =========================================================================
274
275 /// Stack multiple buffers (multi-device tensors).
276 ///
277 /// MStack combines buffers from multiple devices into a single logical tensor.
278 /// Used for distributed/multi-GPU tensor operations.
279 pub fn mstack(buffers: SmallVec<[Arc<Self>; 4]>) -> Arc<Self> {
280 let dtype = buffers.first().map(|b| b.dtype()).unwrap_or(DType::Void);
281 Self::new(Op::MStack { buffers }, dtype)
282 }
283
284 /// Select buffer by device index (multi-device access).
285 ///
286 /// MSelect retrieves a specific device's buffer from a multi-device tensor.
287 pub fn mselect(self: &Arc<Self>, device_index: usize) -> Arc<Self> {
288 let dtype = self.dtype();
289 Self::new(Op::MSelect { buffer: self.clone(), device_index }, dtype)
290 }
291
292 // =========================================================================
293 // Kernel Operations
294 // =========================================================================
295
296 /// Kernel wrapper.
297 ///
298 /// Creates a KERNEL operation with the given sources (kernel arguments) and AST (computation).
299 ///
300 /// # Arguments
301 ///
302 /// * `sources` - Kernel arguments (buffers and variables)
303 /// * `ast` - The computation graph (usually SINK, COPY, or BUFFER_VIEW)
304 pub fn kernel(sources: SmallVec<[Arc<Self>; 4]>, ast: Arc<Self>) -> Arc<Self> {
305 Self::new(Op::Kernel { sources, ast }, DType::Void)
306 }
307}