Skip to main content

cubecl_spirv/
metadata.rs

1use cubecl_core::ir as core;
2use cubecl_core::ir::Metadata;
3use rspirv::spirv::Word;
4
5use crate::{SpirvCompiler, SpirvTarget, item::Item, variable::Variable};
6
7impl<T: SpirvTarget> SpirvCompiler<T> {
8    pub fn compile_meta(&mut self, meta: Metadata, out: Option<core::Variable>, uniform: bool) {
9        let out = out.unwrap();
10        match meta {
11            Metadata::Rank { var } => {
12                let var = self.compile_variable(var);
13                let out = self.compile_variable(out);
14                let pos = self.ext_pos(&var);
15
16                let out_id = self.write_id(&out);
17                self.mark_uniformity(out_id, uniform);
18
19                let offset = self.info.metadata.rank_index(pos);
20                self.load_const_metadata(offset, Some(out_id), out.item());
21                self.write(&out, out_id);
22            }
23            Metadata::Length { var } => {
24                let var = self.compile_variable(var);
25                let out = self.compile_variable(out);
26                self.length(&var, Some(&out), uniform);
27            }
28            Metadata::BufferLength { var } => {
29                let var = self.compile_variable(var);
30                let out = self.compile_variable(out);
31                self.buffer_length(&var, Some(&out), uniform);
32            }
33            Metadata::Stride { dim, var } => {
34                let var = self.compile_variable(var);
35                let dim = self.compile_variable(dim);
36                let out = self.compile_variable(out);
37
38                let ty_id = out.item().id(self);
39                let out_id = out.id(self);
40                self.mark_uniformity(out_id, uniform);
41
42                let pos = self.ext_pos(&var);
43
44                let offs_offset = self.info.metadata.stride_offset_index(pos);
45                let offset = self.load_const_metadata(offs_offset, None, out.item());
46                let dim_id = self.read_as(&dim, &out.item());
47
48                let index = self.i_add(ty_id, None, offset, dim_id).unwrap();
49                self.mark_uniformity(index, uniform);
50                let index = Variable::Raw(index, out.item());
51                self.load_dyn_metadata(&index, Some(out_id), out.item());
52                self.write(&out, out_id);
53            }
54            Metadata::Shape { dim, var } => {
55                let var = self.compile_variable(var);
56                let dim = self.compile_variable(dim);
57                let out = self.compile_variable(out);
58
59                let ty_id = out.item().id(self);
60                let out_id = out.id(self);
61                self.mark_uniformity(out_id, uniform);
62
63                let pos = self.ext_pos(&var);
64
65                let offs_offset = self.info.metadata.shape_offset_index(pos);
66                let offset = self.load_const_metadata(offs_offset, None, out.item());
67                let dim_id = self.read_as(&dim, &out.item());
68
69                let index = self.i_add(ty_id, None, offset, dim_id).unwrap();
70                let index = Variable::Id(index);
71                self.load_dyn_metadata(&index, Some(out_id), out.item());
72                self.write(&out, out_id);
73            }
74        }
75    }
76
77    pub fn length(&mut self, var: &Variable, out: Option<&Variable>, uniform: bool) -> Word {
78        let (out_id, out_ty) = if let Some(out) = out {
79            let out_id = self.write_id(out);
80            self.mark_uniformity(out_id, uniform);
81            (Some(out_id), out.item())
82        } else {
83            (None, self.compile_type(self.addr_type.into()))
84        };
85        let ty_id = out_ty.id(self);
86
87        let id = match var {
88            Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => {
89                let offset = self.info.metadata.len_index(*pos);
90                let id = self.load_const_metadata(offset, out_id, out_ty);
91
92                if let Some(out_id) = out_id {
93                    self.debug_name(out_id, format!("len({pos})"));
94                }
95                id
96            }
97            Variable::Slice {
98                const_len: Some(len),
99                ..
100            } => {
101                let len = out_ty.const_u32(self, *len);
102                if out.is_some() {
103                    self.copy_object(ty_id, out_id, len).unwrap()
104                } else {
105                    len
106                }
107            }
108            Variable::Slice { offset, end, .. } => {
109                self.i_sub(ty_id, out_id, *end, *offset).unwrap()
110            }
111            Variable::SharedArray(_, _, len)
112            | Variable::ConstantArray(_, _, len)
113            | Variable::LocalArray(_, _, len) => out_ty.const_u32(self, *len),
114            var => unimplemented!("Var {var:?} doesn't have length"),
115        };
116        if let Some(out) = out {
117            self.write(out, id);
118        }
119        id
120    }
121
122    pub fn buffer_length(&mut self, var: &Variable, out: Option<&Variable>, uniform: bool) -> Word {
123        let out_id = out.map(|it| self.write_id(it));
124        if let Some(out_id) = out_id {
125            self.mark_uniformity(out_id, uniform);
126        }
127        let out_ty = out
128            .map(|it| it.item())
129            .unwrap_or_else(|| self.compile_type(self.addr_type.into()));
130
131        let position = match var {
132            Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => *pos,
133            _ => panic!("Only Input and Output have a buffer length, got: {var:?}"),
134        };
135        let offset = self.info.metadata.buffer_len_index(position);
136        let id = self.load_const_metadata(offset, out_id, out_ty);
137
138        if let Some(out) = out {
139            self.debug_name(out_id.unwrap(), format!("buffer_len({position})"));
140            self.write(out, id);
141        }
142        id
143    }
144
145    pub fn load_const_metadata(&mut self, index: u32, out: Option<Word>, ty: Item) -> Word {
146        self.insert_in_setup(|b| {
147            let ty_id = ty.id(b);
148            let storage_class = T::info_storage_class(b);
149            let ptr_ty = Item::Pointer(storage_class, Box::new(ty)).id(b);
150            let info = b.state.info;
151            let offset = b.const_u32(b.state.scalar_bindings.len() as u32);
152            let index = b.const_u32(index);
153            let info_ptr = b
154                .access_chain(ptr_ty, None, info, vec![offset, index])
155                .unwrap();
156            b.load(ty_id, out, info_ptr, None, vec![]).unwrap()
157        })
158    }
159
160    pub fn load_dyn_metadata(&mut self, index: &Variable, out: Option<Word>, ty: Item) -> Word {
161        let ty_id = ty.id(self);
162        let storage_class = T::info_storage_class(self);
163        let ptr_ty = Item::Pointer(storage_class, Box::new(ty)).id(self);
164        let info = self.state.info;
165        let offset = self.const_u32(self.state.scalar_bindings.len() as u32 + 1);
166        let index = self.read(index);
167        let info_ptr = self
168            .access_chain(ptr_ty, None, info, vec![offset, index])
169            .unwrap();
170        self.load(ty_id, out, info_ptr, None, vec![]).unwrap()
171    }
172
173    fn ext_pos(&self, var: &Variable) -> u32 {
174        let pos = match var {
175            Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => *pos,
176            _ => panic!("Only global buffers have rank"),
177        };
178        self.ext_meta_pos[pos as usize]
179    }
180}