use std::{collections::HashSet, fmt::Debug};
use std::{fmt::Display, hash::Hash};
use cubecl_core::ir::{ConstantValue, Processor};
use crate::shared::{
FmtLeft, IndexedVariable, MmaShape, SupportedMmaCombinations, SupportedScaledMmaCombinations,
reduce_comparison, reduce_exclusive, reduce_inclusive, reduce_operator, reduce_quantifier,
};
use super::{
Architecture, AtomicKind, Body, Component, CubeIndexFlags, Elem, Flags, Fragment,
FragmentIdent, FragmentLayout, Instruction, Item, KernelArg, SharedMemory, Variable,
WarpInstruction, WmmaInstruction,
};
pub trait Dialect:
DialectIncludes<Self>
+ DialectTypes<Self>
+ DialectBindings<Self>
+ DialectWarpReduceCompiler<Self>
+ DialectCubeBuiltins<Self>
+ DialectInstructions<Self>
+ DialectWmmaCompiler<Self>
+ DialectProcessors<Self>
+ Default
+ Clone
+ Copy
+ Debug
+ Send
+ Sync
+ Eq
+ Hash
+ 'static
{
type Architecture: Architecture;
}
pub trait DialectIncludes<D: Dialect> {
type Extension: Debug + Clone + Sync + Send;
fn compile_includes(f: &mut std::fmt::Formatter<'_>, flags: &Flags<D>) -> std::fmt::Result;
fn compile_extensions(
f: &mut std::fmt::Formatter<'_>,
extensions: &[Self::Extension],
) -> std::fmt::Result;
fn register_instruction_extension(
extensions: &mut Vec<Self::Extension>,
instruction: &Instruction<D>,
);
fn register_warp_instruction_extension(
extensions: &mut Vec<Self::Extension>,
instruction: &WarpInstruction<D>,
);
#[allow(unused_variables)]
fn register_wmma_instruction_extension(
extensions: &mut Vec<Self::Extension>,
instruction: &WmmaInstruction<D>,
) {
}
}
pub trait DialectTypes<D: Dialect> {
fn item_can_be_optimized() -> bool;
fn compile_elem(
f: &mut std::fmt::Formatter<'_>,
elem: &Elem<D>,
word: bool,
) -> std::fmt::Result;
fn compile_atomic_kind(
f: &mut std::fmt::Formatter<'_>,
kind: &AtomicKind<D>,
) -> std::fmt::Result {
match kind {
AtomicKind::I32 => write!(f, "{}", Elem::<D>::I32),
AtomicKind::I64 => write!(f, "{}", Elem::<D>::I64),
AtomicKind::U32 => write!(f, "{}", Elem::<D>::U32),
AtomicKind::U64 => write!(f, "{}", Elem::<D>::U64),
AtomicKind::F16 => write!(f, "{}", Elem::<D>::F16),
AtomicKind::F16x2 => write!(f, "{}", Elem::<D>::F16x2),
AtomicKind::BF16 => write!(f, "{}", Elem::<D>::BF16),
AtomicKind::BF16x2 => write!(f, "{}", Elem::<D>::BF16x2),
AtomicKind::F32 => write!(f, "{}", Elem::<D>::F32),
AtomicKind::F64 => write!(f, "{}", Elem::<D>::F64),
AtomicKind::_Dialect(_) => Ok(()),
}
}
fn compile_item(f: &mut std::fmt::Formatter<'_>, item: &Item<D>) -> std::fmt::Result;
fn compile_type_definitions(
f: &mut std::fmt::Formatter<'_>,
items: &HashSet<Item<D>>,
scalars: &[(Elem<D>, usize)],
info: &cubecl_core::Info,
flags: &Flags<D>,
) -> std::fmt::Result;
fn compile_local_memory_qualifier(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
fn compile_shared_memory_declaration(
f: &mut std::fmt::Formatter<'_>,
shared: &SharedMemory<D>,
) -> std::fmt::Result {
match shared {
SharedMemory::Array {
index,
item,
length,
offset,
..
} => {
let size_bytes = *length * item.size();
writeln!(f, "// Shared array size: {length}, {size_bytes} bytes")?;
writeln!(
f,
"{item} *shared_memory_{index} = reinterpret_cast<{item}*>(&dynamic_shared_mem[{offset}]);"
)
}
SharedMemory::Value {
index,
item,
offset,
..
} => {
let size_bytes = item.size() as u32;
writeln!(f, "// Shared value size: {size_bytes} bytes")?;
writeln!(
f,
"{item} &shared_memory_{index} = reinterpret_cast<{item}&>(dynamic_shared_mem[{offset}]);"
)
}
}
}
fn compile_polyfills(_f: &mut std::fmt::Formatter<'_>, _flags: &Flags<D>) -> std::fmt::Result {
Ok(())
}
fn address_space_for_variable(_variable: &Variable<D>) -> String {
"".to_string()
}
}
pub trait DialectBindings<D: Dialect> {
fn compile_kernel_signature(
f: &mut std::fmt::Formatter<'_>,
kernel_name: &str,
tensor_maps: &[KernelArg<D>],
buffers: &[KernelArg<D>],
flags: &Flags<D>,
) -> std::fmt::Result;
fn compile_bindings_body(
_f: &mut std::fmt::Formatter<'_>,
_body: &Body<D>,
) -> std::fmt::Result {
Ok(())
}
}
pub trait DialectCubeBuiltins<D: Dialect> {
fn builtin_rules(flags: &CubeIndexFlags) -> CubeIndexFlags {
let unit_pos_plane = flags.unit_pos_plane;
let plane_dim_checked = flags.plane_dim_checked;
let plane_dim = flags.plane_dim || plane_dim_checked || unit_pos_plane;
let plane_pos = flags.plane_pos;
let absolute_pos = flags.absolute_pos || unit_pos_plane;
let absolute_pos_tuple = flags.absolute_pos_tuple || absolute_pos;
let cube_dim = flags.cube_dim;
let cube_dim_tuple = flags.cube_dim_tuple || cube_dim || absolute_pos || plane_dim_checked;
let unit_pos = flags.unit_pos;
let unit_pos_tuple = flags.unit_pos_tuple || unit_pos;
let cube_count = flags.cube_count;
let cube_count_tuple = flags.cube_count_tuple || absolute_pos;
let cube_pos = flags.cube_pos;
let cube_pos_tuple = flags.cube_pos_tuple || cube_pos;
let cluster_group = flags.cluster_pos;
CubeIndexFlags {
absolute_pos,
absolute_pos_tuple,
cube_count,
cube_count_tuple,
cube_dim,
cube_dim_tuple,
cube_pos,
cube_pos_tuple,
plane_dim,
plane_dim_checked,
plane_pos,
unit_pos_tuple,
unit_pos,
unit_pos_plane,
cluster_pos: cluster_group,
}
}
fn compile_absolute_pos_tuple_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let variable = Variable::<D>::AbsolutePosBaseName;
let ty = variable.item();
let cube_pos_x = Variable::<D>::CubePosX;
let cube_pos_y = Variable::<D>::CubePosY;
let cube_pos_z = Variable::<D>::CubePosZ;
let cube_dim_x = Variable::<D>::CubeDimX;
let cube_dim_y = Variable::<D>::CubeDimY;
let cube_dim_z = Variable::<D>::CubeDimZ;
let unit_pos_x = Variable::<D>::UnitPosX;
let unit_pos_y = Variable::<D>::UnitPosY;
let unit_pos_z = Variable::<D>::UnitPosZ;
writeln!(
f,
"{ty} {variable} = make_{ty}(
{cube_pos_x} * {cube_dim_x} + {unit_pos_x},
{cube_pos_y} * {cube_dim_y} + {unit_pos_y},
{cube_pos_z} * {cube_dim_z} + {unit_pos_z}
);"
)
}
fn compile_absolute_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("absoluteIdx")
}
fn compile_absolute_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("idxGlobal")
}
fn compile_absolute_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Self::compile_absolute_pos_base_name(f)?;
write!(f, ".x")
}
fn compile_absolute_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Self::compile_absolute_pos_base_name(f)?;
write!(f, ".y")
}
fn compile_absolute_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Self::compile_absolute_pos_base_name(f)?;
write!(f, ".z")
}
fn compile_cube_count_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("gridDim")
}
fn compile_cube_count(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("gridDimGlobal")
}
fn compile_cube_count_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Self::compile_cube_count_base_name(f)?;
write!(f, ".x")
}
fn compile_cube_count_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Self::compile_cube_count_base_name(f)?;
write!(f, ".y")
}
fn compile_cube_count_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Self::compile_cube_count_base_name(f)?;
write!(f, ".z")
}
fn compile_cube_dim_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("blockDim")
}
fn compile_cube_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("blockDimGlobal")
}
fn compile_cube_dim_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Self::compile_cube_dim_base_name(f)?;
write!(f, ".x")
}
fn compile_cube_dim_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Self::compile_cube_dim_base_name(f)?;
write!(f, ".y")
}
fn compile_cube_dim_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Self::compile_cube_dim_base_name(f)?;
write!(f, ".z")
}
fn compile_cube_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("blockIdx")
}
fn compile_cube_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("blockIdxGlobal")
}
fn compile_cube_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Self::compile_cube_pos_base_name(f)?;
write!(f, ".x")
}
fn compile_cube_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Self::compile_cube_pos_base_name(f)?;
write!(f, ".y")
}
fn compile_cube_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Self::compile_cube_pos_base_name(f)?;
write!(f, ".z")
}
fn compile_unit_pos_computation(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let variable = Variable::<D>::UnitPos;
let ty = variable.item();
let cube_dim_x = Variable::<D>::CubeDimX;
let cube_dim_y = Variable::<D>::CubeDimY;
let unit_pos_x = Variable::<D>::UnitPosX;
let unit_pos_y = Variable::<D>::UnitPosY;
let unit_pos_z = Variable::<D>::UnitPosZ;
writeln!(
f,
"{ty} {variable} = {unit_pos_x} + {unit_pos_y} * {cube_dim_x} + {unit_pos_z} * ({cube_dim_x} * {cube_dim_y});"
)
}
fn compile_unit_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("threadIdxGlobal")
}
fn compile_unit_pos_base_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("threadIdx")
}
fn compile_unit_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Self::compile_unit_pos_base_name(f)?;
write!(f, ".x")
}
fn compile_unit_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Self::compile_unit_pos_base_name(f)?;
write!(f, ".y")
}
fn compile_unit_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Self::compile_unit_pos_base_name(f)?;
write!(f, ".z")
}
fn compile_plane_dim(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("warpSize")
}
fn compile_plane_dim_checked(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("warpSizeChecked")
}
fn compile_plane_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let unit_pos_x = Variable::<D>::UnitPosX;
let plane_dim = Variable::<D>::PlaneDim;
write!(f, "{unit_pos_x} / {plane_dim}")
}
fn compile_unit_pos_plane(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let absolute_pos = Variable::<D>::AbsolutePos(Elem::U32);
let plane_dim = Variable::<D>::PlaneDim;
let ty = plane_dim.item();
write!(f, "{ty}({absolute_pos}) % {plane_dim}")
}
fn compile_cluster_pos(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "0")
}
fn compile_cluster_pos_x(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "0")
}
fn compile_cluster_pos_y(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "0")
}
fn compile_cluster_pos_z(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "0")
}
}
pub trait DialectInstructions<D: Dialect> {
fn compile_atomic_add(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable<D>,
rhs: &Variable<D>,
out: &Variable<D>,
) -> std::fmt::Result {
let optimized = Variable::optimized_args([*lhs, *rhs, *out]);
let [lhs, rhs, out_optimized] = optimized.args;
let addr_space = D::address_space_for_variable(out);
let out_item = out.item();
let out = out.fmt_left();
match out_optimized.elem() {
Elem::I64 => writeln!(
f,
"{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}({rhs}));",
uint = Elem::<D>::U64
),
Elem::F32 if out_item.vectorization > 1 => {
let vec_ty = format!("float{}", out_item.vectorization);
let out_tmp = Variable::tmp(out_optimized.item());
writeln!(
f,
"{vec_ty} {out_tmp} = atomicAdd(
reinterpret_cast<{addr_space}{vec_ty}*>({lhs}),
reinterpret_cast<const {addr_space}{vec_ty}&>({rhs}));",
)?;
writeln!(
f,
"{out} = reinterpret_cast<{addr_space}{out_item}&>({out_tmp});"
)
}
Elem::F16x2 | Elem::BF16x2 => {
let out_tmp = Variable::tmp(out_optimized.item());
writeln!(
f,
"{} = atomicAdd(
reinterpret_cast<{addr_space}{}*>({lhs}),
reinterpret_cast<const {addr_space}{}&>({rhs}));",
out_tmp.fmt_left(),
lhs.item(),
rhs.item()
)?;
writeln!(
f,
"{out} = reinterpret_cast<{addr_space}{out_item}&>({out_tmp});"
)
}
_ => writeln!(f, "{out} = atomicAdd({lhs}, {rhs});"),
}
}
fn compile_atomic_and(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable<D>,
rhs: &Variable<D>,
out: &Variable<D>,
) -> std::fmt::Result {
let out = out.fmt_left();
writeln!(f, "{out} = atomicAnd({lhs}, {rhs});")
}
fn compile_atomic_cas(
f: &mut std::fmt::Formatter<'_>,
input: &Variable<D>,
cmp: &Variable<D>,
val: &Variable<D>,
out: &Variable<D>,
) -> std::fmt::Result {
let out_item = out.item();
let out = out.fmt_left();
match val.elem() {
Elem::F32 if val.item().vectorization == 2 => {
let u64 = Item::new(Elem::<D>::U64, 1, true);
let out_tmp = Variable::tmp(u64);
writeln!(
f,
"{} = atomicCAS(
reinterpret_cast<{u64}*>({input}),
reinterpret_cast<{u64}&>({cmp}),
reinterpret_cast<{u64}&>({val}));",
out_tmp.fmt_left()
)?;
writeln!(f, "{out} = reinterpret_cast<{out_item}&>({out_tmp});")
}
Elem::F16 | Elem::BF16 if val.item().vectorization == 2 => {
let u32 = Item::new(Elem::<D>::U32, 1, true);
let out_tmp = Variable::tmp(u32);
writeln!(
f,
"{} = atomicCAS(
reinterpret_cast<{u32}*>({input}),
reinterpret_cast<{u32}&>({cmp}),
reinterpret_cast<{u32}&>({val}));",
out_tmp.fmt_left()
)?;
writeln!(f, "{out} = reinterpret_cast<{out_item}&>({out_tmp});")
}
_ => writeln!(f, "{out} = atomicCAS({input}, {cmp}, {val});"),
}
}
fn compile_atomic_load(
f: &mut std::fmt::Formatter<'_>,
input: &Variable<D>,
out: &Variable<D>,
) -> std::fmt::Result {
let zero = Variable::Constant(ConstantValue::UInt(0), input.item());
Self::compile_atomic_add(f, input, &zero, out)
}
fn compile_atomic_max(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable<D>,
rhs: &Variable<D>,
out: &Variable<D>,
) -> std::fmt::Result {
let out = out.fmt_left();
writeln!(f, "{out} = atomicMax({lhs}, {rhs});")
}
fn compile_atomic_min(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable<D>,
rhs: &Variable<D>,
out: &Variable<D>,
) -> std::fmt::Result {
let out = out.fmt_left();
writeln!(f, "{out} = atomicMin({lhs}, {rhs});")
}
fn compile_atomic_or(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable<D>,
rhs: &Variable<D>,
out: &Variable<D>,
) -> std::fmt::Result {
let out = out.fmt_left();
writeln!(f, "{out} = atomicOr({lhs}, {rhs});")
}
fn compile_atomic_store(
f: &mut std::fmt::Formatter<'_>,
input: &Variable<D>,
out: &Variable<D>,
) -> std::fmt::Result {
let tmp = Variable::tmp(input.item());
Self::compile_atomic_swap(f, out, input, &tmp)
}
fn compile_atomic_sub(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable<D>,
rhs: &Variable<D>,
out: &Variable<D>,
) -> std::fmt::Result {
let out = out.fmt_left();
match rhs.elem() {
Elem::U32 | Elem::I32 => writeln!(f, "{out} = atomicSub({lhs}, {rhs});"),
Elem::U64 => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
Elem::I64 => writeln!(
f,
"{out} = atomicAdd(reinterpret_cast<{uint}*>({lhs}), {uint}(-{rhs}));",
uint = Elem::<D>::U64
),
_ => writeln!(f, "{out} = atomicAdd({lhs}, -{rhs});"),
}
}
fn compile_atomic_swap(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable<D>,
rhs: &Variable<D>,
out: &Variable<D>,
) -> std::fmt::Result {
let out_item = out.item();
let out = out.fmt_left();
match rhs.elem() {
Elem::F32 if rhs.item().vectorization == 2 => {
let u64 = Item::new(Elem::<D>::U64, 1, true);
let out_tmp = Variable::tmp(u64);
writeln!(
f,
"{} = atomicExch(
reinterpret_cast<{u64}*>({lhs}),
reinterpret_cast<{u64}&>({rhs}));",
out_tmp.fmt_left()
)?;
writeln!(f, "{out} = reinterpret_cast<{out_item}&>({out_tmp});")
}
Elem::F16 | Elem::BF16 if rhs.item().vectorization == 2 => {
let u32 = Item::new(Elem::<D>::U32, 1, true);
let out_tmp = Variable::tmp(u32);
writeln!(
f,
"{} = atomicExch(
reinterpret_cast<{u32}*>({lhs}),
reinterpret_cast<{u32}&>({rhs}));",
out_tmp.fmt_left()
)?;
writeln!(f, "{out} = reinterpret_cast<{out_item}&>({out_tmp});")
}
_ => writeln!(f, "{out} = atomicExch({lhs}, {rhs});"),
}
}
fn compile_atomic_xor(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable<D>,
rhs: &Variable<D>,
out: &Variable<D>,
) -> std::fmt::Result {
let out = out.fmt_left();
writeln!(f, "{out} = atomicXor({lhs}, {rhs});")
}
fn compile_saturating_add(
f: &mut std::fmt::Formatter<'_>,
lhs: impl Display,
rhs: impl Display,
item: Item<D>,
) -> std::fmt::Result;
fn compile_saturating_sub(
f: &mut std::fmt::Formatter<'_>,
lhs: impl Display,
rhs: impl Display,
item: Item<D>,
) -> std::fmt::Result;
fn compile_instruction_printf(
f: &mut std::fmt::Formatter<'_>,
format_string: &str,
args: &[Variable<D>],
) -> std::fmt::Result {
let args = args.iter().map(|arg| format!("{arg}")).collect::<Vec<_>>();
let args = match args.is_empty() {
true => "".to_string(),
false => format!(", {}", args.join(",")),
};
writeln!(f, "printf({format_string:?}{args});")
}
fn compile_instruction_log1p_scalar<T: Component<D>>(
f: &mut std::fmt::Formatter<'_>,
input: T,
) -> std::fmt::Result {
let elem = input.elem();
match elem {
Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
write!(f, "{elem}(log1p(float({input})))")
}
_ => write!(f, "log1p({input})"),
}
}
fn compile_instruction_sync_threads(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
fn compile_instruction_sync_warp(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
fn compile_instruction_thread_fence(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
fn compile_instruction_tanh_scalar<T: Component<D>>(
f: &mut std::fmt::Formatter<'_>,
input: T,
) -> std::fmt::Result {
let elem = input.elem();
match elem {
Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => {
write!(f, "{elem}(tanh(float({input})))")
}
_ => write!(f, "tanh({input})"),
}
}
fn compile_instruction_find_first_set<T: Component<D>>(
f: &mut std::fmt::Formatter<'_>,
input: T,
out_elem: Elem<D>,
) -> std::fmt::Result;
fn compile_instruction_leading_zeros_scalar<T: Component<D>>(
f: &mut std::fmt::Formatter<'_>,
input: T,
out_elem: Elem<D>,
) -> std::fmt::Result;
fn compile_instruction_trailing_zeros_scalar<T: Component<D>>(
f: &mut std::fmt::Formatter<'_>,
input: T,
out_elem: Elem<D>,
) -> std::fmt::Result;
fn compile_instruction_popcount_scalar<T: Component<D>>(
f: &mut std::fmt::Formatter<'_>,
input: T,
out_elem: Elem<D>,
) -> std::fmt::Result {
write!(f, "{out_elem}(")?;
match input.elem() {
Elem::I32 => write!(f, "__popc({}({input}))", Elem::<D>::U32),
Elem::U32 => write!(f, "__popc({input})"),
Elem::I64 => write!(f, "__popcll({}({input}))", Elem::<D>::U64),
Elem::U64 => write!(f, "__popcll({input})"),
_ => write!(f, "__popc({})", super::unary::zero_extend(input)),
}?;
write!(f, ")")
}
fn compile_instruction_reverse_bits_scalar<T: Component<D>>(
f: &mut std::fmt::Formatter<'_>,
input: T,
out_elem: Elem<D>,
) -> std::fmt::Result {
write!(f, "{out_elem}(")?;
match out_elem {
Elem::I32 => write!(f, "__brev({}({input}))", Elem::<D>::U32),
Elem::U32 => write!(f, "__brev({input})"),
Elem::I64 => write!(f, "__brevll({}({input}))", Elem::<D>::U64),
Elem::U64 => write!(f, "__brevll({input})"),
_ => write!(
f,
"__brev({}) >> {}",
super::unary::zero_extend(input),
(size_of::<u32>() - out_elem.size()) * 8
),
}?;
write!(f, ")")
}
fn compile_instruction_max_function_name(
f: &mut std::fmt::Formatter<'_>,
item: Item<D>,
) -> std::fmt::Result;
fn compile_instruction_min_function_name(
f: &mut std::fmt::Formatter<'_>,
item: Item<D>,
) -> std::fmt::Result;
fn compile_instruction_powf(
f: &mut std::fmt::Formatter<'_>,
lhs: &str,
rhs: &str,
elem: Elem<D>,
) -> std::fmt::Result {
match elem {
Elem::F32 => write!(f, "powf({lhs}, {rhs})"),
Elem::F64 => write!(f, "pow({lhs}, {rhs})"),
_ => write!(f, "#error Unsupported type for powf: {elem}"),
}
}
fn compile_instruction_hypot(
f: &mut std::fmt::Formatter<'_>,
lhs: &str,
rhs: &str,
elem: Elem<D>,
) -> std::fmt::Result {
match elem {
Elem::F32 => write!(f, "hypotf({lhs}, {rhs})"),
Elem::F64 => write!(f, "hypot({lhs}, {rhs})"),
_ => write!(f, "#error Unsupported type for hypot: {elem}"),
}
}
fn compile_instruction_rhypot(
f: &mut std::fmt::Formatter<'_>,
lhs: &str,
rhs: &str,
elem: Elem<D>,
) -> std::fmt::Result {
match elem {
Elem::F32 => write!(f, "rhypotf({lhs}, {rhs})"),
Elem::F64 => write!(f, "rhypot({lhs}, {rhs})"),
_ => write!(f, "#error Unsupported type for rhypot: {elem}"),
}
}
fn compile_instruction_half_function_name_prefix() -> &'static str {
"h"
}
fn compile_instruction_half2_function_name_prefix() -> &'static str {
"h2"
}
fn compile_warp_shuffle(
f: &mut std::fmt::Formatter<'_>,
var: &str,
source: &str,
) -> std::fmt::Result;
fn compile_warp_shuffle_xor(
f: &mut std::fmt::Formatter<'_>,
var: &str,
elem: &Elem<D>,
offset: &str,
) -> std::fmt::Result;
fn compile_warp_shuffle_up(
f: &mut std::fmt::Formatter<'_>,
var: &str,
offset: &str,
) -> std::fmt::Result;
fn compile_warp_shuffle_down(
f: &mut std::fmt::Formatter<'_>,
var: &str,
offset: &str,
) -> std::fmt::Result;
fn compile_warp_all<T: Component<D>>(
f: &mut std::fmt::Formatter<'_>,
input: &T,
) -> std::fmt::Result;
fn compile_warp_any<T: Component<D>>(
f: &mut std::fmt::Formatter<'_>,
input: &T,
) -> std::fmt::Result;
fn compile_warp_ballot(
f: &mut std::fmt::Formatter<'_>,
input: &Variable<D>,
out_elem: &Elem<D>,
) -> std::fmt::Result;
fn compile_warp_elect(f: &mut std::fmt::Formatter<'_>, out: &str) -> std::fmt::Result {
write!(
f,
"
unsigned int mask = __activemask();
unsigned int leader = __ffs(mask) - 1;
{out} = threadIdx.x % warpSize == leader;
"
)
}
fn compile_unreachable(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
}
#[derive(Debug, Clone, Copy, new)]
pub struct ManualMma<'a, D: Dialect> {
pub shape: MmaShape<D>,
pub frag_a: &'a Variable<D>,
pub frag_b: &'a Variable<D>,
pub frag_c: &'a Variable<D>,
pub frag_d: &'a Variable<D>,
}
pub trait DialectWarpReduceCompiler<D: Dialect>:
Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
{
fn warp_reduce_sum(
f: &mut core::fmt::Formatter<'_>,
input: &Variable<D>,
out: &Variable<D>,
) -> core::fmt::Result {
reduce_operator(f, input, out, "+=")
}
fn warp_reduce_prod(
f: &mut core::fmt::Formatter<'_>,
input: &Variable<D>,
out: &Variable<D>,
) -> core::fmt::Result {
reduce_operator(f, input, out, "*=")
}
fn warp_reduce_max(
f: &mut core::fmt::Formatter<'_>,
input: &Variable<D>,
out: &Variable<D>,
) -> core::fmt::Result {
reduce_comparison(f, input, out, D::compile_instruction_max_function_name)
}
fn warp_reduce_min(
f: &mut core::fmt::Formatter<'_>,
input: &Variable<D>,
out: &Variable<D>,
) -> core::fmt::Result {
reduce_comparison(f, input, out, D::compile_instruction_min_function_name)
}
fn warp_reduce_all(
f: &mut core::fmt::Formatter<'_>,
input: &Variable<D>,
out: &Variable<D>,
) -> core::fmt::Result {
reduce_quantifier(f, input, out, D::compile_warp_all::<IndexedVariable<D>>)
}
fn warp_reduce_any(
f: &mut core::fmt::Formatter<'_>,
input: &Variable<D>,
out: &Variable<D>,
) -> core::fmt::Result {
reduce_quantifier(f, input, out, D::compile_warp_any::<IndexedVariable<D>>)
}
fn warp_reduce_sum_inclusive(
f: &mut core::fmt::Formatter<'_>,
input: &Variable<D>,
out: &Variable<D>,
) -> core::fmt::Result {
reduce_inclusive(f, input, out, "+=")
}
fn warp_reduce_prod_inclusive(
f: &mut core::fmt::Formatter<'_>,
input: &Variable<D>,
out: &Variable<D>,
) -> core::fmt::Result {
reduce_inclusive(f, input, out, "*=")
}
fn warp_reduce_sum_exclusive(
f: &mut core::fmt::Formatter<'_>,
input: &Variable<D>,
out: &Variable<D>,
) -> core::fmt::Result {
reduce_exclusive(f, input, out, "+=", "0")
}
fn warp_reduce_prod_exclusive(
f: &mut core::fmt::Formatter<'_>,
input: &Variable<D>,
out: &Variable<D>,
) -> core::fmt::Result {
reduce_exclusive(f, input, out, "*=", "1")
}
}
pub trait DialectWmmaCompiler<D: Dialect>:
Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
{
#[allow(unused_variables)]
fn compile_wmma_includes(
f: &mut std::fmt::Formatter<'_>,
flags: &Flags<D>,
) -> std::fmt::Result {
Ok(())
}
#[allow(unused_variables)]
fn compile_wmma_type_definitions(
f: &mut std::fmt::Formatter<'_>,
flags: &Flags<D>,
) -> std::fmt::Result {
Ok(())
}
#[allow(unused_variables)]
fn compile_wmma_local_variables(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Ok(())
}
#[allow(unused_variables)]
fn compile_wwma_fragment_ident(
f: &mut std::fmt::Formatter<'_>,
ident: &FragmentIdent<D>,
) -> std::fmt::Result {
Ok(())
}
#[allow(unused_variables)]
fn compile_wmma_fragment_layout(
f: &mut std::fmt::Formatter<'_>,
layout: &FragmentLayout<D>,
) -> std::fmt::Result {
Ok(())
}
#[allow(unused_variables)]
fn compile_wmma_fragment(
f: &mut std::fmt::Formatter<'_>,
fragment: &Fragment<D>,
) -> std::fmt::Result {
Ok(())
}
fn compile_wmma_fragment_declaration(
f: &mut std::fmt::Formatter<'_>,
var: &Variable<D>,
) -> std::fmt::Result;
fn compile_wmma_instruction(
f: &mut std::fmt::Formatter<'_>,
instruction: &WmmaInstruction<D>,
) -> std::fmt::Result;
fn compile_manual_mma(f: &mut std::fmt::Formatter<'_>, mma: ManualMma<D>) -> std::fmt::Result;
fn compile_scaled_mma(
f: &mut std::fmt::Formatter<'_>,
mma: ManualMma<D>,
scales_a: Variable<D>,
scales_b: Variable<D>,
scales_factor: u32,
) -> std::fmt::Result;
fn supported_wmma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
fn supported_mma_combinations(arch: &D::Architecture) -> SupportedMmaCombinations;
fn supported_scaled_mma_combinations(
_arch: &D::Architecture,
) -> SupportedScaledMmaCombinations {
Vec::new()
}
}
pub trait DialectProcessors<D: Dialect> {
fn processors() -> Vec<Box<dyn Processor>>;
}