1use std::marker::PhantomData;
50
51use crate::{
52 ir::{self, Instruction, Operation},
53 unexpanded,
54};
55
56use super::{
57 CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, IntoRuntime,
58 Slice, SliceMut,
59};
60
61pub use ir::{MatrixIdent, MatrixLayout};
62
63#[derive(Copy, Clone)]
68pub struct Matrix<C: CubeType> {
69 _c: PhantomData<C>,
70}
71
72pub struct MatrixExpand<C: CubeType> {
74 elem: ExpandElement,
75 ident: MatrixIdent,
76 _c: PhantomData<C>,
77}
78
79impl<C: CubeType> Clone for MatrixExpand<C> {
80 fn clone(&self) -> Self {
81 Self {
82 elem: self.elem.clone(),
83 ident: self.ident,
84 _c: self._c,
85 }
86 }
87}
88
89impl<C: CubeType> CubeType for Matrix<C> {
90 type ExpandType = MatrixExpand<C>;
91}
92
93impl<C: CubeType> IntoRuntime for Matrix<C> {
94 fn __expand_runtime_method(self, _context: &mut CubeContext) -> MatrixExpand<C> {
95 unimplemented!("Matrices can't exist at compile time")
96 }
97}
98
99impl<C: CubeType> Init for MatrixExpand<C> {
100 fn init(self, _context: &mut CubeContext) -> Self {
101 self
102 }
103}
104
105impl<C: CubePrimitive> Matrix<C> {
106 #[allow(unused_variables)]
124 pub unsafe fn uninitialized(
125 ident: MatrixIdent,
126 m: u32,
127 n: u32,
128 k: u32,
129 layout: MatrixLayout,
130 ) -> Self {
131 Matrix { _c: PhantomData }
132 }
133
134 #[allow(unused_variables)]
148 pub fn from_value(
149 ident: MatrixIdent,
150 m: u32,
151 n: u32,
152 k: u32,
153 layout: MatrixLayout,
154 value: C,
155 ) -> Self {
156 Matrix { _c: PhantomData }
157 }
158
159 #[allow(unused_variables)]
173 pub fn from_slice(
174 ident: MatrixIdent,
175 m: u32,
176 n: u32,
177 k: u32,
178 layout: MatrixLayout,
179 value: &Slice<C>,
180 stride: u32,
181 ) -> Self {
182 Matrix { _c: PhantomData }
183 }
184
185 pub fn __expand_uninitialized(
186 context: &mut CubeContext,
187 ident: MatrixIdent,
188 m: ExpandElementTyped<u32>,
189 n: ExpandElementTyped<u32>,
190 k: ExpandElementTyped<u32>,
191 layout: MatrixLayout,
192 ) -> MatrixExpand<C> {
193 let elem = C::as_elem(context);
194 let elem = context.create_matrix(ir::Matrix {
195 ident,
196 m: m.constant().unwrap().as_u32() as u8,
197 n: n.constant().unwrap().as_u32() as u8,
198 k: k.constant().unwrap().as_u32() as u8,
199 elem,
200 layout,
201 });
202 MatrixExpand {
203 elem,
204 ident,
205 _c: PhantomData,
206 }
207 }
208
209 pub fn __expand_from_value(
210 context: &mut CubeContext,
211 ident: MatrixIdent,
212 m: ExpandElementTyped<u32>,
213 n: ExpandElementTyped<u32>,
214 k: ExpandElementTyped<u32>,
215 layout: MatrixLayout,
216 value: ExpandElementTyped<C>,
217 ) -> MatrixExpand<C> {
218 let mat = Self::__expand_uninitialized(context, ident, m, n, k, layout);
219 fill::expand(context, mat.clone(), value);
220 mat
221 }
222
223 #[allow(clippy::too_many_arguments)]
224 pub fn __expand_from_slice(
225 context: &mut CubeContext,
226 ident: MatrixIdent,
227 m: ExpandElementTyped<u32>,
228 n: ExpandElementTyped<u32>,
229 k: ExpandElementTyped<u32>,
230 layout: MatrixLayout,
231 value: ExpandElementTyped<Slice<C>>,
232 stride: ExpandElementTyped<u32>,
233 ) -> MatrixExpand<C> {
234 let mat = Self::__expand_uninitialized(context, ident, m, n, k, layout);
235 load::expand(context, mat.clone(), value, stride);
236 mat
237 }
238}
239
240#[allow(unused_variables)]
242pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
243 unexpanded!()
244}
245
246pub mod fill {
248 use super::*;
249
250 pub fn expand<C: CubeType>(
252 context: &mut CubeContext,
253 mat: MatrixExpand<C>,
254 value: ExpandElementTyped<C>,
255 ) {
256 let value: ExpandElement = value.into();
257 context.register(Instruction::new(
258 ir::CoopMma::Fill { value: *value },
259 *mat.elem,
260 ));
261 }
262}
263
264#[allow(unused_variables)]
266pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
267 unexpanded!()
268}
269
270pub mod load {
272 use super::*;
273
274 #[allow(unused_variables)]
276 pub fn expand<C: CubePrimitive, V: CubePrimitive>(
277 context: &mut CubeContext,
278 mat: MatrixExpand<C>,
279 value: ExpandElementTyped<Slice<V>>,
280 stride: ExpandElementTyped<u32>,
281 ) {
282 let stride: ExpandElement = stride.into();
283 assert_ne!(
284 mat.ident,
285 MatrixIdent::Accumulator,
286 "Loading accumulator requires explicit layout. Use `load_with_layout` instead."
287 );
288
289 context.register(Instruction::new(
290 ir::CoopMma::Load {
291 value: *value.expand,
292 stride: *stride,
293 layout: None,
294 },
295 *mat.elem,
296 ));
297 }
298}
299
300#[allow(unused_variables)]
303pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
304 mat: &Matrix<C>,
305 value: &Slice<V>,
306 stride: u32,
307 layout: MatrixLayout,
308) {
309 unexpanded!()
310}
311
312pub mod load_with_layout {
314 use super::*;
315
316 #[allow(unused_variables)]
318 pub fn expand<C: CubeType, V: CubePrimitive>(
319 context: &mut CubeContext,
320 mat: MatrixExpand<C>,
321 value: ExpandElementTyped<Slice<V>>,
322 stride: ExpandElementTyped<u32>,
323 layout: MatrixLayout,
324 ) {
325 let stride: ExpandElement = stride.into();
326
327 context.register(Instruction::new(
328 ir::CoopMma::Load {
329 value: *value.expand,
330 stride: *stride,
331 layout: Some(layout),
332 },
333 *mat.elem,
334 ));
335 }
336}
337
338#[allow(unused_variables)]
340pub fn store<C: CubePrimitive, O: CubePrimitive>(
341 output: &mut SliceMut<O>,
342 mat: &Matrix<C>,
343 stride: u32,
344 layout: MatrixLayout,
345) {
346 unexpanded!()
347}
348
349pub mod store {
351 use super::*;
352
353 #[allow(unused_variables)]
355 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
356 context: &mut CubeContext,
357 output: ExpandElementTyped<SliceMut<O>>,
358 mat: MatrixExpand<C>,
359 stride: ExpandElementTyped<u32>,
360 layout: MatrixLayout,
361 ) {
362 let stride: ExpandElement = stride.into();
363
364 context.register(Instruction::new(
365 ir::CoopMma::Store {
366 mat: *mat.elem,
367 stride: *stride,
368 layout,
369 },
370 *output.expand,
371 ));
372 }
373}
374
375#[allow(unused_variables)]
377pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
378 mat_a: &Matrix<A>,
379 mat_b: &Matrix<B>,
380 mat_c: &Matrix<C>,
381 mat_d: &Matrix<D>,
382) {
383 unexpanded!()
384}
385
386pub mod execute {
388 use super::*;
389
390 pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
392 context: &mut CubeContext,
393 mat_a: MatrixExpand<A>,
394 mat_b: MatrixExpand<B>,
395 mat_c: MatrixExpand<C>,
396 mat_d: MatrixExpand<D>,
397 ) {
398 context.register(Instruction::new(
399 ir::CoopMma::Execute {
400 mat_a: *mat_a.elem,
401 mat_b: *mat_b.elem,
402 mat_c: *mat_c.elem,
403 },
404 *mat_d.elem,
405 ));
406 }
407}
408
409#[allow(unused_variables)]
411pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
412 unexpanded!()
413}
414
415pub mod cast {
417 use super::*;
418
419 #[allow(unused_variables)]
421 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
422 context: &mut CubeContext,
423 input: MatrixExpand<C>,
424 ) -> MatrixExpand<O> {
425 let ident = input.ident;
426
427 if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
428 return MatrixExpand {
429 elem: input.elem,
430 ident,
431 _c: PhantomData,
432 };
433 }
434 let input = *input.elem;
435 let input_mat = match input.kind {
436 ir::VariableKind::Matrix { mat, .. } => mat,
437 _ => unreachable!(),
438 };
439
440 let elem = O::as_elem(context);
441 let elem = context.create_matrix(ir::Matrix {
442 ident,
443 m: input_mat.m,
444 n: input_mat.n,
445 k: input_mat.k,
446 elem,
447 layout: MatrixLayout::Undefined,
448 });
449
450 let output = MatrixExpand {
451 ident,
452 elem,
453 _c: PhantomData,
454 };
455 context.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
456
457 output
458 }
459}
460
461impl From<ir::CoopMma> for Operation {
462 fn from(value: ir::CoopMma) -> Self {
463 Operation::CoopMma(value)
464 }
465}