use cubecl_ir::{ExpandElement, Instruction};
use paste::paste;
use crate::{
ir::{BarrierOps, Scope},
unexpanded,
};
use super::{
CubeDebug, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, Line, ReadOnly, ReadWrite,
Slice, SliceExpand, SliceMut, TensorMap,
};
#[derive(Clone, Copy)]
pub struct Barrier;
impl CubeType for Barrier {
type ExpandType = BarrierExpand;
}
impl IntoMut for BarrierExpand {
fn into_mut(self, _scope: &mut Scope) -> Self {
self
}
}
impl CubeDebug for BarrierExpand {
fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
scope.update_variable_name(*self.elem, name);
}
}
#[derive(Clone)]
pub struct BarrierExpand {
elem: ExpandElement,
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub struct BarrierLevel(InnerBarrierLevel);
impl CubeType for BarrierLevel {
type ExpandType = Self;
}
impl IntoMut for BarrierLevel {
fn into_mut(self, _scope: &mut Scope) -> Self {
self
}
}
impl CubeDebug for BarrierLevel {
fn set_debug_name(&self, _scope: &mut Scope, _name: &'static str) {}
}
#[derive(Copy, Clone, Eq, PartialEq)]
enum InnerBarrierLevel {
Unit,
CubeCoop(u32),
CubeManual(u32),
}
impl BarrierLevel {
pub fn unit() -> Self {
BarrierLevel(InnerBarrierLevel::Unit)
}
pub fn cube_coop(elected_unit: u32) -> Self {
BarrierLevel(InnerBarrierLevel::CubeCoop(elected_unit))
}
pub fn cube_manual(elected_unit: u32) -> Self {
BarrierLevel(InnerBarrierLevel::CubeManual(elected_unit))
}
pub fn __expand_unit(_scope: &mut Scope) -> BarrierLevel {
BarrierLevel(InnerBarrierLevel::Unit)
}
pub fn __expand_cube_coop(_scope: &mut Scope, elected_unit: u32) -> Self {
BarrierLevel(InnerBarrierLevel::CubeCoop(elected_unit))
}
pub fn __expand_cube_manual(_scope: &mut Scope, elected_unit: u32) -> Self {
BarrierLevel(InnerBarrierLevel::CubeManual(elected_unit))
}
}
impl From<InnerBarrierLevel> for cubecl_ir::BarrierLevel {
fn from(val: InnerBarrierLevel) -> Self {
match val {
InnerBarrierLevel::Unit => cubecl_ir::BarrierLevel::Unit,
InnerBarrierLevel::CubeCoop(elected_unit) => {
cubecl_ir::BarrierLevel::CubeCoop(elected_unit)
}
InnerBarrierLevel::CubeManual(elected_unit) => {
cubecl_ir::BarrierLevel::CubeManual(elected_unit)
}
}
}
}
macro_rules! tensor_map_load {
($dim: literal, $($arg: expr),*) => {
paste! {
impl Barrier {
#[allow(unused, clippy::too_many_arguments)]
pub fn [<tma_load_ $dim d>]<C: CubePrimitive>(
&self,
source: &TensorMap<C>,
destination: &mut SliceMut<Line<C>>,
$($arg: i32),*
) {
unexpanded!()
}
#[allow(clippy::too_many_arguments)]
pub fn [<__expand_tma_load_ $dim d>]<C: CubePrimitive>(
scope: &mut Scope,
expand: BarrierExpand,
source: ExpandElementTyped<TensorMap<C>>,
destination: SliceExpand<Line<C>, ReadWrite>,
$($arg: ExpandElementTyped<i32>),*
) {
expand.[<__expand_tma_load_ $dim d_method>](scope, source, destination, $($arg),*);
}
}
impl BarrierExpand {
#[allow(clippy::too_many_arguments)]
pub fn [<__expand_tma_load_ $dim d_method>]<C: CubePrimitive>(
&self,
scope: &mut Scope,
source: ExpandElementTyped<TensorMap<C>>,
destination: SliceExpand<Line<C>, ReadWrite>,
$($arg: ExpandElementTyped<i32>),*
) {
let barrier = *self.elem;
let source = *source.expand;
let (destination, destination_offset) = destination.__to_raw_parts();
let mem_copy = BarrierOps::TmaLoad {
barrier,
tensor_map: source,
indices: vec![$(*$arg.expand),*],
offset_out: destination_offset
};
scope.register(Instruction::new(mem_copy, destination));
}
}
}
};
}
macro_rules! tensor_map_load_im2col {
($dim: literal, $($arg: expr),*; $($offset: expr),*) => {
paste! {
impl Barrier {
#[allow(unused, clippy::too_many_arguments)]
pub fn [<tma_load_im2col_ $dim d>]<C: CubePrimitive>(
&self,
source: &TensorMap<C>,
destination: &mut SliceMut<Line<C>>,
$($arg: i32,)*
$($offset: u16),*
) {
unexpanded!()
}
#[allow(clippy::too_many_arguments)]
pub fn [<__expand_tma_load_im2col_ $dim d>]<C: CubePrimitive>(
scope: &mut Scope,
expand: BarrierExpand,
source: ExpandElementTyped<TensorMap<C>>,
destination: SliceExpand<Line<C>, ReadWrite>,
$($arg: ExpandElementTyped<i32>,)*
$($offset: ExpandElementTyped<u16>),*
) {
expand.[<__expand_tma_load_im2col_ $dim d_method>](scope, source, destination, $($arg),*, $($offset),*);
}
}
impl BarrierExpand {
#[allow(clippy::too_many_arguments)]
pub fn [<__expand_tma_load_im2col_ $dim d_method>]<C: CubePrimitive>(
&self,
scope: &mut Scope,
source: ExpandElementTyped<TensorMap<C>>,
destination: SliceExpand<Line<C>, ReadWrite>,
$($arg: ExpandElementTyped<i32>,)*
$($offset: ExpandElementTyped<u16>),*
) {
let barrier = *self.elem;
let source = *source.expand;
let (destination, destination_offset) = destination.__to_raw_parts();
let mem_copy = BarrierOps::TmaLoadIm2col {
barrier,
tensor_map: source,
indices: vec![$(*$arg.expand),*],
offsets: vec![$(*$offset.expand),*],
offset_out: destination_offset,
};
scope.register(Instruction::new(mem_copy, destination));
}
}
}
};
}
tensor_map_load!(1, x);
tensor_map_load!(2, y, x);
tensor_map_load!(3, z, y, x);
tensor_map_load!(4, w, z, y, x);
tensor_map_load!(5, v, w, z, y, x);
tensor_map_load_im2col!(3, n, w, c; w_offset);
tensor_map_load_im2col!(4, n, h, w, c; h_offset, w_offset);
tensor_map_load_im2col!(5, n, d, h, w, c; d_offset, h_offset, w_offset);
impl Barrier {
pub fn new(_level: BarrierLevel) -> Self {
Self
}
pub fn new_with_tma_proxy(_level: BarrierLevel) -> Self {
Self
}
pub fn memcpy_async<C: CubePrimitive>(
&self,
_source: &Slice<Line<C>>,
_destination: &mut SliceMut<Line<C>>,
) {
unexpanded!()
}
pub fn arrive(&self) {
unexpanded!()
}
pub fn arrive_tx(&self, _arrival_count: u32, _transaction_count: u32) {
unexpanded!()
}
pub fn expect_tx(&self, _expected_count: u32) {
unexpanded!()
}
pub fn wait(&self) {
unexpanded!()
}
pub fn arrive_and_wait(&self) {
unexpanded!()
}
pub fn __expand_new(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand {
let variable = scope.create_barrier(level.0.into());
scope.register(BarrierOps::Init {
barrier: *variable,
with_cta_fence: false,
});
BarrierExpand { elem: variable }
}
pub fn __expand_new_with_tma_proxy(scope: &mut Scope, level: BarrierLevel) -> BarrierExpand {
let variable = scope.create_barrier(level.0.into());
scope.register(BarrierOps::Init {
barrier: *variable,
with_cta_fence: true,
});
BarrierExpand { elem: variable }
}
pub fn __expand_memcpy_async<C: CubePrimitive>(
scope: &mut Scope,
expand: BarrierExpand,
source: SliceExpand<Line<C>, ReadOnly>,
destination: SliceExpand<Line<C>, ReadWrite>,
) {
expand.__expand_memcpy_async_method(scope, source, destination);
}
pub fn __expand_arrive(scope: &mut Scope, expand: BarrierExpand) {
expand.__expand_arrive_method(scope);
}
pub fn __expand_arrive_tx(
scope: &mut Scope,
expand: BarrierExpand,
arrival_count: ExpandElementTyped<u32>,
transaction_count: ExpandElementTyped<u32>,
) {
expand.__expand_arrive_tx_method(scope, arrival_count, transaction_count);
}
pub fn __expand_expect_tx(
scope: &mut Scope,
expand: BarrierExpand,
expected_count: ExpandElementTyped<u32>,
) {
expand.__expand_expect_tx_method(scope, expected_count);
}
pub fn __expand_wait(scope: &mut Scope, expand: BarrierExpand) {
expand.__expand_wait_method(scope);
}
pub fn __expand_arrive_and_wait(scope: &mut Scope, expand: BarrierExpand) {
expand.__expand_arrive_and_wait_method(scope);
}
}
impl BarrierExpand {
pub fn __expand_memcpy_async_method<C: CubePrimitive>(
&self,
scope: &mut Scope,
source: SliceExpand<Line<C>, ReadOnly>,
destination: SliceExpand<Line<C>, ReadWrite>,
) {
let barrier = *self.elem;
let source_length = *source.length.expand;
let (source, source_offset) = source.__to_raw_parts();
let (destination, destination_offset) = destination.__to_raw_parts();
let mem_copy = BarrierOps::MemCopyAsync {
barrier,
source,
source_length,
offset_source: source_offset,
offset_out: destination_offset,
};
scope.register(Instruction::new(mem_copy, destination));
}
pub fn __expand_arrive_method(&self, scope: &mut Scope) {
let barrier = *self.elem;
scope.register(BarrierOps::Arrive { barrier });
}
pub fn __expand_arrive_tx_method(
&self,
scope: &mut Scope,
arrival_count: ExpandElementTyped<u32>,
transaction_count: ExpandElementTyped<u32>,
) {
let barrier = *self.elem;
let arrival_count: ExpandElement = arrival_count.into();
let transaction_count: ExpandElement = transaction_count.into();
scope.register(BarrierOps::ArriveTx {
barrier,
arrive_count_update: arrival_count.consume(),
transaction_count_update: transaction_count.consume(),
});
}
pub fn __expand_expect_tx_method(
&self,
scope: &mut Scope,
transaction_count: ExpandElementTyped<u32>,
) {
let barrier = *self.elem;
let transaction_count: ExpandElement = transaction_count.into();
scope.register(BarrierOps::ExpectTx {
barrier,
transaction_count_update: transaction_count.consume(),
});
}
pub fn __expand_wait_method(&self, scope: &mut Scope) {
let barrier = *self.elem;
scope.register(BarrierOps::Wait { barrier });
}
pub fn __expand_arrive_and_wait_method(&self, scope: &mut Scope) {
let barrier = *self.elem;
scope.register(BarrierOps::ArriveAndWait { barrier });
}
}