use std::fmt::Formatter;
use crate::{
Dialect,
hip::{HipDialect, arch::AMDArchitecture},
shared::{
Architecture, Component, DialectWmmaCompiler, Elem, Flags, FmtLeft, Fragment,
FragmentIdent, FragmentLayout, Item, ManualMma, MmaShape, SupportedMmaCombinations,
Variable, WmmaInstruction, frag_as_ptr, frag_ident_str, frag_layout_str, variable_to_frag,
wmma_api_base,
},
};
use cubecl_core::ir::{self as gpu, Matrix, MatrixIdent, features::MmaConfig};
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
pub struct WmmaIntrinsicCompiler {}
#[derive(new, Debug, Clone, PartialEq)]
pub struct WmmaFill<D: Dialect> {
frag: Fragment<D>,
}
#[derive(new, Debug, Clone, PartialEq)]
pub struct WmmaLoad<D: Dialect> {
frag: Fragment<D>,
layout: Option<FragmentLayout<D>>,
}
#[derive(new, Debug, Clone, PartialEq)]
pub struct WmmaStore<D: Dialect> {
frag: Fragment<D>,
layout: FragmentLayout<D>,
}
#[derive(new, Debug, Clone, PartialEq)]
pub struct WmmaExecute<D: Dialect> {
frag_a: Fragment<D>,
frag_b: Fragment<D>,
frag_c: Fragment<D>,
frag_d: Fragment<D>,
}
#[derive(new, Debug, Clone, PartialEq)]
pub struct WmmaCast<D: Dialect> {
frag_input: Fragment<D>,
frag_output: Fragment<D>,
}
impl<D: Dialect> WmmaFill<D> {
pub fn fn_name(&self) -> String {
let layout = frag_layout_str(&self.frag.layout);
let ident = frag_ident_str(&self.frag.ident);
let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
let elem = self.frag.elem;
format!("wmma_fill_{elem}_{ident}_{m}x{n}x{k}_{layout}",)
}
pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let elem = self.frag.elem;
let frag = self.frag;
let name = self.fn_name();
write!(
f,
"
// Fill the fragment.
__device__ void {name}({frag}& frag, {elem} value) {{
#pragma unroll
for (uint i = 0; i < 8; ++i) {{
frag[i] = value;
}}
}}
"
)
}
}
impl<D: Dialect> WmmaLoad<D> {
pub fn fn_name(&self) -> String {
let layout_frag = frag_layout_str(&self.frag.layout);
let layout = frag_layout_str(&self.layout);
let ident = frag_ident_str(&self.frag.ident);
let elem = self.frag.elem;
let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
format!("wmma_load_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
}
pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let elem = self.frag.elem;
let frag = self.frag;
let name = self.fn_name();
let (index_body, length, step) = match frag.ident {
FragmentIdent::A | FragmentIdent::B => {
let length = 16;
let step = 1;
let index = if (frag.ident == FragmentIdent::A
&& frag.layout.unwrap() == FragmentLayout::ColMajor)
|| (frag.ident == FragmentIdent::B
&& frag.layout.unwrap() == FragmentLayout::RowMajor)
{
"i * stride + wmmaLane".to_string()
} else {
"i + wmmaLane * stride".to_string()
};
(index, length, step)
}
FragmentIdent::Accumulator => {
let length = 8;
let step = get_output_accumulator_index_step(&elem, &frag);
let index = match self.layout {
Some(FragmentLayout::ColMajor) => {
"(i * uint(2) + threadIdx.x / uint(16)) + wmmaLane * stride".to_string()
}
Some(FragmentLayout::RowMajor) => {
"(i * uint(2) + threadIdx.x / uint(16)) * stride + wmmaLane".to_string()
}
_ => panic!(
"cannot load data to an accumulator without knowing the layout of the data"
),
};
(index, length, step)
}
other => panic!("unknown matrix identifier {other}"),
};
write!(
f,
"
// Load the fragment.
__device__ void {name}({frag}& frag, const {elem}* value_ptr, const uint stride) {{
{WMMA_LANE_DEF}
#pragma unroll
for (uint i = 0; i < {length}; ++i) {{
const uint index = {index_body};
frag[i * {step}] = value_ptr[index];
}}
}}
"
)
}
}
impl<D: Dialect> WmmaStore<D> {
pub fn fn_name(&self) -> String {
let layout_frag = frag_layout_str(&self.frag.layout);
let layout_option = Some(self.layout);
let layout = frag_layout_str(&layout_option);
let ident = frag_ident_str(&self.frag.ident);
let (m, n, k) = (self.frag.m, self.frag.n, self.frag.k);
let elem = self.frag.elem;
format!("wmma_store_{elem}_{ident}_{m}x{n}x{k}_{layout_frag}_{layout}",)
}
pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let elem = self.frag.elem;
let frag = self.frag;
let name = self.fn_name();
let frag_idx = match elem {
Elem::F16 | Elem::BF16 => "elemIdx * 2",
Elem::F32 => "elemIdx",
other => {
panic!("C fragment format cannot be {other}. Only f16, bf16 and f32 are supported.")
}
};
let output_idx = match self.layout {
FragmentLayout::ColMajor => "wmmaLane * stride + rowIdx".to_string(),
FragmentLayout::RowMajor => "wmmaLane + rowIdx * stride".to_string(),
FragmentLayout::_Dialect(_) => String::new(),
};
write!(
f,
"
// Store the fragment.
__device__ void {name}({frag}& frag, {elem}* output_ptr, uint stride) {{
{WMMA_LANE_DEF}
#pragma unroll
for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
const uint rowIdx = elemIdx * uint(2) + threadIdx.x / uint(16);
output_ptr[{output_idx}] = frag[{frag_idx}];
}}
}}
"
)
}
}
impl<D: Dialect> WmmaExecute<D> {
pub fn from_manual(shape: MmaShape<D>, ab_elem: Elem<D>, cd_elem: Elem<D>) -> Self {
let frag_a = Fragment {
ident: FragmentIdent::A,
m: shape.m,
n: shape.n,
k: shape.k,
elem: ab_elem,
layout: Some(FragmentLayout::ColMajor),
};
let frag_b = Fragment {
ident: FragmentIdent::B,
layout: Some(FragmentLayout::RowMajor),
..frag_a
};
let frag_cd = Fragment {
ident: FragmentIdent::Accumulator,
elem: cd_elem,
..frag_b
};
WmmaExecute::new(frag_a, frag_b, frag_cd, frag_cd)
}
pub fn fn_name(&self) -> String {
format!(
"wmma_execute_16x16x16_{}_{}",
self.frag_a.elem, self.frag_c.elem
)
}
pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let name = self.fn_name();
let ab_format = match self.frag_a.elem {
Elem::F32 => "f32",
Elem::BF16 => "bf16",
Elem::F16 => "f16",
_ => panic!(),
};
let (cd_format, opsel) = match self.frag_c.elem {
Elem::F32 => ("f32", ""),
Elem::BF16 => ("bf16", ", false"),
Elem::F16 => ("f16", ", false"),
_ => panic!(),
};
let warp_size = 32;
write!(
f,
"
// Execute wmma.
__device__ void {name}(const {}& frag_a, const {}& frag_b, const {}& frag_c, {}& frag_d) {{
frag_d = __builtin_amdgcn_wmma_{cd_format}_16x16x16_{ab_format}_w{warp_size}(frag_a, frag_b, frag_c{opsel});
}}
", self.frag_a, self.frag_b, self.frag_c, self.frag_d
)
}
}
impl<D: Dialect> WmmaCast<D> {
pub fn fn_name(&self) -> String {
let layout = frag_layout_str(&self.frag_input.layout);
let ident = frag_ident_str(&self.frag_input.ident);
let (m, n, k) = (self.frag_input.m, self.frag_input.n, self.frag_input.k);
let elem = self.frag_input.elem;
let elem_out = self.frag_output.elem;
format!("wmma_cast_{elem}_to_{elem_out}_{ident}_{m}x{n}x{k}_{layout}",)
}
pub fn format_extension(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let input = self.frag_input;
let output = self.frag_output;
let name = self.fn_name();
let step = match output.ident {
FragmentIdent::Accumulator => {
get_output_accumulator_index_step(&self.frag_input.elem, &output)
}
_ => 1,
};
write!(
f,
"
// Cast the fragment.
__device__ void {name}({input}& input, {output}& output) {{
#pragma unroll
for (uint elemIdx = 0; elemIdx < uint(8); ++elemIdx) {{
output[elemIdx * {step}] = input[elemIdx];
}}
}}
"
)
}
}
impl DialectWmmaCompiler<HipDialect<Self>> for WmmaIntrinsicCompiler {
fn compile_wmma_type_definitions(
f: &mut std::fmt::Formatter<'_>,
flags: &Flags<HipDialect<Self>>,
) -> std::fmt::Result {
if flags.elem_bf16 {
f.write_str("typedef __bf16 bhalf8_t __attribute__((ext_vector_type(8)));\n")?;
f.write_str("typedef __bf16 bhalf16_t __attribute__((ext_vector_type(16)));\n")?;
}
if flags.elem_f16 {
f.write_str("typedef _Float16 half8_t __attribute__((ext_vector_type(8)));\n")?;
f.write_str("typedef _Float16 half16_t __attribute__((ext_vector_type(16)));\n")?;
}
f.write_str("typedef float float8_t __attribute__((ext_vector_type(8)));\n")
}
fn compile_wmma_fragment_declaration(
f: &mut std::fmt::Formatter<'_>,
var: &crate::shared::Variable<HipDialect<Self>>,
) -> std::fmt::Result {
wmma_api_base::compile_fragment_declaration(f, var)
}
fn compile_wmma_fragment(
f: &mut std::fmt::Formatter<'_>,
fragment: &Fragment<HipDialect<Self>>,
) -> std::fmt::Result {
match fragment.ident {
FragmentIdent::A | FragmentIdent::B => match fragment.elem {
Elem::F16 => write!(f, "half16_t"),
Elem::BF16 => write!(f, "bhalf16_t"),
other => panic!("unsupported type {other} for {fragment}"),
},
FragmentIdent::Accumulator => match fragment.elem {
Elem::F16 => write!(f, "half16_t"),
Elem::BF16 => write!(f, "bhalf16_t"),
Elem::F32 => write!(f, "float8_t"),
other => panic!("unsupported type {other} for {fragment}"),
},
FragmentIdent::_Dialect(_) => Ok(()),
}
}
fn compile_wmma_instruction(
f: &mut std::fmt::Formatter<'_>,
instruction: &WmmaInstruction<HipDialect<Self>>,
) -> std::fmt::Result {
match instruction {
WmmaInstruction::Fill { frag, value } => {
let extension = WmmaFill::new(match frag {
Variable::WmmaFragment { frag, .. } => *frag,
_ => panic!(),
});
let name = extension.fn_name();
writeln!(f, "{name}({frag}, {value});")
}
WmmaInstruction::Load {
frag,
value,
layout,
offset,
stride,
} => {
let extension = WmmaLoad::new(variable_to_frag(frag), *layout);
let name = extension.fn_name();
let value_ptr = frag_as_ptr(f, value, offset);
writeln!(f, "{name}({frag}, {value_ptr}, {stride});")
}
WmmaInstruction::LdMatrix { .. } | WmmaInstruction::StMatrix { .. } => {
f.write_str("#error LdMatrix & StMatrix are not supported on HIP\n")
}
WmmaInstruction::Execute {
frag_a,
frag_b,
frag_c,
frag_d,
warp_size,
} => {
if *warp_size != 32 {
f.write_str(
"#error Only warp size of 32 supported for Wmma::Execute on HIP\n",
)?;
}
let extension = WmmaExecute::new(
variable_to_frag(frag_a),
variable_to_frag(frag_b),
variable_to_frag(frag_c),
variable_to_frag(frag_d),
);
let name = extension.fn_name();
writeln!(f, "{name}({frag_a}, {frag_b}, {frag_c}, {frag_d});")
}
WmmaInstruction::ExecuteManual {
shape,
frag_a,
frag_b,
frag_c,
frag_d,
} => {
Self::compile_manual_mma(f, ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d))
}
WmmaInstruction::ExecuteScaled {
shape,
frag_a,
frag_b,
frag_c,
frag_d,
scales_a,
scales_b,
scales_factor,
} => Self::compile_scaled_mma(
f,
ManualMma::new(*shape, frag_a, frag_b, frag_c, frag_d),
*scales_a,
*scales_b,
*scales_factor,
),
WmmaInstruction::Store {
output,
frag,
layout,
offset,
stride,
} => {
let extension = WmmaStore::new(variable_to_frag(frag), *layout);
let name = extension.fn_name();
let output_ptr = frag_as_ptr(f, output, offset);
writeln!(f, "{name}({frag}, {output_ptr}, {stride});")
}
WmmaInstruction::Cast { input, output } => {
let extension = WmmaCast::new(variable_to_frag(input), variable_to_frag(output));
let name = extension.fn_name();
writeln!(f, "{name}({input}, {output});")
}
}
}
fn compile_manual_mma(
f: &mut std::fmt::Formatter<'_>,
mma: ManualMma<HipDialect<Self>>,
) -> std::fmt::Result {
compile_manual_mma(f, mma.shape, mma.frag_a, mma.frag_b, mma.frag_c, mma.frag_d)
}
fn compile_scaled_mma(
f: &mut std::fmt::Formatter<'_>,
_mma: ManualMma<HipDialect<Self>>,
_scales_a: Variable<HipDialect<Self>>,
_scales_b: Variable<HipDialect<Self>>,
_scales_factor: u32,
) -> std::fmt::Result {
f.write_str("#error scaled mma not supported in HIP\n")
}
fn supported_wmma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
let mut result: SupportedMmaCombinations = vec![];
if arch.is_wmma_capable() {
let types = vec![
(
gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), gpu::ElemType::Float(gpu::FloatKind::F16), ),
(
gpu::ElemType::Float(gpu::FloatKind::F16),
gpu::ElemType::Float(gpu::FloatKind::F16),
gpu::ElemType::Float(gpu::FloatKind::F32),
),
(
gpu::ElemType::Float(gpu::FloatKind::BF16),
gpu::ElemType::Float(gpu::FloatKind::BF16),
gpu::ElemType::Float(gpu::FloatKind::F32),
),
];
let combinations: SupportedMmaCombinations = types
.into_iter()
.map(|(a, b, c)| MmaConfig {
a_type: a.into(),
b_type: b.into(),
cd_type: c.into(),
m: 16,
n: 16,
k: 16,
})
.collect();
result.extend(combinations);
}
result
}
fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
supported_mma_combinations(arch)
}
}
fn get_output_accumulator_index_step<D: Dialect>(
input_elem: &Elem<D>,
output: &Fragment<D>,
) -> u32 {
assert_eq!(output.ident, FragmentIdent::<D>::Accumulator);
match input_elem {
Elem::F16 | Elem::BF16 | Elem::F32 => {
match output.elem {
Elem::F16 | Elem::BF16 => 2,
Elem::F32 => 1,
other => panic!("unsupported format {other} for {output}"),
}
}
other => panic!("unsupported format {other} for {input_elem}"),
}
}
pub(super) fn compile_manual_mma<D: Dialect>(
f: &mut std::fmt::Formatter<'_>,
shape: MmaShape<D>,
frag_a: &Variable<D>,
frag_b: &Variable<D>,
frag_c: &Variable<D>,
frag_d: &Variable<D>,
) -> std::fmt::Result {
let extension = WmmaExecute::from_manual(shape, frag_a.elem(), frag_c.elem());
let cd_elems = shape.num_elems(FragmentIdent::<D>::Accumulator) / 32;
let frag_cd_step = 4usize.div_ceil(frag_c.elem().size());
let frag_d_tmp = Variable::tmp_declared(Item::new(Elem::<D>::I32, 1, true)).fmt_left();
let frag = |var: &Variable<D>, len: usize| {
let vec = var.item().vectorization;
let frag: Vec<_> = if vec > 1 {
(0..len)
.map(|i| format!("{var}[{}].i_{}", i / vec, i % vec))
.collect()
} else {
(0..len).map(|i| format!("{var}[{}]", i)).collect()
};
frag.join(", ")
};
let frag_a = frag(frag_a, 16);
let frag_b = frag(frag_b, 16);
let frag_c = {
let vec = frag_c.item().vectorization;
let frag: Vec<_> = if vec > 1 {
(0..cd_elems as usize)
.flat_map(|i| {
(0..frag_cd_step).map(move |_| format!("{frag_c}[{}].i_{}", i / vec, i % vec))
})
.collect()
} else {
(0..cd_elems as usize)
.flat_map(|i| (0..frag_cd_step).map(move |_| format!("{frag_c}[{}]", i)))
.collect()
};
frag.join(", ")
};
let name = extension.fn_name();
writeln!(f, "{} {frag_d_tmp} = {{}};", extension.frag_d)?;
writeln!(
f,
"{name}({}{{{frag_a}}}, {}{{{frag_b}}}, {}{{{frag_c}}}, {frag_d_tmp});",
extension.frag_a, extension.frag_b, extension.frag_c
)?;
for i in 0..cd_elems as usize {
let vec = frag_d.item().vectorization;
if vec > 1 {
writeln!(
f,
"{frag_d}[{}].i_{} = {frag_d_tmp}[{i} * {frag_cd_step}];",
i / vec,
i % vec
)?;
} else {
writeln!(f, "{frag_d}[{i}] = {frag_d_tmp}[{i} * {frag_cd_step}];")?;
}
}
Ok(())
}
pub(super) fn supported_mma_combinations(arch: &AMDArchitecture) -> SupportedMmaCombinations {
const ENABLED: bool = true;
if !ENABLED {
return Vec::new();
}
let mut result: SupportedMmaCombinations = vec![];
if arch.is_wmma_capable() {
let types = vec![
(
gpu::ElemType::Float(gpu::FloatKind::F16),
gpu::ElemType::Float(gpu::FloatKind::F32),
),
(
gpu::ElemType::Float(gpu::FloatKind::BF16),
gpu::ElemType::Float(gpu::FloatKind::F32),
),
];
let combinations = types.into_iter().map(|(ab_elem, cd_elem)| MmaConfig {
a_type: ab_elem.into(),
b_type: ab_elem.into(),
cd_type: cd_elem.into(),
m: 16,
n: 16,
k: 16,
});
result.extend(combinations);
}
result
}
pub fn contiguous_elements_rdna3(ident: MatrixIdent, matrix: Matrix) -> usize {
let max_vector_size = 16 / matrix.storage.size();
match ident {
MatrixIdent::A | MatrixIdent::B => 16.min(max_vector_size),
MatrixIdent::Accumulator => 1,
}
}
static WMMA_LANE_DEF: &str = "uint wmmaLane = uint(threadIdx.x % 16);";