use std::fmt::{Display, Write};
use cubecl_core::ir::BarrierLevel;
use super::{Component, Dialect, Variable};
#[derive(Debug, Clone)]
pub enum BarrierOps<D: Dialect> {
Declare {
barrier: Variable<D>,
level: BarrierLevel,
},
Init {
barrier: Variable<D>,
is_elected: Variable<D>,
arrival_count: Variable<D>,
level: BarrierLevel,
},
InitManual {
barrier: Variable<D>,
arrival_count: Variable<D>,
},
MemCopyAsync {
barrier: Variable<D>,
source: Variable<D>,
destination: Variable<D>,
source_length: Variable<D>,
offset_source: Variable<D>,
offset_out: Variable<D>,
cooperative: bool,
},
MemCopyAsyncTx {
barrier: Variable<D>,
source: Variable<D>,
destination: Variable<D>,
source_length: Variable<D>,
offset_source: Variable<D>,
offset_out: Variable<D>,
},
CopyAsync {
source: Variable<D>,
destination: Variable<D>,
source_length: Variable<D>,
offset_source: Variable<D>,
offset_out: Variable<D>,
copy_size: u32,
checked: bool,
},
MemCopyAsyncTensorGlobalToShared {
barrier: Variable<D>,
smem_buffer: Variable<D>,
smem_offset: Variable<D>,
tensor_map: Variable<D>,
indices: Vec<Variable<D>>,
},
TmaLoadIm2col {
barrier: Variable<D>,
smem_buffer: Variable<D>,
smem_offset: Variable<D>,
tensor_map: Variable<D>,
indices: Vec<Variable<D>>,
offsets: Vec<Variable<D>>,
},
Arrive {
barrier: Variable<D>,
token: Variable<D>,
},
ArriveTx {
barrier: Variable<D>,
token: Variable<D>,
arrive_count_update: Variable<D>,
transaction_count_update: Variable<D>,
},
ArriveCopyAsync {
barrier: Variable<D>,
},
ExpectTx {
barrier: Variable<D>,
transaction_count_update: Variable<D>,
},
Wait {
barrier: Variable<D>,
token: Variable<D>,
},
WaitParity {
barrier: Variable<D>,
phase: Variable<D>,
},
ArriveAndWait {
barrier: Variable<D>,
level: BarrierLevel,
},
}
impl<D: Dialect> BarrierOps<D> {
pub fn barrier_id(&self) -> u32 {
match self {
BarrierOps::MemCopyAsync { barrier, .. } => barrier.id().unwrap(),
BarrierOps::MemCopyAsyncTx { barrier, .. } => barrier.id().unwrap(),
BarrierOps::Declare { barrier, .. } => barrier.id().unwrap(),
BarrierOps::Init { barrier, .. } => barrier.id().unwrap(),
BarrierOps::InitManual { barrier, .. } => barrier.id().unwrap(),
BarrierOps::ArriveAndWait { barrier, .. } => barrier.id().unwrap(),
BarrierOps::Arrive { barrier, .. } => barrier.id().unwrap(),
BarrierOps::ArriveTx { barrier, .. } => barrier.id().unwrap(),
BarrierOps::Wait { barrier, .. } => barrier.id().unwrap(),
BarrierOps::WaitParity { barrier, .. } => barrier.id().unwrap(),
BarrierOps::MemCopyAsyncTensorGlobalToShared { barrier, .. } => barrier.id().unwrap(),
BarrierOps::TmaLoadIm2col { barrier, .. } => barrier.id().unwrap(),
BarrierOps::ExpectTx { barrier, .. } => barrier.id().unwrap(),
BarrierOps::CopyAsync { .. } => 0,
BarrierOps::ArriveCopyAsync { barrier } => barrier.id().unwrap(),
}
}
}
impl<D: Dialect> Display for BarrierOps<D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BarrierOps::Declare { barrier, level } => {
match level {
BarrierLevel::Unit => Ok(()),
BarrierLevel::Cube => write!(
f,
"
cooperative_groups::thread_block block_{barrier} = cooperative_groups::this_thread_block();
cuda::barrier<cuda::thread_scope_block>::arrival_token barrier_{}_token;
",
barrier.id().unwrap()
),
}
}
BarrierOps::Init {
barrier,
is_elected,
arrival_count,
level,
} => {
match level {
BarrierLevel::Unit => write!(
f,
"
init(&{barrier}, {arrival_count});
"
),
BarrierLevel::Cube => write!(
f,
"
cooperative_groups::thread_block block_{barrier} = cooperative_groups::this_thread_block();
cuda::barrier<cuda::thread_scope_block>::arrival_token barrier_{}_token;
if ({is_elected}) {{
init(&{barrier}, {arrival_count});
}}
__syncthreads();
",
barrier.id().unwrap()
),
}
}
BarrierOps::InitManual {
barrier,
arrival_count,
} => {
writeln!(f, "init(&{barrier}, {arrival_count});")
}
BarrierOps::MemCopyAsync {
barrier,
source,
destination,
source_length,
offset_source,
offset_out,
cooperative,
} => {
let item = source.item();
let size = format!("sizeof({item})");
match cooperative {
false => write!(
f,
"
cuda::memcpy_async({destination} + {offset_out}, {source} + {offset_source}, {source_length} * {size}, {barrier});
"
),
true => write!(
f,
"
cuda::memcpy_async(block_{barrier}, {destination} + {offset_out}, {source} + {offset_source}, {source_length} * {size}, {barrier});
"
),
}
}
BarrierOps::MemCopyAsyncTx {
barrier,
source,
destination,
source_length,
offset_source,
offset_out,
} => {
let item = source.item();
let size = format!("sizeof({item})");
write!(
f,
"
cuda::device::memcpy_async_tx({destination} + {offset_out}, {source} + {offset_source}, {source_length} * {size}, {barrier});
"
)
}
BarrierOps::CopyAsync {
source,
destination,
source_length,
offset_source,
offset_out,
copy_size,
checked,
} => {
let item = source.item();
let size = format!("{source_length} * sizeof({item})");
match *checked {
false => write!(
f,
"
__cp_async_shared_global<{copy_size}>({source} + {offset_source}, {destination} + {offset_out});
"
),
true => write!(
f,
"
__cp_async_shared_global<{copy_size}>({source} + {offset_source}, {destination} + {offset_out}, {size});
"
),
}
}
BarrierOps::MemCopyAsyncTensorGlobalToShared {
barrier,
smem_buffer,
smem_offset,
tensor_map,
indices,
} => {
let rank = indices.len();
let smem_ptr = smem_buffer.fmt_ptr();
let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
let _ = write!(s, "{it}, ");
s
});
writeln!(
f,
"cuda::device::experimental::cp_async_bulk_tensor_{rank}d_global_to_shared({smem_ptr} + {smem_offset}, &{tensor_map}, {indices} {barrier});"
)
}
BarrierOps::TmaLoadIm2col {
barrier,
smem_buffer,
smem_offset,
tensor_map,
indices,
offsets,
} => {
let rank = indices.len();
let smem_ptr = smem_buffer.fmt_ptr();
let args: Vec<_> = indices
.iter()
.rev()
.map(|it| it.to_string())
.chain(offsets.iter().rev().map(|it| it.to_string()))
.collect();
writeln!(
f,
"tma_load_im2col_{rank}d(&{tensor_map}, {barrier}, {smem_ptr} + {smem_offset}, {});",
args.join(", ")
)
}
BarrierOps::Arrive { barrier, token, .. } => {
writeln!(f, "{token} = {barrier}.arrive();")
}
BarrierOps::ArriveTx {
barrier,
token,
arrive_count_update,
transaction_count_update,
} => {
writeln!(
f,
"{token} = cuda::device::barrier_arrive_tx({barrier}, {arrive_count_update}, {transaction_count_update});"
)
}
BarrierOps::ArriveCopyAsync { barrier } => {
writeln!(f, "__cp_async_arrive({barrier});")
}
BarrierOps::ExpectTx {
barrier,
transaction_count_update,
} => {
writeln!(
f,
"cuda::device::barrier_expect_tx({barrier}, {transaction_count_update});"
)
}
BarrierOps::Wait { barrier, token } => {
writeln!(f, "{barrier}.wait(std::move({token}));")
}
BarrierOps::WaitParity { barrier, phase } => {
writeln!(f, "{barrier}.wait_parity({phase});")
}
BarrierOps::ArriveAndWait { barrier, .. } => {
writeln!(f, "{barrier}.arrive_and_wait();")
}
}
}
}