use std::marker::PhantomData;
use crate::{
ir::{self, Operation},
unexpanded,
};
use super::{
CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, Slice, SliceMut,
UInt,
};
pub use ir::{MatrixIdent, MatrixLayout};
#[derive(Copy, Clone)]
pub struct Matrix<C: CubeType> {
_c: PhantomData<C>,
}
#[derive(Clone)]
pub struct MatrixExpand {
elem: ExpandElement,
}
impl<C: CubeType> CubeType for Matrix<C> {
type ExpandType = MatrixExpand;
}
impl Init for MatrixExpand {
fn init(self, _context: &mut CubeContext) -> Self {
self
}
}
impl<C: CubePrimitive> Matrix<C> {
#[allow(unused_variables)]
pub fn new(ident: MatrixIdent, m: u32, n: u32, k: u32, layout: MatrixLayout) -> Self {
Matrix { _c: PhantomData }
}
pub fn __expand_new(
context: &mut CubeContext,
ident: MatrixIdent,
m: ExpandElementTyped<UInt>,
n: ExpandElementTyped<UInt>,
k: ExpandElementTyped<UInt>,
layout: MatrixLayout,
) -> MatrixExpand {
let elem = context.create_matrix(ir::Matrix {
ident,
m: m.constant().unwrap().as_u32() as u8,
n: n.constant().unwrap().as_u32() as u8,
k: k.constant().unwrap().as_u32() as u8,
elem: C::as_elem(),
layout,
});
MatrixExpand { elem }
}
}
#[allow(unused_variables)]
pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
unexpanded!()
}
pub mod fill {
use super::*;
pub fn __expand<C: CubeType>(
context: &mut CubeContext,
mat: MatrixExpand,
value: ExpandElementTyped<C>,
) {
let value: ExpandElement = value.into();
context.register(Operation::CoopMma(ir::CoopMma::Fill {
mat: *mat.elem,
value: *value,
}));
}
}
#[allow(unused_variables)]
pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Slice<'_, C>, stride: UInt) {
unexpanded!()
}
pub mod load {
use super::*;
#[allow(unused_variables)]
pub fn __expand<C: CubeType>(
context: &mut CubeContext,
mat: MatrixExpand,
value: ExpandElementTyped<Slice<'static, C>>,
stride: ExpandElementTyped<UInt>,
) {
let stride: ExpandElement = stride.into();
context.register(Operation::CoopMma(ir::CoopMma::Load {
mat: *mat.elem,
value: *value.expand,
stride: *stride,
}));
}
}
#[allow(unused_variables)]
pub fn store<C: CubePrimitive>(
output: &mut SliceMut<'_, C>,
mat: &Matrix<C>,
stride: UInt,
layout: MatrixLayout,
) {
unexpanded!()
}
pub mod store {
use super::*;
#[allow(unused_variables)]
pub fn __expand<C: CubePrimitive>(
context: &mut CubeContext,
output: ExpandElementTyped<SliceMut<'static, C>>,
mat: MatrixExpand,
stride: ExpandElementTyped<UInt>,
layout: MatrixLayout,
) {
let stride: ExpandElement = stride.into();
context.register(Operation::CoopMma(ir::CoopMma::Store {
output: *output.expand,
mat: *mat.elem,
stride: *stride,
layout,
}));
}
}
#[allow(unused_variables)]
pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
mat_a: &Matrix<A>,
mat_b: &Matrix<B>,
mat_c: &Matrix<C>,
mat_d: &Matrix<D>,
) {
unexpanded!()
}
pub mod execute {
use super::*;
pub fn __expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
context: &mut CubeContext,
mat_a: MatrixExpand,
mat_b: MatrixExpand,
mat_c: MatrixExpand,
mat_d: MatrixExpand,
) {
context.register(Operation::CoopMma(ir::CoopMma::Execute {
mat_a: *mat_a.elem,
mat_b: *mat_b.elem,
mat_c: *mat_c.elem,
mat_d: *mat_d.elem,
}));
}
}