cubecl-cpp 0.10.0-pre.3

CPP transpiler for CubeCL
Documentation
use cubecl_core::prelude::Visibility;

use crate::{
    Dialect,
    shared::{Component, KernelArg, Variable},
};

use super::BufferAttribute;
use std::fmt::Display;

#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum AddressSpace {
    Constant,
    ConstDevice,
    Device,
    Thread,
    ThreadGroup,
    None,
}

impl AddressSpace {
    pub fn attribute(&self) -> BufferAttribute {
        match self {
            AddressSpace::Constant | AddressSpace::ConstDevice | AddressSpace::Device => {
                BufferAttribute::Buffer
            }
            AddressSpace::ThreadGroup => BufferAttribute::ThreadGroup,
            _ => BufferAttribute::None,
        }
    }
}

impl Display for AddressSpace {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            AddressSpace::Constant => f.write_str("constant"),
            AddressSpace::ConstDevice => f.write_str("const device"),
            AddressSpace::Device => f.write_str("device"),
            AddressSpace::ThreadGroup => f.write_str("threadgroup"),
            AddressSpace::Thread => f.write_str("thread"),
            AddressSpace::None => Ok(()),
        }
    }
}

impl From<AddressSpace> for Visibility {
    fn from(val: AddressSpace) -> Self {
        match val {
            AddressSpace::Constant => Visibility::Read,
            _ => Visibility::ReadWrite,
        }
    }
}

impl<D: Dialect> From<&KernelArg<D>> for AddressSpace {
    fn from(value: &KernelArg<D>) -> Self {
        match value.vis {
            Visibility::Read => AddressSpace::ConstDevice,
            Visibility::ReadWrite => AddressSpace::Device,
        }
    }
}

impl<D: Dialect> From<&Variable<D>> for AddressSpace {
    fn from(value: &Variable<D>) -> Self {
        match value {
            Variable::AbsolutePosBaseName
            | Variable::AbsolutePosX
            | Variable::AbsolutePosY
            | Variable::AbsolutePosZ
            | Variable::UnitPosBaseName
            | Variable::UnitPosX
            | Variable::UnitPosY
            | Variable::UnitPosZ
            | Variable::CubePosBaseName
            | Variable::CubePosX
            | Variable::CubePosY
            | Variable::CubePosZ
            | Variable::CubeDimBaseName
            | Variable::CubeDimX
            | Variable::CubeDimY
            | Variable::CubeDimZ
            | Variable::CubeCountBaseName
            | Variable::CubeCountX
            | Variable::CubeCountY
            | Variable::CubeCountZ
            | Variable::PlaneDim
            | Variable::UnitPosPlane => AddressSpace::None,
            Variable::GlobalInputArray(..) => AddressSpace::ConstDevice,
            Variable::GlobalOutputArray(..) => AddressSpace::Device,
            Variable::GlobalScalar { .. } => {
                if value.is_const() {
                    AddressSpace::ConstDevice
                } else {
                    AddressSpace::Device
                }
            }
            Variable::SharedArray(..) => AddressSpace::ThreadGroup,
            _ => AddressSpace::Thread,
        }
    }
}