cubecl_spirv/
metadata.rs

1use cubecl_core::ir as core;
2use cubecl_core::ir::Metadata;
3use rspirv::spirv::{StorageClass, Word};
4
5use crate::{SpirvCompiler, SpirvTarget, item::Item, variable::Variable};
6
7impl<T: SpirvTarget> SpirvCompiler<T> {
8    pub fn compile_meta(&mut self, meta: Metadata, out: Option<core::Variable>, uniform: bool) {
9        let out = out.unwrap();
10        match meta {
11            Metadata::Rank { var } => {
12                let var = self.compile_variable(var);
13                let out = self.compile_variable(out);
14                let pos = self.ext_pos(&var);
15
16                let out_id = self.write_id(&out);
17                self.mark_uniformity(out_id, uniform);
18
19                let offset = self.metadata.rank_index(pos);
20                self.load_const_metadata(offset, Some(out_id), out.item());
21                self.write(&out, out_id);
22            }
23            Metadata::Length { var } => {
24                let var = self.compile_variable(var);
25                let out = self.compile_variable(out);
26                self.length(&var, Some(&out), uniform);
27            }
28            Metadata::BufferLength { var } => {
29                let var = self.compile_variable(var);
30                let out = self.compile_variable(out);
31                self.buffer_length(&var, Some(&out), uniform);
32            }
33            Metadata::Stride { dim, var } => {
34                let var = self.compile_variable(var);
35                let dim = self.compile_variable(dim);
36                let out = self.compile_variable(out);
37
38                let ty_id = out.item().id(self);
39                let out_id = out.id(self);
40                self.mark_uniformity(out_id, uniform);
41
42                let pos = self.ext_pos(&var);
43
44                let offs_offset = self.metadata.stride_offset_index(pos);
45                let offset = self.load_const_metadata(offs_offset, None, out.item());
46                let dim_id = self.read(&dim);
47
48                let index = self.i_add(ty_id, None, offset, dim_id).unwrap();
49                self.mark_uniformity(index, uniform);
50                let index = Variable::Raw(index, out.item());
51                self.load_dyn_metadata(&index, &out, out.item());
52            }
53            Metadata::Shape { dim, var } => {
54                let var = self.compile_variable(var);
55                let dim = self.compile_variable(dim);
56                let out = self.compile_variable(out);
57
58                let ty_id = out.item().id(self);
59                let out_id = out.id(self);
60                self.mark_uniformity(out_id, uniform);
61
62                let pos = self.ext_pos(&var);
63
64                let offs_offset = self.metadata.shape_offset_index(pos);
65                let offset = self.load_const_metadata(offs_offset, None, out.item());
66                let dim_id = self.read(&dim);
67
68                let index = self.i_add(ty_id, None, offset, dim_id).unwrap();
69                let index = Variable::Id(index);
70                self.load_dyn_metadata(&index, &out, out.item());
71            }
72        }
73    }
74
75    pub fn length(&mut self, var: &Variable, out: Option<&Variable>, uniform: bool) -> Word {
76        let (out_id, out_ty) = if let Some(out) = out {
77            let out_id = self.write_id(out);
78            self.mark_uniformity(out_id, uniform);
79            (Some(out_id), out.item())
80        } else {
81            (None, self.compile_type(self.addr_type.into()))
82        };
83        let ty_id = out_ty.id(self);
84
85        let id = match var {
86            Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => {
87                let offset = self.metadata.len_index(*pos);
88                let id = self.load_const_metadata(offset, out_id, out_ty);
89
90                if let Some(out_id) = out_id {
91                    self.debug_name(out_id, format!("len({pos})"));
92                }
93                id
94            }
95            Variable::Slice {
96                const_len: Some(len),
97                ..
98            } => {
99                let len = out_ty.const_u32(self, *len);
100                if out.is_some() {
101                    self.copy_object(ty_id, out_id, len).unwrap()
102                } else {
103                    len
104                }
105            }
106            Variable::Slice { offset, end, .. } => {
107                self.i_sub(ty_id, out_id, *end, *offset).unwrap()
108            }
109            Variable::SharedArray(_, _, len)
110            | Variable::ConstantArray(_, _, len)
111            | Variable::LocalArray(_, _, len) => out_ty.const_u32(self, *len),
112            var => unimplemented!("Var {var:?} doesn't have length"),
113        };
114        if let Some(out) = out {
115            self.write(out, id);
116        }
117        id
118    }
119
120    pub fn buffer_length(&mut self, var: &Variable, out: Option<&Variable>, uniform: bool) -> Word {
121        let out_id = out.map(|it| self.write_id(it));
122        if let Some(out_id) = out_id {
123            self.mark_uniformity(out_id, uniform);
124        }
125        let out_ty = out
126            .map(|it| it.item())
127            .unwrap_or_else(|| self.compile_type(self.addr_type.into()));
128
129        let position = match var {
130            Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => *pos,
131            _ => panic!("Only Input and Output have a buffer length, got: {var:?}"),
132        };
133        let offset = self.metadata.buffer_len_index(position);
134        let id = self.load_const_metadata(offset, out_id, out_ty);
135
136        if let Some(out) = out {
137            self.debug_name(out_id.unwrap(), format!("buffer_len({position})"));
138            self.write(out, id);
139        }
140        id
141    }
142
143    pub fn load_const_metadata(&mut self, index: u32, out: Option<Word>, ty: Item) -> Word {
144        self.insert_in_setup(|b| {
145            let ty_id = ty.id(b);
146            let int_ptr = Item::Pointer(StorageClass::StorageBuffer, Box::new(ty)).id(b);
147            let info = b.state.info;
148            let zero = b.const_u32(0);
149            let index = b.const_u32(index);
150            let info_ptr = b
151                .access_chain(int_ptr, None, info, vec![zero, index])
152                .unwrap();
153            b.load(ty_id, out, info_ptr, None, vec![]).unwrap()
154        })
155    }
156
157    pub fn load_dyn_metadata(&mut self, index: &Variable, out: &Variable, ty: Item) -> Word {
158        let info = Variable::Named {
159            id: self.state.info,
160            item: ty,
161            is_array: false,
162        };
163        self.read_indexed_unchecked(out, &info, index)
164    }
165
166    fn ext_pos(&self, var: &Variable) -> u32 {
167        let pos = match var {
168            Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => *pos,
169            _ => panic!("Only global buffers have rank"),
170        };
171        self.ext_meta_pos[pos as usize]
172    }
173}