cubecl-spirv 0.10.0-pre.4

SPIR-V compiler for CubeCL
Documentation
use cubecl_core::ir::{self, Builtin, ElemType, UIntKind};
use rspirv::spirv::{BuiltIn, Word};

use crate::{
    SpirvCompiler, SpirvTarget,
    item::{Elem, Item},
    variable::Variable,
};

impl<T: SpirvTarget> SpirvCompiler<T> {
    pub fn compile_builtin(&mut self, builtin: Builtin, ty: Item) -> Variable {
        match builtin {
            Builtin::UnitPos => Variable::Builtin(
                self.insert_global(builtin, |b| {
                    let id = b.load_builtin(BuiltIn::LocalInvocationIndex, &ty);
                    b.debug_name(id, "UNIT_POS");
                    id
                }),
                ty,
            ),
            Builtin::UnitPosX => Variable::Builtin(
                self.insert_global(builtin, |b| {
                    let id = b.extract(BuiltIn::LocalInvocationId, 0, &ty);
                    b.debug_name(id, "UNIT_POS_X");
                    id
                }),
                ty,
            ),
            Builtin::UnitPosY => Variable::Builtin(
                self.insert_global(builtin, |b| {
                    let id = b.extract(BuiltIn::LocalInvocationId, 1, &ty);
                    b.debug_name(id, "UNIT_POS_Y");
                    id
                }),
                ty,
            ),
            Builtin::UnitPosZ => Variable::Builtin(
                self.insert_global(builtin, |b| {
                    let id = b.extract(BuiltIn::LocalInvocationId, 2, &ty);
                    b.debug_name(id, "UNIT_POS_Z");
                    id
                }),
                ty,
            ),
            Builtin::CubePosX => Variable::Builtin(
                self.insert_global(builtin, |b| {
                    let id = b.extract(BuiltIn::WorkgroupId, 0, &ty);
                    b.debug_name(id, "CUBE_POS_X");
                    id
                }),
                ty,
            ),
            Builtin::CubePosY => Variable::Builtin(
                self.insert_global(builtin, |b| {
                    let id = b.extract(BuiltIn::WorkgroupId, 1, &ty);
                    b.debug_name(id, "CUBE_POS_Y");
                    id
                }),
                ty,
            ),
            Builtin::CubePosZ => Variable::Builtin(
                self.insert_global(builtin, |b| {
                    let id = b.extract(BuiltIn::WorkgroupId, 2, &ty);
                    b.debug_name(id, "CUBE_POS_Z");
                    id
                }),
                ty,
            ),
            Builtin::CubePosCluster
            | Builtin::CubePosClusterX
            | Builtin::CubePosClusterY
            | Builtin::CubePosClusterZ => self.constant_var(0, ty),
            Builtin::CubeDim => Variable::Builtin(self.state.cube_size, ty),
            Builtin::CubeDimX => Variable::Builtin(self.state.cube_dims[0], ty),
            Builtin::CubeDimY => Variable::Builtin(self.state.cube_dims[1], ty),
            Builtin::CubeDimZ => Variable::Builtin(self.state.cube_dims[2], ty),
            Builtin::CubeClusterDim
            | Builtin::CubeClusterDimX
            | Builtin::CubeClusterDimY
            | Builtin::CubeClusterDimZ => self.constant_var(1, ty),
            Builtin::CubeCount => Variable::Builtin(
                self.insert_global(builtin, |b: &mut SpirvCompiler<T>| {
                    let ty_id = ty.id(b);
                    let x = b.compile_variable(builtin_u32(Builtin::CubeCountX)).id(b);
                    let y = b.compile_variable(builtin_u32(Builtin::CubeCountY)).id(b);
                    let z = b.compile_variable(builtin_u32(Builtin::CubeCountZ)).id(b);

                    let x = Item::builtin_u32().cast_to(b, None, x, &ty);
                    let y = Item::builtin_u32().cast_to(b, None, y, &ty);
                    let z = Item::builtin_u32().cast_to(b, None, z, &ty);

                    let count = b.i_mul(ty_id, None, x, y).unwrap();
                    let count = b.i_mul(ty_id, None, count, z).unwrap();
                    b.debug_name(count, "CUBE_COUNT");
                    count
                }),
                ty,
            ),
            Builtin::CubeCountX => Variable::Builtin(
                self.insert_global(builtin, |b| {
                    let id = b.extract(BuiltIn::NumWorkgroups, 0, &ty);
                    b.debug_name(id, "CUBE_COUNT_X");
                    id
                }),
                ty,
            ),
            Builtin::CubeCountY => Variable::Builtin(
                self.insert_global(builtin, |b| {
                    let id = b.extract(BuiltIn::NumWorkgroups, 1, &ty);
                    b.debug_name(id, "CUBE_COUNT_Y");
                    id
                }),
                ty,
            ),
            Builtin::CubeCountZ => Variable::Builtin(
                self.insert_global(builtin, |b| {
                    let id = b.extract(BuiltIn::NumWorkgroups, 2, &ty);
                    b.debug_name(id, "CUBE_COUNT_Z");
                    id
                }),
                ty,
            ),
            Builtin::PlaneDim => {
                let id = self.insert_global(builtin, |b| {
                    let id = b.load_builtin(BuiltIn::SubgroupSize, &ty);
                    b.debug_name(id, "PLANE_DIM");
                    id
                });
                Variable::Builtin(id, ty)
            }
            Builtin::PlanePos => {
                let id = self.insert_global(builtin, |b| {
                    let id = b.load_builtin(BuiltIn::SubgroupId, &ty);
                    b.debug_name(id, "PLANE_POS");
                    id
                });
                Variable::Builtin(id, ty)
            }
            Builtin::UnitPosPlane => {
                let id = self.insert_global(builtin, |b| {
                    let id = b.load_builtin(BuiltIn::SubgroupLocalInvocationId, &ty);
                    b.debug_name(id, "UNIT_POS_PLANE");
                    id
                });
                Variable::Builtin(id, ty)
            }
            Builtin::CubePos => {
                let id = self.insert_global(builtin, |b| {
                    let x = b.compile_variable(builtin_u32(Builtin::CubePosX)).id(b);
                    let y = b.compile_variable(builtin_u32(Builtin::CubePosY)).id(b);
                    let z = b.compile_variable(builtin_u32(Builtin::CubePosZ)).id(b);

                    let x = Item::builtin_u32().cast_to(b, None, x, &ty);
                    let y = Item::builtin_u32().cast_to(b, None, y, &ty);
                    let z = Item::builtin_u32().cast_to(b, None, z, &ty);

                    let groups_x = b.compile_variable(builtin_u32(Builtin::CubeCountX)).id(b);
                    let groups_y = b.compile_variable(builtin_u32(Builtin::CubeCountY)).id(b);

                    let groups_x = Item::builtin_u32().cast_to(b, None, groups_x, &ty);
                    let groups_y = Item::builtin_u32().cast_to(b, None, groups_y, &ty);

                    let ty = ty.id(b);
                    let id = b.i_mul(ty, None, z, groups_y).unwrap();
                    let id = b.i_add(ty, None, id, y).unwrap();
                    let id = b.i_mul(ty, None, id, groups_x).unwrap();
                    let id = b.i_add(ty, None, id, x).unwrap();
                    b.debug_name(id, "CUBE_POS");
                    id
                });
                Variable::Builtin(id, ty)
            }
            Builtin::AbsolutePos => {
                let id = self.insert_global(builtin, |b| {
                    let x = b.compile_variable(builtin_u32(Builtin::AbsolutePosX)).id(b);
                    let y = b.compile_variable(builtin_u32(Builtin::AbsolutePosY)).id(b);
                    let z = b.compile_variable(builtin_u32(Builtin::AbsolutePosZ)).id(b);

                    let x = Item::builtin_u32().cast_to(b, None, x, &ty);
                    let y = Item::builtin_u32().cast_to(b, None, y, &ty);
                    let z = Item::builtin_u32().cast_to(b, None, z, &ty);

                    let groups_x = b.compile_variable(builtin_u32(Builtin::CubeCountX)).id(b);
                    let groups_y = b.compile_variable(builtin_u32(Builtin::CubeCountY)).id(b);

                    let groups_x = Item::builtin_u32().cast_to(b, None, groups_x, &ty);
                    let groups_y = Item::builtin_u32().cast_to(b, None, groups_y, &ty);

                    let size_x = ty.const_u32(b, b.cube_dim.x);
                    let size_y = ty.const_u32(b, b.cube_dim.y);

                    let ty = ty.id(b);
                    let size_x = b.i_mul(ty, None, groups_x, size_x).unwrap();
                    let size_y = b.i_mul(ty, None, groups_y, size_y).unwrap();
                    let id = b.i_mul(ty, None, z, size_y).unwrap();
                    let id = b.i_add(ty, None, id, y).unwrap();
                    let id = b.i_mul(ty, None, id, size_x).unwrap();
                    let id = b.i_add(ty, None, id, x).unwrap();
                    b.debug_name(id, "ABSOLUTE_POS");
                    id
                });
                Variable::Builtin(id, ty)
            }
            Builtin::AbsolutePosX => {
                let id = self.insert_global(builtin, |b| {
                    let id = b.extract(BuiltIn::GlobalInvocationId, 0, &ty);
                    b.debug_name(id, "ABSOLUTE_POS_X");
                    id
                });

                Variable::Builtin(id, ty)
            }
            Builtin::AbsolutePosY => {
                let id = self.insert_global(builtin, |b| {
                    let id = b.extract(BuiltIn::GlobalInvocationId, 1, &ty);
                    b.debug_name(id, "ABSOLUTE_POS_Y");
                    id
                });

                Variable::Builtin(id, ty)
            }
            Builtin::AbsolutePosZ => {
                let id = self.insert_global(builtin, |b| {
                    let id = b.extract(BuiltIn::GlobalInvocationId, 2, &ty);
                    b.debug_name(id, "ABSOLUTE_POS_Z");
                    id
                });

                Variable::Builtin(id, ty)
            }
        }
    }

    fn constant_var(&mut self, value: u32, ty: Item) -> Variable {
        let id = ty.const_u32(self, value);
        Variable::Builtin(id, ty.clone())
    }

    fn extract(&mut self, builtin: BuiltIn, idx: u32, ty: &Item) -> Word {
        let composite_id = self.vec_global(builtin);
        let ty = ty.id(self);
        self.composite_extract(ty, None, composite_id, vec![idx])
            .unwrap()
    }

    fn vec_global(&mut self, builtin: BuiltIn) -> Word {
        let item = Item::Vector(Elem::Int(32, false), 3);

        self.insert_builtin(builtin, |b| b.load_builtin(builtin, &item))
    }

    fn load_builtin(&mut self, builtin: BuiltIn, item: &Item) -> Word {
        let item_id = item.id(self);
        let id = self.builtin(builtin, item.clone());
        self.load(item_id, None, id, None, vec![]).unwrap()
    }
}

fn builtin_u32(builtin: Builtin) -> ir::Variable {
    ir::Variable::builtin(builtin, ElemType::UInt(UIntKind::U32).into())
}