cubecl-spirv 0.10.0

SPIR-V compiler for CubeCL
Documentation
use cubecl_core::ir as core;
use cubecl_core::ir::Metadata;
use rspirv::spirv::Word;

use crate::{SpirvCompiler, SpirvTarget, item::Item, variable::Variable};

impl<T: SpirvTarget> SpirvCompiler<T> {
    pub fn compile_meta(&mut self, meta: Metadata, out: Option<core::Variable>, uniform: bool) {
        let out = out.unwrap();
        match meta {
            Metadata::Rank { var } => {
                let var = self.compile_variable(var);
                let out = self.compile_variable(out);
                let pos = self.ext_pos(&var);

                let out_id = self.write_id(&out);
                self.mark_uniformity(out_id, uniform);

                let offset = self.info.metadata.rank_index(pos);
                self.load_const_metadata(offset, Some(out_id), out.item());
                self.write(&out, out_id);
            }
            Metadata::Length { var } => {
                let var = self.compile_variable(var);
                let out = self.compile_variable(out);
                self.length(&var, Some(&out), uniform);
            }
            Metadata::BufferLength { var } => {
                let var = self.compile_variable(var);
                let out = self.compile_variable(out);
                self.buffer_length(&var, Some(&out), uniform);
            }
            Metadata::Stride { dim, var } => {
                let var = self.compile_variable(var);
                let dim = self.compile_variable(dim);
                let out = self.compile_variable(out);

                let ty_id = out.item().id(self);
                let out_id = out.id(self);
                self.mark_uniformity(out_id, uniform);

                let pos = self.ext_pos(&var);

                let offs_offset = self.info.metadata.stride_offset_index(pos);
                let offset = self.load_const_metadata(offs_offset, None, out.item());
                let dim_id = self.read_as(&dim, &out.item());

                let index = self.i_add(ty_id, None, offset, dim_id).unwrap();
                self.mark_uniformity(index, uniform);
                let index = Variable::Raw(index, out.item());
                self.load_dyn_metadata(&index, Some(out_id), out.item());
                self.write(&out, out_id);
            }
            Metadata::Shape { dim, var } => {
                let var = self.compile_variable(var);
                let dim = self.compile_variable(dim);
                let out = self.compile_variable(out);

                let ty_id = out.item().id(self);
                let out_id = out.id(self);
                self.mark_uniformity(out_id, uniform);

                let pos = self.ext_pos(&var);

                let offs_offset = self.info.metadata.shape_offset_index(pos);
                let offset = self.load_const_metadata(offs_offset, None, out.item());
                let dim_id = self.read_as(&dim, &out.item());

                let index = self.i_add(ty_id, None, offset, dim_id).unwrap();
                let index = Variable::Id(index);
                self.load_dyn_metadata(&index, Some(out_id), out.item());
                self.write(&out, out_id);
            }
        }
    }

    pub fn length(&mut self, var: &Variable, out: Option<&Variable>, uniform: bool) -> Word {
        let (out_id, out_ty) = if let Some(out) = out {
            let out_id = self.write_id(out);
            self.mark_uniformity(out_id, uniform);
            (Some(out_id), out.item())
        } else {
            (None, self.compile_type(self.addr_type.into()))
        };
        let ty_id = out_ty.id(self);

        let id = match var {
            Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => {
                let offset = self.info.metadata.len_index(*pos);
                let id = self.load_const_metadata(offset, out_id, out_ty);

                if let Some(out_id) = out_id {
                    self.debug_name(out_id, format!("len({pos})"));
                }
                id
            }
            Variable::Slice {
                const_len: Some(len),
                ..
            } => {
                let len = out_ty.const_u32(self, *len);
                if out.is_some() {
                    self.copy_object(ty_id, out_id, len).unwrap()
                } else {
                    len
                }
            }
            Variable::Slice { offset, end, .. } => {
                self.i_sub(ty_id, out_id, *end, *offset).unwrap()
            }
            Variable::SharedArray(_, _, len)
            | Variable::ConstantArray(_, _, len)
            | Variable::LocalArray(_, _, len) => out_ty.const_u32(self, *len),
            var => unimplemented!("Var {var:?} doesn't have length"),
        };
        if let Some(out) = out {
            self.write(out, id);
        }
        id
    }

    pub fn buffer_length(&mut self, var: &Variable, out: Option<&Variable>, uniform: bool) -> Word {
        let out_id = out.map(|it| self.write_id(it));
        if let Some(out_id) = out_id {
            self.mark_uniformity(out_id, uniform);
        }
        let out_ty = out
            .map(|it| it.item())
            .unwrap_or_else(|| self.compile_type(self.addr_type.into()));

        let position = match var {
            Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => *pos,
            _ => panic!("Only Input and Output have a buffer length, got: {var:?}"),
        };
        let offset = self.info.metadata.buffer_len_index(position);
        let id = self.load_const_metadata(offset, out_id, out_ty);

        if let Some(out) = out {
            self.debug_name(out_id.unwrap(), format!("buffer_len({position})"));
            self.write(out, id);
        }
        id
    }

    pub fn load_const_metadata(&mut self, index: u32, out: Option<Word>, ty: Item) -> Word {
        self.insert_in_setup(|b| {
            let ty_id = ty.id(b);
            let storage_class = T::info_storage_class(b);
            let ptr_ty = Item::Pointer(storage_class, Box::new(ty)).id(b);
            let info = b.state.info;
            let offset = b.const_u32(b.state.scalar_bindings.len() as u32);
            let index = b.const_u32(index);
            let info_ptr = b
                .access_chain(ptr_ty, None, info, vec![offset, index])
                .unwrap();
            b.load(ty_id, out, info_ptr, None, vec![]).unwrap()
        })
    }

    pub fn load_dyn_metadata(&mut self, index: &Variable, out: Option<Word>, ty: Item) -> Word {
        let ty_id = ty.id(self);
        let storage_class = T::info_storage_class(self);
        let ptr_ty = Item::Pointer(storage_class, Box::new(ty)).id(self);
        let info = self.state.info;
        let offset = self.const_u32(self.state.scalar_bindings.len() as u32 + 1);
        let index = self.read(index);
        let info_ptr = self
            .access_chain(ptr_ty, None, info, vec![offset, index])
            .unwrap();
        self.load(ty_id, out, info_ptr, None, vec![]).unwrap()
    }

    fn ext_pos(&self, var: &Variable) -> u32 {
        let pos = match var {
            Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => *pos,
            _ => panic!("Only global buffers have rank"),
        };
        self.ext_meta_pos[pos as usize]
    }
}