cubecl-cpp 0.9.0

CPP transpiler for CubeCL
Documentation
use core::fmt::Formatter;

use crate::{
    Dialect,
    cuda::ptx,
    shared::{Elem, FragmentIdent, MmaShape},
};

#[derive(Debug, Clone, Default, PartialEq)]
pub enum Extension<D: Dialect> {
    #[default]
    NoExtension,
    Mma(MmaExtension<D>),
}

#[derive(Debug, Clone, PartialEq)]
pub enum MmaExtension<D: Dialect> {
    Execute(MmaExecute<D>),
    ExecuteScaled(MmaExecuteScaled<D>),
    LdMatrix(LdMatrix<D>),
    StMatrix(StMatrix<D>),
}

#[derive(Debug, Clone, PartialEq, Copy)]
pub struct Fragment<D: Dialect>(pub Elem<D>);

impl<D: Dialect> Fragment<D> {
    pub fn elem(&self) -> Elem<D> {
        self.0
    }
}

#[derive(new, Debug, Clone, PartialEq)]
pub struct MmaExecute<D: Dialect> {
    pub shape: MmaShape<D>,
    pub frag_a: Fragment<D>,
    pub frag_b: Fragment<D>,
    pub frag_c: Fragment<D>,
    pub frag_d: Fragment<D>,
}

#[derive(new, Debug, Clone, PartialEq)]
pub struct MmaExecuteScaled<D: Dialect> {
    pub shape: MmaShape<D>,
    pub frag_a: Fragment<D>,
    pub frag_b: Fragment<D>,
    pub frag_c: Fragment<D>,
    pub frag_d: Fragment<D>,
    pub scales_elem: Elem<D>,
    pub scales_factor: u32,
}

#[derive(new, Debug, Clone, PartialEq)]
pub struct LdMatrix<D: Dialect> {
    pub elem: Elem<D>,
    pub factor: u32,
    pub transpose: bool,
}

#[derive(new, Debug, Clone, PartialEq)]
pub struct StMatrix<D: Dialect> {
    pub elem: Elem<D>,
    pub factor: u32,
    pub transpose: bool,
}

impl<D: Dialect> MmaExtension<D> {
    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self {
            MmaExtension::Execute(mma_execute) => mma_execute.format_extension(f),
            MmaExtension::ExecuteScaled(mma_execute) => mma_execute.format_extension(f),
            MmaExtension::LdMatrix(ld_matrix) => ld_matrix.format_extension(f),
            MmaExtension::StMatrix(st_matrix) => st_matrix.format_extension(f),
        }
    }
}

impl<D: Dialect> MmaExecute<D> {
    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        let a_elems = self.shape.num_elems(FragmentIdent::<D>::A) / 32;
        let b_elems = self.shape.num_elems(FragmentIdent::<D>::B) / 32;
        let c_elems = self.shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
        let d_elems = self.shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;

        let a_regs = a_elems as usize / (32 / self.frag_a.elem().unpacked().size_bits());
        let b_regs = b_elems as usize / (32 / self.frag_b.elem().unpacked().size_bits());
        let c_regs = c_elems as usize / (32 / self.frag_c.elem().unpacked().size_bits());
        let d_regs = d_elems as usize / (32 / self.frag_d.elem().unpacked().size_bits());

        let ptx = ptx::mma_template(
            self.frag_a.elem().unpacked(),
            self.frag_b.elem().unpacked(),
            self.frag_c.elem().unpacked(),
            self.shape.k,
            a_regs,
            b_regs,
            c_regs,
            d_regs,
        );

        f.write_str(&ptx)
    }
}

impl<D: Dialect> MmaExecuteScaled<D> {
    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        let a_elems = self.shape.num_elems(FragmentIdent::<D>::A) / 32;
        let b_elems = self.shape.num_elems(FragmentIdent::<D>::B) / 32;
        let c_elems = self.shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
        let d_elems = self.shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;

        let a_regs = a_elems as usize / (32 / self.frag_a.elem().unpacked().size_bits());
        let b_regs = b_elems as usize / (32 / self.frag_b.elem().unpacked().size_bits());
        let c_regs = c_elems as usize / (32 / self.frag_c.elem().unpacked().size_bits());
        let d_regs = d_elems as usize / (32 / self.frag_d.elem().unpacked().size_bits());

        let ptx = ptx::mma_scaled_template(
            self.frag_a.elem().unpacked(),
            self.frag_b.elem().unpacked(),
            self.frag_c.elem().unpacked(),
            self.shape.k,
            a_regs,
            b_regs,
            c_regs,
            d_regs,
            self.scales_elem,
            self.scales_factor,
        );

        f.write_str(&ptx)
    }
}

impl<D: Dialect> LdMatrix<D> {
    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        let ptx = ptx::ldmatrix_template(self.elem, self.factor, self.transpose);

        f.write_str(&ptx)
    }
}

impl<D: Dialect> StMatrix<D> {
    pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        let ptx = ptx::stmatrix_template(self.elem, self.factor, self.transpose);

        f.write_str(&ptx)
    }
}