cubecl-cpp 0.10.0-pre.3

CPP transpiler for CubeCL
Documentation
use std::fmt::{Display, Write};

use cubecl_core::ir::BarrierLevel;

use super::{Component, Dialect, Variable};

#[derive(Debug, Clone)]
pub enum BarrierOps<D: Dialect> {
    Declare {
        barrier: Variable<D>,
        level: BarrierLevel,
    },
    Init {
        barrier: Variable<D>,
        is_elected: Variable<D>,
        arrival_count: Variable<D>,
        level: BarrierLevel,
    },
    InitManual {
        barrier: Variable<D>,
        arrival_count: Variable<D>,
    },
    MemCopyAsync {
        barrier: Variable<D>,
        source: Variable<D>,
        destination: Variable<D>,
        source_length: Variable<D>,
        offset_source: Variable<D>,
        offset_out: Variable<D>,
        cooperative: bool,
    },
    MemCopyAsyncTx {
        barrier: Variable<D>,
        source: Variable<D>,
        destination: Variable<D>,
        source_length: Variable<D>,
        offset_source: Variable<D>,
        offset_out: Variable<D>,
    },
    CopyAsync {
        source: Variable<D>,
        destination: Variable<D>,
        source_length: Variable<D>,
        offset_source: Variable<D>,
        offset_out: Variable<D>,
        copy_size: u32,
        checked: bool,
    },
    MemCopyAsyncTensorGlobalToShared {
        barrier: Variable<D>,
        smem_buffer: Variable<D>,
        smem_offset: Variable<D>,
        tensor_map: Variable<D>,
        indices: Vec<Variable<D>>,
    },
    TmaLoadIm2col {
        barrier: Variable<D>,
        smem_buffer: Variable<D>,
        smem_offset: Variable<D>,
        tensor_map: Variable<D>,
        indices: Vec<Variable<D>>,
        offsets: Vec<Variable<D>>,
    },
    Arrive {
        barrier: Variable<D>,
        token: Variable<D>,
    },
    ArriveTx {
        barrier: Variable<D>,
        token: Variable<D>,
        arrive_count_update: Variable<D>,
        transaction_count_update: Variable<D>,
    },
    ArriveCopyAsync {
        barrier: Variable<D>,
    },
    ExpectTx {
        barrier: Variable<D>,
        transaction_count_update: Variable<D>,
    },
    Wait {
        barrier: Variable<D>,
        token: Variable<D>,
    },
    WaitParity {
        barrier: Variable<D>,
        phase: Variable<D>,
    },
    ArriveAndWait {
        barrier: Variable<D>,
        level: BarrierLevel,
    },
}

impl<D: Dialect> BarrierOps<D> {
    pub fn barrier_id(&self) -> u32 {
        match self {
            BarrierOps::MemCopyAsync { barrier, .. } => barrier.id().unwrap(),
            BarrierOps::MemCopyAsyncTx { barrier, .. } => barrier.id().unwrap(),
            BarrierOps::Declare { barrier, .. } => barrier.id().unwrap(),
            BarrierOps::Init { barrier, .. } => barrier.id().unwrap(),
            BarrierOps::InitManual { barrier, .. } => barrier.id().unwrap(),
            BarrierOps::ArriveAndWait { barrier, .. } => barrier.id().unwrap(),
            BarrierOps::Arrive { barrier, .. } => barrier.id().unwrap(),
            BarrierOps::ArriveTx { barrier, .. } => barrier.id().unwrap(),
            BarrierOps::Wait { barrier, .. } => barrier.id().unwrap(),
            BarrierOps::WaitParity { barrier, .. } => barrier.id().unwrap(),
            BarrierOps::MemCopyAsyncTensorGlobalToShared { barrier, .. } => barrier.id().unwrap(),
            BarrierOps::TmaLoadIm2col { barrier, .. } => barrier.id().unwrap(),
            BarrierOps::ExpectTx { barrier, .. } => barrier.id().unwrap(),
            BarrierOps::CopyAsync { .. } => 0,
            BarrierOps::ArriveCopyAsync { barrier } => barrier.id().unwrap(),
        }
    }
}

impl<D: Dialect> Display for BarrierOps<D> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            BarrierOps::Declare { barrier, level } => {
                match level {
                    // Note: Arrival token exists for cuda::thread_scope_thread, but is not public.
                    // So skip creating the token for unit barriers.
                    BarrierLevel::Unit => Ok(()),
                    BarrierLevel::Cube => write!(
                        f,
                        "
cooperative_groups::thread_block block_{barrier} = cooperative_groups::this_thread_block();
cuda::barrier<cuda::thread_scope_block>::arrival_token barrier_{}_token;
",
                        barrier.id().unwrap()
                    ),
                }
            }
            BarrierOps::Init {
                barrier,
                is_elected,
                arrival_count,
                level,
            } => {
                match level {
                    // Note: Arrival token exists for cuda::thread_scope_thread, but is not public.
                    // So skip creating the token for unit barriers.
                    BarrierLevel::Unit => write!(
                        f,
                        "
init(&{barrier}, {arrival_count});
                "
                    ),
                    BarrierLevel::Cube => write!(
                        f,
                        "
cooperative_groups::thread_block block_{barrier} = cooperative_groups::this_thread_block();
cuda::barrier<cuda::thread_scope_block>::arrival_token barrier_{}_token;
if ({is_elected}) {{
   init(&{barrier}, {arrival_count});
}}
__syncthreads();
",
                        barrier.id().unwrap()
                    ),
                }
            }
            BarrierOps::InitManual {
                barrier,
                arrival_count,
            } => {
                writeln!(f, "init(&{barrier}, {arrival_count});")
            }
            BarrierOps::MemCopyAsync {
                barrier,
                source,
                destination,
                source_length,
                offset_source,
                offset_out,
                cooperative,
            } => {
                let item = source.item();
                let size = format!("sizeof({item})");
                match cooperative {
                    false => write!(
                        f,
                        "
cuda::memcpy_async({destination} + {offset_out}, {source} + {offset_source}, {source_length} * {size}, {barrier});
                    "
                    ),
                    true => write!(
                        f,
                        "
cuda::memcpy_async(block_{barrier}, {destination} + {offset_out}, {source} + {offset_source}, {source_length} * {size}, {barrier});
                        "
                    ),
                }
            }
            BarrierOps::MemCopyAsyncTx {
                barrier,
                source,
                destination,
                source_length,
                offset_source,
                offset_out,
            } => {
                let item = source.item();
                let size = format!("sizeof({item})");
                write!(
                        f,
                        "
cuda::device::memcpy_async_tx({destination} + {offset_out}, {source} + {offset_source}, {source_length} * {size}, {barrier});
                        "
                    )
            }
            BarrierOps::CopyAsync {
                source,
                destination,
                source_length,
                offset_source,
                offset_out,
                copy_size,
                checked,
            } => {
                let item = source.item();
                let size = format!("{source_length} * sizeof({item})");
                match *checked {
                    false => write!(
                        f,
                        "
__cp_async_shared_global<{copy_size}>({source} + {offset_source}, {destination} + {offset_out});
                    "
                    ),
                    true => write!(
                        f,
                        "
__cp_async_shared_global<{copy_size}>({source} + {offset_source}, {destination} + {offset_out}, {size});
                        "
                    ),
                }
            }
            BarrierOps::MemCopyAsyncTensorGlobalToShared {
                barrier,
                smem_buffer,
                smem_offset,
                tensor_map,
                indices,
            } => {
                let rank = indices.len();
                let smem_ptr = smem_buffer.fmt_ptr();
                let indices = indices.iter().rev().fold(String::new(), |mut s, it| {
                    let _ = write!(s, "{it}, ");
                    s
                });
                writeln!(
                    f,
                    "cuda::device::experimental::cp_async_bulk_tensor_{rank}d_global_to_shared({smem_ptr} + {smem_offset}, &{tensor_map}, {indices} {barrier});"
                )
            }
            BarrierOps::TmaLoadIm2col {
                barrier,
                smem_buffer,
                smem_offset,
                tensor_map,
                indices,
                offsets,
            } => {
                let rank = indices.len();
                let smem_ptr = smem_buffer.fmt_ptr();
                let args: Vec<_> = indices
                    .iter()
                    .rev()
                    .map(|it| it.to_string())
                    .chain(offsets.iter().rev().map(|it| it.to_string()))
                    .collect();
                writeln!(
                    f,
                    "tma_load_im2col_{rank}d(&{tensor_map}, {barrier}, {smem_ptr} + {smem_offset}, {});",
                    args.join(", ")
                )
            }
            BarrierOps::Arrive { barrier, token, .. } => {
                writeln!(f, "{token} = {barrier}.arrive();")
            }
            BarrierOps::ArriveTx {
                barrier,
                token,
                arrive_count_update,
                transaction_count_update,
            } => {
                writeln!(
                    f,
                    "{token} = cuda::device::barrier_arrive_tx({barrier}, {arrive_count_update}, {transaction_count_update});"
                )
            }
            BarrierOps::ArriveCopyAsync { barrier } => {
                writeln!(f, "__cp_async_arrive({barrier});")
            }
            BarrierOps::ExpectTx {
                barrier,
                transaction_count_update,
            } => {
                writeln!(
                    f,
                    "cuda::device::barrier_expect_tx({barrier}, {transaction_count_update});"
                )
            }
            BarrierOps::Wait { barrier, token } => {
                writeln!(f, "{barrier}.wait(std::move({token}));")
            }
            BarrierOps::WaitParity { barrier, phase } => {
                writeln!(f, "{barrier}.wait_parity({phase});")
            }
            BarrierOps::ArriveAndWait { barrier, .. } => {
                writeln!(f, "{barrier}.arrive_and_wait();")
            }
        }
    }
}