cubecl_spirv/
metadata.rs

1use cubecl_core::ir as core;
2use cubecl_core::ir::Metadata;
3use rspirv::spirv::{StorageClass, Word};
4
5use crate::{
6    SpirvCompiler, SpirvTarget,
7    item::{Elem, Item},
8    variable::Variable,
9};
10
11impl<T: SpirvTarget> SpirvCompiler<T> {
12    pub fn compile_meta(&mut self, meta: Metadata, out: Option<core::Variable>, uniform: bool) {
13        let out = out.unwrap();
14        match meta {
15            Metadata::Rank { var } => {
16                let var = self.compile_variable(var);
17                let out = self.compile_variable(out);
18                let pos = self.ext_pos(&var);
19
20                let out_id = self.write_id(&out);
21                self.mark_uniformity(out_id, uniform);
22
23                let offset = self.metadata.rank_index(pos);
24                self.load_const_metadata(offset, Some(out_id));
25                self.write(&out, out_id);
26            }
27            Metadata::Length { var } => {
28                let var = self.compile_variable(var);
29                let out = self.compile_variable(out);
30                self.length(&var, Some(&out), uniform);
31            }
32            Metadata::BufferLength { var } => {
33                let var = self.compile_variable(var);
34                let out = self.compile_variable(out);
35                self.buffer_length(&var, Some(&out), uniform);
36            }
37            Metadata::Stride { dim, var } => {
38                let int = self.type_int(32, 0);
39
40                let var = self.compile_variable(var);
41                let dim = self.compile_variable(dim);
42                let out = self.compile_variable(out);
43
44                let out_id = out.id(self);
45                self.mark_uniformity(out_id, uniform);
46
47                let pos = self.ext_pos(&var);
48
49                let offs_offset = self.metadata.stride_offset_index(pos);
50                let offset = self.load_const_metadata(offs_offset, None);
51                let dim_id = self.read(&dim);
52
53                let index = self.i_add(int, None, offset, dim_id).unwrap();
54                self.mark_uniformity(index, uniform);
55                let index = Variable::Id(index);
56                self.load_dyn_metadata(&index, &out);
57            }
58            Metadata::Shape { dim, var } => {
59                let int = self.type_int(32, 0);
60
61                let var = self.compile_variable(var);
62                let dim = self.compile_variable(dim);
63                let out = self.compile_variable(out);
64
65                let out_id = out.id(self);
66                self.mark_uniformity(out_id, uniform);
67
68                let pos = self.ext_pos(&var);
69
70                let offs_offset = self.metadata.shape_offset_index(pos);
71                let offset = self.load_const_metadata(offs_offset, None);
72                let dim_id = self.read(&dim);
73
74                let index = self.i_add(int, None, offset, dim_id).unwrap();
75                let index = Variable::Id(index);
76                self.load_dyn_metadata(&index, &out);
77            }
78        }
79    }
80
81    pub fn length(&mut self, var: &Variable, out: Option<&Variable>, uniform: bool) -> Word {
82        let (out_id, out_ty) = if let Some(out) = out {
83            let out_id = self.write_id(out);
84            self.mark_uniformity(out_id, uniform);
85            let out_ty = out.elem().id(self);
86            (Some(out_id), out_ty)
87        } else {
88            (None, self.type_int(32, 0))
89        };
90
91        let id = match var {
92            Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => {
93                let offset = self.metadata.len_index(*pos);
94                let id = self.load_const_metadata(offset, out_id);
95
96                if let Some(out_id) = out_id {
97                    self.debug_name(out_id, format!("len({pos})"));
98                }
99                id
100            }
101            Variable::Slice {
102                const_len: Some(len),
103                ..
104            } => {
105                let len = self.const_u32(*len);
106                if out.is_some() {
107                    self.copy_object(out_ty, out_id, len).unwrap()
108                } else {
109                    len
110                }
111            }
112            Variable::Slice { offset, end, .. } => {
113                let len_ty = Elem::Int(32, false).id(self);
114                self.i_sub(len_ty, out_id, *end, *offset).unwrap()
115            }
116            Variable::SharedMemory(_, _, len)
117            | Variable::ConstantArray(_, _, len)
118            | Variable::LocalArray(_, _, len) => self.const_u32(*len),
119            var => unimplemented!("Var {var:?} doesn't have length"),
120        };
121        if let Some(out) = out {
122            self.write(out, id);
123        }
124        id
125    }
126
127    pub fn buffer_length(&mut self, var: &Variable, out: Option<&Variable>, uniform: bool) -> Word {
128        let out_id = out.map(|it| self.write_id(it));
129        if let Some(out_id) = out_id {
130            self.mark_uniformity(out_id, uniform);
131        }
132
133        let position = match var {
134            Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => *pos,
135            _ => panic!("Only Input and Output have a buffer length, got: {var:?}"),
136        };
137        let offset = self.metadata.buffer_len_index(position);
138        let id = self.load_const_metadata(offset, out_id);
139
140        if let Some(out) = out {
141            self.debug_name(out_id.unwrap(), format!("buffer_len({position})"));
142            self.write(out, id);
143        }
144        id
145    }
146
147    pub fn load_const_metadata(&mut self, index: u32, out: Option<Word>) -> Word {
148        self.insert_global(|b| {
149            let int = Item::Scalar(Elem::Int(32, false));
150            let int_ty = int.id(b);
151            let int_ptr = Item::Pointer(StorageClass::StorageBuffer, Box::new(int)).id(b);
152            let info = b.state.info;
153            let zero = b.const_u32(0);
154            let index = b.const_u32(index);
155            let info_ptr = b
156                .access_chain(int_ptr, None, info, vec![zero, index])
157                .unwrap();
158            b.load(int_ty, out, info_ptr, None, vec![]).unwrap()
159        })
160    }
161
162    pub fn load_dyn_metadata(&mut self, index: &Variable, out: &Variable) -> Word {
163        let int_ty = Item::Scalar(Elem::Int(32, false));
164        let info = Variable::Named {
165            id: self.state.info,
166            item: int_ty,
167            is_array: false,
168        };
169        self.read_indexed_unchecked(out, &info, index)
170    }
171
172    fn ext_pos(&self, var: &Variable) -> u32 {
173        let pos = match var {
174            Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => *pos,
175            _ => panic!("Only global buffers have rank"),
176        };
177        self.ext_meta_pos[pos as usize]
178    }
179}