1use std::marker::PhantomData;
50
51use crate::{
52 ir::{self, Instruction},
53 unexpanded,
54};
55
56use super::{
57 CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, ReadOnly, Slice, SliceExpand,
58 SliceMut,
59};
60
61use cubecl_ir::{ExpandElement, Scope};
62pub use ir::{MatrixIdent, MatrixLayout};
63
64#[derive(Copy, Clone)]
69pub struct Matrix<C: CubeType> {
70 _c: PhantomData<C>,
71}
72
73pub struct MatrixExpand<C: CubeType> {
75 elem: ExpandElement,
76 ident: MatrixIdent,
77 _c: PhantomData<C>,
78}
79
80impl<C: CubeType> Clone for MatrixExpand<C> {
81 fn clone(&self) -> Self {
82 Self {
83 elem: self.elem.clone(),
84 ident: self.ident,
85 _c: self._c,
86 }
87 }
88}
89
90impl<C: CubeType> CubeType for Matrix<C> {
91 type ExpandType = MatrixExpand<C>;
92}
93
94impl<C: CubeType> IntoMut for MatrixExpand<C> {
95 fn into_mut(self, _scope: &mut Scope) -> Self {
96 self
97 }
98}
99
100impl<C: CubeType> CubeDebug for MatrixExpand<C> {
101 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
102 scope.update_variable_name(*self.elem, name);
103 }
104}
105
106impl<C: CubePrimitive> Matrix<C> {
107 #[allow(unused_variables)]
125 pub unsafe fn uninitialized(
126 ident: MatrixIdent,
127 m: u32,
128 n: u32,
129 k: u32,
130 layout: MatrixLayout,
131 ) -> Self {
132 Matrix { _c: PhantomData }
133 }
134
135 #[allow(unused_variables)]
149 pub fn from_value(
150 ident: MatrixIdent,
151 m: u32,
152 n: u32,
153 k: u32,
154 layout: MatrixLayout,
155 value: C,
156 ) -> Self {
157 Matrix { _c: PhantomData }
158 }
159
160 #[allow(unused_variables)]
174 pub fn from_slice(
175 ident: MatrixIdent,
176 m: u32,
177 n: u32,
178 k: u32,
179 layout: MatrixLayout,
180 value: &Slice<C>,
181 stride: u32,
182 ) -> Self {
183 Matrix { _c: PhantomData }
184 }
185
186 pub fn __expand_uninitialized(
187 scope: &mut Scope,
188 ident: MatrixIdent,
189 m: ExpandElementTyped<u32>,
190 n: ExpandElementTyped<u32>,
191 k: ExpandElementTyped<u32>,
192 layout: MatrixLayout,
193 ) -> MatrixExpand<C> {
194 let elem = C::as_elem(scope);
195 let elem = scope.create_matrix(ir::Matrix {
196 ident,
197 m: m.constant().unwrap().as_u32() as u8,
198 n: n.constant().unwrap().as_u32() as u8,
199 k: k.constant().unwrap().as_u32() as u8,
200 elem,
201 layout,
202 });
203 MatrixExpand {
204 elem,
205 ident,
206 _c: PhantomData,
207 }
208 }
209
210 pub fn __expand_from_value(
211 scope: &mut Scope,
212 ident: MatrixIdent,
213 m: ExpandElementTyped<u32>,
214 n: ExpandElementTyped<u32>,
215 k: ExpandElementTyped<u32>,
216 layout: MatrixLayout,
217 value: ExpandElementTyped<C>,
218 ) -> MatrixExpand<C> {
219 let mat = Self::__expand_uninitialized(scope, ident, m, n, k, layout);
220 fill::expand(scope, mat.clone(), value);
221 mat
222 }
223
224 #[allow(clippy::too_many_arguments)]
225 pub fn __expand_from_slice(
226 scope: &mut Scope,
227 ident: MatrixIdent,
228 m: ExpandElementTyped<u32>,
229 n: ExpandElementTyped<u32>,
230 k: ExpandElementTyped<u32>,
231 layout: MatrixLayout,
232 value: SliceExpand<C, ReadOnly>,
233 stride: ExpandElementTyped<u32>,
234 ) -> MatrixExpand<C> {
235 let mat = Self::__expand_uninitialized(scope, ident, m, n, k, layout);
236 load::expand(scope, mat.clone(), value, stride);
237 mat
238 }
239}
240
241#[allow(unused_variables)]
243pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
244 unexpanded!()
245}
246
247pub mod fill {
249 use super::*;
250
251 pub fn expand<C: CubeType>(
253 scope: &mut Scope,
254 mat: MatrixExpand<C>,
255 value: ExpandElementTyped<C>,
256 ) {
257 let value: ExpandElement = value.into();
258 scope.register(Instruction::new(
259 ir::CoopMma::Fill { value: *value },
260 *mat.elem,
261 ));
262 }
263}
264
265#[allow(unused_variables)]
267pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
268 unexpanded!()
269}
270
271pub mod load {
273 use super::*;
274
275 #[allow(unused_variables)]
277 pub fn expand<C: CubePrimitive, V: CubePrimitive>(
278 scope: &mut Scope,
279 mat: MatrixExpand<C>,
280 value: SliceExpand<V, ReadOnly>,
281 stride: ExpandElementTyped<u32>,
282 ) {
283 let stride: ExpandElement = stride.into();
284 assert_ne!(
285 mat.ident,
286 MatrixIdent::Accumulator,
287 "Loading accumulator requires explicit layout. Use `load_with_layout` instead."
288 );
289
290 let (value, offset) = value.__to_raw_parts();
291
292 scope.register(Instruction::new(
293 ir::CoopMma::Load {
294 value,
295 stride: *stride,
296 offset,
297 layout: None,
298 },
299 *mat.elem,
300 ));
301 }
302}
303
304#[allow(unused_variables)]
307pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
308 mat: &Matrix<C>,
309 value: &Slice<V>,
310 stride: u32,
311 layout: MatrixLayout,
312) {
313 unexpanded!()
314}
315
316pub mod load_with_layout {
318 use super::*;
319
320 #[allow(unused_variables)]
322 pub fn expand<C: CubeType, V: CubePrimitive>(
323 scope: &mut Scope,
324 mat: MatrixExpand<C>,
325 value: SliceExpand<V, ReadOnly>,
326 stride: ExpandElementTyped<u32>,
327 layout: MatrixLayout,
328 ) {
329 let stride: ExpandElement = stride.into();
330 let (value, offset) = value.__to_raw_parts();
331
332 scope.register(Instruction::new(
333 ir::CoopMma::Load {
334 value,
335 stride: *stride,
336 offset,
337 layout: Some(layout),
338 },
339 *mat.elem,
340 ));
341 }
342}
343
344#[allow(unused_variables)]
346pub fn store<C: CubePrimitive, O: CubePrimitive>(
347 output: &mut SliceMut<O>,
348 mat: &Matrix<C>,
349 stride: u32,
350 layout: MatrixLayout,
351) {
352 unexpanded!()
353}
354
355pub mod store {
357 use crate::prelude::ReadWrite;
358
359 use super::*;
360
361 #[allow(unused_variables)]
363 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
364 scope: &mut Scope,
365 output: SliceExpand<O, ReadWrite>,
366 mat: MatrixExpand<C>,
367 stride: ExpandElementTyped<u32>,
368 layout: MatrixLayout,
369 ) {
370 let stride: ExpandElement = stride.into();
371
372 let (output, offset) = output.__to_raw_parts();
373
374 scope.register(Instruction::new(
375 ir::CoopMma::Store {
376 mat: *mat.elem,
377 offset,
378 stride: *stride,
379 layout,
380 },
381 output,
382 ));
383 }
384}
385
386#[allow(unused_variables)]
388pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
389 mat_a: &Matrix<A>,
390 mat_b: &Matrix<B>,
391 mat_c: &Matrix<C>,
392 mat_d: &Matrix<D>,
393) {
394 unexpanded!()
395}
396
397pub mod execute {
399 use super::*;
400
401 pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
403 scope: &mut Scope,
404 mat_a: MatrixExpand<A>,
405 mat_b: MatrixExpand<B>,
406 mat_c: MatrixExpand<C>,
407 mat_d: MatrixExpand<D>,
408 ) {
409 scope.register(Instruction::new(
410 ir::CoopMma::Execute {
411 mat_a: *mat_a.elem,
412 mat_b: *mat_b.elem,
413 mat_c: *mat_c.elem,
414 },
415 *mat_d.elem,
416 ));
417 }
418}
419
420#[allow(unused_variables)]
422pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
423 unexpanded!()
424}
425
426pub mod cast {
428 use super::*;
429
430 #[allow(unused_variables)]
432 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
433 scope: &mut Scope,
434 input: MatrixExpand<C>,
435 ) -> MatrixExpand<O> {
436 let ident = input.ident;
437
438 if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
439 return MatrixExpand {
440 elem: input.elem,
441 ident,
442 _c: PhantomData,
443 };
444 }
445 let input = *input.elem;
446 let input_mat = match input.kind {
447 ir::VariableKind::Matrix { mat, .. } => mat,
448 _ => unreachable!(),
449 };
450
451 let elem = O::as_elem(scope);
452 let elem = scope.create_matrix(ir::Matrix {
453 ident,
454 m: input_mat.m,
455 n: input_mat.n,
456 k: input_mat.k,
457 elem,
458 layout: MatrixLayout::Undefined,
459 });
460
461 let output = MatrixExpand {
462 ident,
463 elem,
464 _c: PhantomData,
465 };
466 scope.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
467
468 output
469 }
470}
471
472impl CubeType for MatrixLayout {
473 type ExpandType = Self;
474}
475
476impl IntoMut for MatrixLayout {
477 fn into_mut(self, _scope: &mut crate::ir::Scope) -> Self {
478 self
479 }
480}
481
482impl CubeDebug for MatrixLayout {}