1use cubecl_core::ir as core;
2use cubecl_core::ir::Metadata;
3use rspirv::spirv::Word;
4
5use crate::{SpirvCompiler, SpirvTarget, item::Item, variable::Variable};
6
7impl<T: SpirvTarget> SpirvCompiler<T> {
8 pub fn compile_meta(&mut self, meta: Metadata, out: Option<core::Variable>, uniform: bool) {
9 let out = out.unwrap();
10 match meta {
11 Metadata::Rank { var } => {
12 let var = self.compile_variable(var);
13 let out = self.compile_variable(out);
14 let pos = self.ext_pos(&var);
15
16 let out_id = self.write_id(&out);
17 self.mark_uniformity(out_id, uniform);
18
19 let offset = self.info.metadata.rank_index(pos);
20 self.load_const_metadata(offset, Some(out_id), out.item());
21 self.write(&out, out_id);
22 }
23 Metadata::Length { var } => {
24 let var = self.compile_variable(var);
25 let out = self.compile_variable(out);
26 self.length(&var, Some(&out), uniform);
27 }
28 Metadata::BufferLength { var } => {
29 let var = self.compile_variable(var);
30 let out = self.compile_variable(out);
31 self.buffer_length(&var, Some(&out), uniform);
32 }
33 Metadata::Stride { dim, var } => {
34 let var = self.compile_variable(var);
35 let dim = self.compile_variable(dim);
36 let out = self.compile_variable(out);
37
38 let ty_id = out.item().id(self);
39 let out_id = out.id(self);
40 self.mark_uniformity(out_id, uniform);
41
42 let pos = self.ext_pos(&var);
43
44 let offs_offset = self.info.metadata.stride_offset_index(pos);
45 let offset = self.load_const_metadata(offs_offset, None, out.item());
46 let dim_id = self.read_as(&dim, &out.item());
47
48 let index = self.i_add(ty_id, None, offset, dim_id).unwrap();
49 self.mark_uniformity(index, uniform);
50 let index = Variable::Raw(index, out.item());
51 self.load_dyn_metadata(&index, Some(out_id), out.item());
52 self.write(&out, out_id);
53 }
54 Metadata::Shape { dim, var } => {
55 let var = self.compile_variable(var);
56 let dim = self.compile_variable(dim);
57 let out = self.compile_variable(out);
58
59 let ty_id = out.item().id(self);
60 let out_id = out.id(self);
61 self.mark_uniformity(out_id, uniform);
62
63 let pos = self.ext_pos(&var);
64
65 let offs_offset = self.info.metadata.shape_offset_index(pos);
66 let offset = self.load_const_metadata(offs_offset, None, out.item());
67 let dim_id = self.read_as(&dim, &out.item());
68
69 let index = self.i_add(ty_id, None, offset, dim_id).unwrap();
70 let index = Variable::Id(index);
71 self.load_dyn_metadata(&index, Some(out_id), out.item());
72 self.write(&out, out_id);
73 }
74 }
75 }
76
77 pub fn length(&mut self, var: &Variable, out: Option<&Variable>, uniform: bool) -> Word {
78 let (out_id, out_ty) = if let Some(out) = out {
79 let out_id = self.write_id(out);
80 self.mark_uniformity(out_id, uniform);
81 (Some(out_id), out.item())
82 } else {
83 (None, self.compile_type(self.addr_type.into()))
84 };
85 let ty_id = out_ty.id(self);
86
87 let id = match var {
88 Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => {
89 let offset = self.info.metadata.len_index(*pos);
90 let id = self.load_const_metadata(offset, out_id, out_ty);
91
92 if let Some(out_id) = out_id {
93 self.debug_name(out_id, format!("len({pos})"));
94 }
95 id
96 }
97 Variable::Slice {
98 const_len: Some(len),
99 ..
100 } => {
101 let len = out_ty.const_u32(self, *len);
102 if out.is_some() {
103 self.copy_object(ty_id, out_id, len).unwrap()
104 } else {
105 len
106 }
107 }
108 Variable::Slice { offset, end, .. } => {
109 self.i_sub(ty_id, out_id, *end, *offset).unwrap()
110 }
111 Variable::SharedArray(_, _, len)
112 | Variable::ConstantArray(_, _, len)
113 | Variable::LocalArray(_, _, len) => out_ty.const_u32(self, *len),
114 var => unimplemented!("Var {var:?} doesn't have length"),
115 };
116 if let Some(out) = out {
117 self.write(out, id);
118 }
119 id
120 }
121
122 pub fn buffer_length(&mut self, var: &Variable, out: Option<&Variable>, uniform: bool) -> Word {
123 let out_id = out.map(|it| self.write_id(it));
124 if let Some(out_id) = out_id {
125 self.mark_uniformity(out_id, uniform);
126 }
127 let out_ty = out
128 .map(|it| it.item())
129 .unwrap_or_else(|| self.compile_type(self.addr_type.into()));
130
131 let position = match var {
132 Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => *pos,
133 _ => panic!("Only Input and Output have a buffer length, got: {var:?}"),
134 };
135 let offset = self.info.metadata.buffer_len_index(position);
136 let id = self.load_const_metadata(offset, out_id, out_ty);
137
138 if let Some(out) = out {
139 self.debug_name(out_id.unwrap(), format!("buffer_len({position})"));
140 self.write(out, id);
141 }
142 id
143 }
144
145 pub fn load_const_metadata(&mut self, index: u32, out: Option<Word>, ty: Item) -> Word {
146 self.insert_in_setup(|b| {
147 let ty_id = ty.id(b);
148 let storage_class = T::info_storage_class(b);
149 let ptr_ty = Item::Pointer(storage_class, Box::new(ty)).id(b);
150 let info = b.state.info;
151 let offset = b.const_u32(b.state.scalar_bindings.len() as u32);
152 let index = b.const_u32(index);
153 let info_ptr = b
154 .access_chain(ptr_ty, None, info, vec![offset, index])
155 .unwrap();
156 b.load(ty_id, out, info_ptr, None, vec![]).unwrap()
157 })
158 }
159
160 pub fn load_dyn_metadata(&mut self, index: &Variable, out: Option<Word>, ty: Item) -> Word {
161 let ty_id = ty.id(self);
162 let storage_class = T::info_storage_class(self);
163 let ptr_ty = Item::Pointer(storage_class, Box::new(ty)).id(self);
164 let info = self.state.info;
165 let offset = self.const_u32(self.state.scalar_bindings.len() as u32 + 1);
166 let index = self.read(index);
167 let info_ptr = self
168 .access_chain(ptr_ty, None, info, vec![offset, index])
169 .unwrap();
170 self.load(ty_id, out, info_ptr, None, vec![]).unwrap()
171 }
172
173 fn ext_pos(&self, var: &Variable) -> u32 {
174 let pos = match var {
175 Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => *pos,
176 _ => panic!("Only global buffers have rank"),
177 };
178 self.ext_meta_pos[pos as usize]
179 }
180}