1use std::marker::PhantomData;
50
51use crate::{
52 ir::{self, Instruction},
53 unexpanded,
54};
55
56use super::{CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, Init, Slice, SliceMut};
57
58use cubecl_ir::{ExpandElement, Scope};
59pub use ir::{MatrixIdent, MatrixLayout};
60
61#[derive(Copy, Clone)]
66pub struct Matrix<C: CubeType> {
67 _c: PhantomData<C>,
68}
69
70pub struct MatrixExpand<C: CubeType> {
72 elem: ExpandElement,
73 ident: MatrixIdent,
74 _c: PhantomData<C>,
75}
76
77impl<C: CubeType> Clone for MatrixExpand<C> {
78 fn clone(&self) -> Self {
79 Self {
80 elem: self.elem.clone(),
81 ident: self.ident,
82 _c: self._c,
83 }
84 }
85}
86
87impl<C: CubeType> CubeType for Matrix<C> {
88 type ExpandType = MatrixExpand<C>;
89}
90
91impl<C: CubeType> Init for MatrixExpand<C> {
92 fn init(self, _scope: &mut Scope) -> Self {
93 self
94 }
95}
96
97impl<C: CubeType> CubeDebug for MatrixExpand<C> {
98 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
99 scope.update_variable_name(*self.elem, name);
100 }
101}
102
103impl<C: CubePrimitive> Matrix<C> {
104 #[allow(unused_variables)]
122 pub unsafe fn uninitialized(
123 ident: MatrixIdent,
124 m: u32,
125 n: u32,
126 k: u32,
127 layout: MatrixLayout,
128 ) -> Self {
129 Matrix { _c: PhantomData }
130 }
131
132 #[allow(unused_variables)]
146 pub fn from_value(
147 ident: MatrixIdent,
148 m: u32,
149 n: u32,
150 k: u32,
151 layout: MatrixLayout,
152 value: C,
153 ) -> Self {
154 Matrix { _c: PhantomData }
155 }
156
157 #[allow(unused_variables)]
171 pub fn from_slice(
172 ident: MatrixIdent,
173 m: u32,
174 n: u32,
175 k: u32,
176 layout: MatrixLayout,
177 value: &Slice<C>,
178 stride: u32,
179 ) -> Self {
180 Matrix { _c: PhantomData }
181 }
182
183 pub fn __expand_uninitialized(
184 scope: &mut Scope,
185 ident: MatrixIdent,
186 m: ExpandElementTyped<u32>,
187 n: ExpandElementTyped<u32>,
188 k: ExpandElementTyped<u32>,
189 layout: MatrixLayout,
190 ) -> MatrixExpand<C> {
191 let elem = C::as_elem(scope);
192 let elem = scope.create_matrix(ir::Matrix {
193 ident,
194 m: m.constant().unwrap().as_u32() as u8,
195 n: n.constant().unwrap().as_u32() as u8,
196 k: k.constant().unwrap().as_u32() as u8,
197 elem,
198 layout,
199 });
200 MatrixExpand {
201 elem,
202 ident,
203 _c: PhantomData,
204 }
205 }
206
207 pub fn __expand_from_value(
208 scope: &mut Scope,
209 ident: MatrixIdent,
210 m: ExpandElementTyped<u32>,
211 n: ExpandElementTyped<u32>,
212 k: ExpandElementTyped<u32>,
213 layout: MatrixLayout,
214 value: ExpandElementTyped<C>,
215 ) -> MatrixExpand<C> {
216 let mat = Self::__expand_uninitialized(scope, ident, m, n, k, layout);
217 fill::expand(scope, mat.clone(), value);
218 mat
219 }
220
221 #[allow(clippy::too_many_arguments)]
222 pub fn __expand_from_slice(
223 scope: &mut Scope,
224 ident: MatrixIdent,
225 m: ExpandElementTyped<u32>,
226 n: ExpandElementTyped<u32>,
227 k: ExpandElementTyped<u32>,
228 layout: MatrixLayout,
229 value: ExpandElementTyped<Slice<C>>,
230 stride: ExpandElementTyped<u32>,
231 ) -> MatrixExpand<C> {
232 let mat = Self::__expand_uninitialized(scope, ident, m, n, k, layout);
233 load::expand(scope, mat.clone(), value, stride);
234 mat
235 }
236}
237
238#[allow(unused_variables)]
240pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
241 unexpanded!()
242}
243
244pub mod fill {
246 use super::*;
247
248 pub fn expand<C: CubeType>(
250 scope: &mut Scope,
251 mat: MatrixExpand<C>,
252 value: ExpandElementTyped<C>,
253 ) {
254 let value: ExpandElement = value.into();
255 scope.register(Instruction::new(
256 ir::CoopMma::Fill { value: *value },
257 *mat.elem,
258 ));
259 }
260}
261
262#[allow(unused_variables)]
264pub fn load<C: CubePrimitive, V: CubePrimitive>(mat: &Matrix<C>, value: &Slice<V>, stride: u32) {
265 unexpanded!()
266}
267
268pub mod load {
270 use super::*;
271
272 #[allow(unused_variables)]
274 pub fn expand<C: CubePrimitive, V: CubePrimitive>(
275 scope: &mut Scope,
276 mat: MatrixExpand<C>,
277 value: ExpandElementTyped<Slice<V>>,
278 stride: ExpandElementTyped<u32>,
279 ) {
280 let stride: ExpandElement = stride.into();
281 assert_ne!(
282 mat.ident,
283 MatrixIdent::Accumulator,
284 "Loading accumulator requires explicit layout. Use `load_with_layout` instead."
285 );
286
287 scope.register(Instruction::new(
288 ir::CoopMma::Load {
289 value: *value.expand,
290 stride: *stride,
291 layout: None,
292 },
293 *mat.elem,
294 ));
295 }
296}
297
298#[allow(unused_variables)]
301pub fn load_with_layout<C: CubePrimitive, V: CubePrimitive>(
302 mat: &Matrix<C>,
303 value: &Slice<V>,
304 stride: u32,
305 layout: MatrixLayout,
306) {
307 unexpanded!()
308}
309
310pub mod load_with_layout {
312 use super::*;
313
314 #[allow(unused_variables)]
316 pub fn expand<C: CubeType, V: CubePrimitive>(
317 scope: &mut Scope,
318 mat: MatrixExpand<C>,
319 value: ExpandElementTyped<Slice<V>>,
320 stride: ExpandElementTyped<u32>,
321 layout: MatrixLayout,
322 ) {
323 let stride: ExpandElement = stride.into();
324
325 scope.register(Instruction::new(
326 ir::CoopMma::Load {
327 value: *value.expand,
328 stride: *stride,
329 layout: Some(layout),
330 },
331 *mat.elem,
332 ));
333 }
334}
335
336#[allow(unused_variables)]
338pub fn store<C: CubePrimitive, O: CubePrimitive>(
339 output: &mut SliceMut<O>,
340 mat: &Matrix<C>,
341 stride: u32,
342 layout: MatrixLayout,
343) {
344 unexpanded!()
345}
346
347pub mod store {
349 use super::*;
350
351 #[allow(unused_variables)]
353 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
354 scope: &mut Scope,
355 output: ExpandElementTyped<SliceMut<O>>,
356 mat: MatrixExpand<C>,
357 stride: ExpandElementTyped<u32>,
358 layout: MatrixLayout,
359 ) {
360 let stride: ExpandElement = stride.into();
361
362 scope.register(Instruction::new(
363 ir::CoopMma::Store {
364 mat: *mat.elem,
365 stride: *stride,
366 layout,
367 },
368 *output.expand,
369 ));
370 }
371}
372
373#[allow(unused_variables)]
375pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
376 mat_a: &Matrix<A>,
377 mat_b: &Matrix<B>,
378 mat_c: &Matrix<C>,
379 mat_d: &Matrix<D>,
380) {
381 unexpanded!()
382}
383
384pub mod execute {
386 use super::*;
387
388 pub fn expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
390 scope: &mut Scope,
391 mat_a: MatrixExpand<A>,
392 mat_b: MatrixExpand<B>,
393 mat_c: MatrixExpand<C>,
394 mat_d: MatrixExpand<D>,
395 ) {
396 scope.register(Instruction::new(
397 ir::CoopMma::Execute {
398 mat_a: *mat_a.elem,
399 mat_b: *mat_b.elem,
400 mat_c: *mat_c.elem,
401 },
402 *mat_d.elem,
403 ));
404 }
405}
406
407#[allow(unused_variables)]
409pub fn cast<C: CubePrimitive, O: CubePrimitive>(input: &Matrix<C>) -> Matrix<O> {
410 unexpanded!()
411}
412
413pub mod cast {
415 use super::*;
416
417 #[allow(unused_variables)]
419 pub fn expand<C: CubePrimitive, O: CubePrimitive>(
420 scope: &mut Scope,
421 input: MatrixExpand<C>,
422 ) -> MatrixExpand<O> {
423 let ident = input.ident;
424
425 if core::any::TypeId::of::<C>() == core::any::TypeId::of::<O>() {
426 return MatrixExpand {
427 elem: input.elem,
428 ident,
429 _c: PhantomData,
430 };
431 }
432 let input = *input.elem;
433 let input_mat = match input.kind {
434 ir::VariableKind::Matrix { mat, .. } => mat,
435 _ => unreachable!(),
436 };
437
438 let elem = O::as_elem(scope);
439 let elem = scope.create_matrix(ir::Matrix {
440 ident,
441 m: input_mat.m,
442 n: input_mat.n,
443 k: input_mat.k,
444 elem,
445 layout: MatrixLayout::Undefined,
446 });
447
448 let output = MatrixExpand {
449 ident,
450 elem,
451 _c: PhantomData,
452 };
453 scope.register(Instruction::new(ir::CoopMma::Cast { input }, *output.elem));
454
455 output
456 }
457}
458
459impl CubeType for MatrixLayout {
460 type ExpandType = Self;
461}
462
463impl Init for MatrixLayout {
464 fn init(self, _scope: &mut crate::ir::Scope) -> Self {
465 self
466 }
467}
468
469impl CubeDebug for MatrixLayout {}