cubecl-cpp 0.3.0

CPP transpiler for CubeCL
Documentation
use super::{Body, Dialect, Item, Variable};
use cubecl_core::{ir::CubeDim, CompilerRepresentation};
use std::{collections::HashSet, fmt::Display};

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Binding<D: Dialect> {
    pub item: Item<D>,
    pub size: Option<usize>,
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct SharedMemory<D: Dialect> {
    pub index: u16,
    pub item: Item<D>,
    pub size: u32,
}

#[derive(Debug, PartialEq, Clone)]
pub struct ConstArray<D: Dialect> {
    pub index: u16,
    pub item: Item<D>,
    pub size: u32,
    pub values: Vec<Variable<D>>,
}

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct LocalArray<D: Dialect> {
    pub index: u16,
    pub item: Item<D>,
    pub depth: u8,
    pub size: u32,
}

impl<D: Dialect> LocalArray<D> {
    pub fn new(index: u16, item: Item<D>, depth: u8, size: u32) -> Self {
        Self {
            index,
            item,
            depth,
            size,
        }
    }
}

impl<D: Dialect> SharedMemory<D> {
    pub fn new(index: u16, item: Item<D>, size: u32) -> Self {
        Self { index, item, size }
    }
}

#[derive(Debug, Clone)]
pub struct ComputeKernel<D: Dialect> {
    pub inputs: Vec<Binding<D>>,
    pub outputs: Vec<Binding<D>>,
    pub named: Vec<(String, Binding<D>)>,
    pub cube_dim: CubeDim,
    pub body: Body<D>,
    pub wmma_activated: bool,
    pub bf16: bool,
    pub f16: bool,
    pub items: HashSet<super::Item<D>>,
}

impl<D: Dialect> CompilerRepresentation for ComputeKernel<D> {
    fn shared_memory_size(&self) -> usize {
        let mut current = 0usize;

        for var in self.body.shared_memories.iter() {
            let factor = var.item.vectorization;
            let elem_size_bytes = var.item.elem().size();
            current += (var.size as usize) * factor * elem_size_bytes;
        }

        current
    }
}

impl<D: Dialect> Display for ComputeKernel<D> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        if self.wmma_activated {
            D::include_wmma(f)?;
        }
        if self.bf16 {
            D::include_bf16(f)?;
        }

        if self.f16 {
            D::include_f16(f)?;
        }

        f.write_str("typedef unsigned int uint;\n")?;

        for item in self.items.iter() {
            let elem = item.elem;
            let size = item.vectorization;
            let alignment = elem.size() * size;
            if size > 1 {
                write!(
                    f,
                    "
struct __align__({alignment}) {item} {{"
                )?;

                for i in 0..size {
                    write!(
                        f,
                        "
    {elem} i_{i};"
                    )?;
                }

                f.write_str("\n};\n")?;
            }
        }

        write!(
            f,
            "

extern \"C\" __global__ void kernel(
",
        )?;

        let num_bindings = self.inputs.len() + self.outputs.len() + self.named.len();
        let mut binding_index = 0;
        for (index, binding) in self.inputs.iter().enumerate() {
            binding_index += 1;
            write!(f, "{} input_{}[]", binding.item, index)?;
            if binding_index < num_bindings {
                f.write_str(",")?;
            }
        }
        for (index, binding) in self.outputs.iter().enumerate() {
            binding_index += 1;
            write!(f, "{} output_{}[]", binding.item, index)?;
            if binding_index < num_bindings {
                f.write_str(",")?;
            }
        }
        for (name, binding) in self.named.iter() {
            binding_index += 1;
            write!(f, "{} {}[]", binding.item, name)?;

            if binding_index < num_bindings {
                f.write_str(",")?;
            }
        }

        f.write_str("\n) {\n")?;

        write!(f, "{}", self.body)?;
        f.write_str("\n}")?;

        Ok(())
    }
}