use alloc::vec;
use core::ops::{Deref, DerefMut};
use crate as cubecl;
use cubecl_ir::{Instruction, ManagedVariable, OpaqueType};
use cubecl_macros::intrinsic;
use paste::paste;
use crate::{
ir::{BarrierOps, Scope},
prelude::*,
unexpanded,
};
use super::{
CubePrimitive, CubeType, NativeExpand, ReadOnly, ReadWrite, Slice, SliceExpand, SliceMut,
TensorMap,
};
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct Barrier;
pub type BarrierExpand = NativeExpand<Barrier>;
#[derive(Clone, Copy, PartialEq)]
pub struct BarrierToken;
impl CubeType for Barrier {
type ExpandType = NativeExpand<Barrier>;
}
impl CubePrimitive for Barrier {
type Scalar = u32; type Size = Const<1>;
type WithScalar<S: Scalar> = S;
fn from_const_value(_value: cubecl_ir::ConstantValue) -> Self {
unreachable!("Can't create from const value")
}
}
impl NativeAssign for Barrier {
fn elem_init_mut(_scope: &mut Scope, elem: ManagedVariable) -> ManagedVariable {
elem
}
}
impl CubeType for BarrierToken {
type ExpandType = NativeExpand<BarrierToken>;
}
impl NativeAssign for BarrierToken {
fn elem_init_mut(_scope: &mut crate::ir::Scope, elem: ManagedVariable) -> ManagedVariable {
elem
}
}
macro_rules! tensor_map_load {
($dim: literal, $($arg: expr),*) => {
paste! {
impl Barrier {
#[allow(unused, clippy::too_many_arguments)]
pub fn [<tma_load_ $dim d>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
&self,
source: &TensorMap<C1, Tiled>,
destination: &mut SliceMut<C2>,
$($arg: i32),*
) {
unexpanded!()
}
#[allow(clippy::too_many_arguments)]
pub fn [<__expand_tma_load_ $dim d>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
scope: &mut Scope,
expand: BarrierExpand,
source: NativeExpand<TensorMap<C1, Tiled>>,
destination: SliceExpand<C2, ReadWrite>,
$($arg: NativeExpand<i32>),*
) {
expand.[<__expand_tma_load_ $dim d_method>](scope, source, destination, $($arg),*);
}
}
impl BarrierExpand {
#[allow(clippy::too_many_arguments)]
pub fn [<__expand_tma_load_ $dim d_method>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
&self,
scope: &mut Scope,
source: NativeExpand<TensorMap<C1, Tiled>>,
destination: SliceExpand<C2, ReadWrite>,
$($arg: NativeExpand<i32>),*
) {
let barrier = *self.expand;
let source = *source.expand;
let (destination, destination_offset) = destination.__to_raw_parts();
let mem_copy = BarrierOps::TmaLoad {
barrier,
tensor_map: source,
indices: vec![$(*$arg.expand),*],
offset_out: destination_offset
};
scope.register(Instruction::new(mem_copy, destination));
}
}
}
};
}
macro_rules! tensor_map_load_im2col {
($dim: literal, $($arg: expr),*; $($offset: expr),*) => {
paste! {
impl Barrier {
#[allow(unused, clippy::too_many_arguments)]
pub fn [<tma_load_im2col_ $dim d>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
&self,
source: &TensorMap<C1, Im2col>,
destination: &mut SliceMut<C2>,
$($arg: i32,)*
$($offset: u16),*
) {
unexpanded!()
}
#[allow(clippy::too_many_arguments)]
pub fn [<__expand_tma_load_im2col_ $dim d>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
scope: &mut Scope,
expand: BarrierExpand,
source: NativeExpand<TensorMap<C1, Im2col>>,
destination: SliceExpand<C2, ReadWrite>,
$($arg: NativeExpand<i32>,)*
$($offset: NativeExpand<u16>),*
) {
expand.[<__expand_tma_load_im2col_ $dim d_method>](scope, source, destination, $($arg),*, $($offset),*);
}
}
impl BarrierExpand {
#[allow(clippy::too_many_arguments)]
pub fn [<__expand_tma_load_im2col_ $dim d_method>]<C1: CubePrimitive, C2: CubePrimitive<Scalar = C1::Scalar>>(
&self,
scope: &mut Scope,
source: NativeExpand<TensorMap<C1, Im2col>>,
destination: SliceExpand<C2, ReadWrite>,
$($arg: NativeExpand<i32>,)*
$($offset: NativeExpand<u16>),*
) {
let barrier = *self.expand;
let source = *source.expand;
let (destination, destination_offset) = destination.__to_raw_parts();
let mem_copy = BarrierOps::TmaLoadIm2col {
barrier,
tensor_map: source,
indices: vec![$(*$arg.expand),*],
offsets: vec![$(*$offset.expand),*],
offset_out: destination_offset,
};
scope.register(Instruction::new(mem_copy, destination));
}
}
}
};
}
tensor_map_load!(1, x);
tensor_map_load!(2, y, x);
tensor_map_load!(3, z, y, x);
tensor_map_load!(4, w, z, y, x);
tensor_map_load!(5, v, w, z, y, x);
tensor_map_load_im2col!(3, n, w, c; w_offset);
tensor_map_load_im2col!(4, n, h, w, c; h_offset, w_offset);
tensor_map_load_im2col!(5, n, d, h, w, c; d_offset, h_offset, w_offset);
#[cube(self_type = "ref")]
impl Barrier {
pub fn local() -> Self {
intrinsic!(|scope| {
let variable =
scope.create_local_mut(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Unit));
scope.register(BarrierOps::Init {
barrier: *variable,
is_elected: true.into(),
arrival_count: 1.into(),
});
variable.into()
})
}
#[allow(unused_variables)]
pub fn shared(arrival_count: u32, is_elected: bool) -> Shared<Barrier> {
intrinsic!(|scope| {
let variable = scope.create_shared(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Cube));
scope.register(BarrierOps::Init {
barrier: *variable,
is_elected: *is_elected.expand,
arrival_count: *arrival_count.expand,
});
variable.into()
})
}
pub fn shared_uninit() -> Shared<Barrier> {
intrinsic!(|scope| {
let variable = scope.create_shared(OpaqueType::Barrier(cubecl_ir::BarrierLevel::Cube));
scope.register(BarrierOps::Declare { barrier: *variable });
variable.into()
})
}
#[allow(unused_variables)]
pub fn init_manual(&self, arrival_count: u32) {
intrinsic!(|scope| {
let barrier = *self.expand.clone();
scope.register(BarrierOps::InitManual {
barrier,
arrival_count: *arrival_count.expand,
});
})
}
}
#[cube(self_type = "ref")]
impl Barrier {
#[allow(unused_variables)]
pub fn memcpy_async<C: CubePrimitive>(&self, source: &Slice<C>, destination: &mut SliceMut<C>) {
intrinsic!(|scope| {
let barrier = *self.expand;
let source_length = *source.length.expand;
let (source, source_offset) = source.__to_raw_parts();
let (destination, destination_offset) = destination.__to_raw_parts();
let mem_copy = BarrierOps::MemCopyAsync {
barrier,
source,
source_length,
offset_source: source_offset,
offset_out: destination_offset,
};
scope.register(Instruction::new(mem_copy, destination));
})
}
#[allow(unused_variables)]
pub fn memcpy_async_cooperative<C: CubePrimitive>(
&self,
source: &Slice<C>,
destination: &mut SliceMut<C>,
) {
intrinsic!(|scope| {
let barrier = *self.expand;
let source_length = *source.length.expand;
let (source, source_offset) = source.__to_raw_parts();
let (destination, destination_offset) = destination.__to_raw_parts();
let mem_copy = BarrierOps::MemCopyAsyncCooperative {
barrier,
source,
source_length,
offset_source: source_offset,
offset_out: destination_offset,
};
scope.register(Instruction::new(mem_copy, destination));
})
}
#[allow(unused_variables)]
pub fn memcpy_async_tx<C: CubePrimitive>(
&self,
source: &Slice<C>,
destination: &mut SliceMut<C>,
) {
intrinsic!(|scope| {
let barrier = *self.expand;
let source_length = *source.length.expand;
let (source, source_offset) = source.__to_raw_parts();
let (destination, destination_offset) = destination.__to_raw_parts();
let mem_copy = BarrierOps::MemCopyAsyncTx {
barrier,
source,
source_length,
offset_source: source_offset,
offset_out: destination_offset,
};
scope.register(Instruction::new(mem_copy, destination));
})
}
}
#[cube(self_type = "ref")]
impl Barrier {
pub fn arrive(&self) -> BarrierToken {
intrinsic!(|scope| {
let barrier = *self.expand;
let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
unreachable!()
};
let token = scope.create_barrier_token(barrier.index().unwrap(), level);
scope.register(Instruction::new(BarrierOps::Arrive { barrier }, *token));
token.into()
})
}
#[allow(unused_variables)]
pub fn arrive_and_expect_tx(&self, arrival_count: u32, transaction_count: u32) -> BarrierToken {
intrinsic!(|scope| {
let barrier = *self.expand;
let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
unreachable!()
};
let token = scope.create_barrier_token(barrier.index().unwrap(), level);
let arrival_count: ManagedVariable = arrival_count.into();
let transaction_count: ManagedVariable = transaction_count.into();
scope.register(Instruction::new(
BarrierOps::ArriveTx {
barrier,
arrive_count_update: arrival_count.consume(),
transaction_count_update: transaction_count.consume(),
},
*token,
));
token.into()
})
}
#[allow(unused_variables)]
pub fn expect_tx(&self, expected_count: u32) {
intrinsic!(|scope| {
let barrier = *self.expand;
let transaction_count: ManagedVariable = expected_count.into();
scope.register(BarrierOps::ExpectTx {
barrier,
transaction_count_update: transaction_count.consume(),
});
})
}
pub fn arrive_and_wait(&self) {
intrinsic!(|scope| {
let barrier = *self.expand;
scope.register(BarrierOps::ArriveAndWait { barrier });
})
}
#[allow(unused_variables)]
pub fn wait(&self, token: BarrierToken) {
intrinsic!(|scope| {
let barrier = *self.expand;
let token = *token.expand;
scope.register(BarrierOps::Wait { barrier, token });
})
}
#[allow(unused_variables)]
pub fn wait_parity(&self, phase: u32) {
intrinsic!(|scope| {
let barrier = *self.expand;
let phase = *phase.expand;
scope.register(BarrierOps::WaitParity { barrier, phase });
})
}
}
pub fn copy_async<C: CubePrimitive>(
_source: &Slice<C>,
_destination: &mut SliceMut<C>,
_copy_size: u32,
) {
unexpanded!()
}
pub mod copy_async {
use super::*;
pub fn expand<C: CubePrimitive>(
scope: &mut Scope,
source: SliceExpand<C, ReadOnly>,
destination: SliceExpand<C, ReadWrite>,
copy_length: u32,
) {
let source_length = copy_length.into();
let (source, source_offset) = source.__to_raw_parts();
let (destination, destination_offset) = destination.__to_raw_parts();
let scalar_size = C::as_type(scope).storage_type().size();
let mem_copy = BarrierOps::CopyAsync {
source,
source_length,
offset_source: source_offset,
offset_out: destination_offset,
copy_length: copy_length * scalar_size as u32,
checked: false,
};
scope.register(Instruction::new(mem_copy, destination));
}
}
pub fn copy_async_checked<C: CubePrimitive>(
_source: &Slice<C>,
_destination: &mut SliceMut<C>,
_copy_size: u32,
) {
unexpanded!();
}
pub mod copy_async_checked {
use super::*;
pub fn expand<C: CubePrimitive>(
scope: &mut Scope,
source: SliceExpand<C, ReadOnly>,
destination: SliceExpand<C, ReadWrite>,
copy_length: u32,
) {
let source_length = *source.length.expand;
let (source, source_offset) = source.__to_raw_parts();
let (destination, destination_offset) = destination.__to_raw_parts();
let scalar_size = C::as_type(scope).storage_type().size();
let mem_copy = BarrierOps::CopyAsync {
source,
source_length,
offset_source: source_offset,
offset_out: destination_offset,
copy_length: copy_length * scalar_size as u32,
checked: true,
};
scope.register(Instruction::new(mem_copy, destination));
}
}
#[cube(self_type = "ref")]
impl Barrier {
pub fn commit_copy_async(&self) {
intrinsic!(|scope| {
let barrier = *self.expand;
let StorageType::Opaque(OpaqueType::Barrier(level)) = barrier.ty.storage_type() else {
unreachable!()
};
let token = scope.create_barrier_token(barrier.index().unwrap(), level);
scope.register(Instruction::new(
BarrierOps::CommitCopyAsync { barrier },
*token,
));
})
}
}
impl Deref for Shared<Barrier> {
type Target = Barrier;
fn deref(&self) -> &Self::Target {
unexpanded!()
}
}
impl Deref for SharedExpand<Barrier> {
type Target = BarrierExpand;
fn deref(&self) -> &Self::Target {
unsafe { self.as_type_ref_unchecked::<Barrier>() }
}
}
impl DerefMut for Shared<Barrier> {
fn deref_mut(&mut self) -> &mut Self::Target {
todo!()
}
}
impl DerefMut for SharedExpand<Barrier> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.as_type_mut_unchecked::<Barrier>() }
}
}
impl From<SharedExpand<Barrier>> for BarrierExpand {
fn from(value: SharedExpand<Barrier>) -> Self {
value.expand.into()
}
}