1use cubecl_core::ir as core;
2use cubecl_core::ir::Metadata;
3use rspirv::spirv::{StorageClass, Word};
4
5use crate::{SpirvCompiler, SpirvTarget, item::Item, variable::Variable};
6
7impl<T: SpirvTarget> SpirvCompiler<T> {
8 pub fn compile_meta(&mut self, meta: Metadata, out: Option<core::Variable>, uniform: bool) {
9 let out = out.unwrap();
10 match meta {
11 Metadata::Rank { var } => {
12 let var = self.compile_variable(var);
13 let out = self.compile_variable(out);
14 let pos = self.ext_pos(&var);
15
16 let out_id = self.write_id(&out);
17 self.mark_uniformity(out_id, uniform);
18
19 let offset = self.metadata.rank_index(pos);
20 self.load_const_metadata(offset, Some(out_id), out.item());
21 self.write(&out, out_id);
22 }
23 Metadata::Length { var } => {
24 let var = self.compile_variable(var);
25 let out = self.compile_variable(out);
26 self.length(&var, Some(&out), uniform);
27 }
28 Metadata::BufferLength { var } => {
29 let var = self.compile_variable(var);
30 let out = self.compile_variable(out);
31 self.buffer_length(&var, Some(&out), uniform);
32 }
33 Metadata::Stride { dim, var } => {
34 let var = self.compile_variable(var);
35 let dim = self.compile_variable(dim);
36 let out = self.compile_variable(out);
37
38 let ty_id = out.item().id(self);
39 let out_id = out.id(self);
40 self.mark_uniformity(out_id, uniform);
41
42 let pos = self.ext_pos(&var);
43
44 let offs_offset = self.metadata.stride_offset_index(pos);
45 let offset = self.load_const_metadata(offs_offset, None, out.item());
46 let dim_id = self.read(&dim);
47
48 let index = self.i_add(ty_id, None, offset, dim_id).unwrap();
49 self.mark_uniformity(index, uniform);
50 let index = Variable::Raw(index, out.item());
51 self.load_dyn_metadata(&index, &out, out.item());
52 }
53 Metadata::Shape { dim, var } => {
54 let var = self.compile_variable(var);
55 let dim = self.compile_variable(dim);
56 let out = self.compile_variable(out);
57
58 let ty_id = out.item().id(self);
59 let out_id = out.id(self);
60 self.mark_uniformity(out_id, uniform);
61
62 let pos = self.ext_pos(&var);
63
64 let offs_offset = self.metadata.shape_offset_index(pos);
65 let offset = self.load_const_metadata(offs_offset, None, out.item());
66 let dim_id = self.read(&dim);
67
68 let index = self.i_add(ty_id, None, offset, dim_id).unwrap();
69 let index = Variable::Id(index);
70 self.load_dyn_metadata(&index, &out, out.item());
71 }
72 }
73 }
74
75 pub fn length(&mut self, var: &Variable, out: Option<&Variable>, uniform: bool) -> Word {
76 let (out_id, out_ty) = if let Some(out) = out {
77 let out_id = self.write_id(out);
78 self.mark_uniformity(out_id, uniform);
79 (Some(out_id), out.item())
80 } else {
81 (None, self.compile_type(self.addr_type.into()))
82 };
83 let ty_id = out_ty.id(self);
84
85 let id = match var {
86 Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => {
87 let offset = self.metadata.len_index(*pos);
88 let id = self.load_const_metadata(offset, out_id, out_ty);
89
90 if let Some(out_id) = out_id {
91 self.debug_name(out_id, format!("len({pos})"));
92 }
93 id
94 }
95 Variable::Slice {
96 const_len: Some(len),
97 ..
98 } => {
99 let len = out_ty.const_u32(self, *len);
100 if out.is_some() {
101 self.copy_object(ty_id, out_id, len).unwrap()
102 } else {
103 len
104 }
105 }
106 Variable::Slice { offset, end, .. } => {
107 self.i_sub(ty_id, out_id, *end, *offset).unwrap()
108 }
109 Variable::SharedArray(_, _, len)
110 | Variable::ConstantArray(_, _, len)
111 | Variable::LocalArray(_, _, len) => out_ty.const_u32(self, *len),
112 var => unimplemented!("Var {var:?} doesn't have length"),
113 };
114 if let Some(out) = out {
115 self.write(out, id);
116 }
117 id
118 }
119
120 pub fn buffer_length(&mut self, var: &Variable, out: Option<&Variable>, uniform: bool) -> Word {
121 let out_id = out.map(|it| self.write_id(it));
122 if let Some(out_id) = out_id {
123 self.mark_uniformity(out_id, uniform);
124 }
125 let out_ty = out
126 .map(|it| it.item())
127 .unwrap_or_else(|| self.compile_type(self.addr_type.into()));
128
129 let position = match var {
130 Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => *pos,
131 _ => panic!("Only Input and Output have a buffer length, got: {var:?}"),
132 };
133 let offset = self.metadata.buffer_len_index(position);
134 let id = self.load_const_metadata(offset, out_id, out_ty);
135
136 if let Some(out) = out {
137 self.debug_name(out_id.unwrap(), format!("buffer_len({position})"));
138 self.write(out, id);
139 }
140 id
141 }
142
143 pub fn load_const_metadata(&mut self, index: u32, out: Option<Word>, ty: Item) -> Word {
144 self.insert_in_setup(|b| {
145 let ty_id = ty.id(b);
146 let int_ptr = Item::Pointer(StorageClass::StorageBuffer, Box::new(ty)).id(b);
147 let info = b.state.info;
148 let zero = b.const_u32(0);
149 let index = b.const_u32(index);
150 let info_ptr = b
151 .access_chain(int_ptr, None, info, vec![zero, index])
152 .unwrap();
153 b.load(ty_id, out, info_ptr, None, vec![]).unwrap()
154 })
155 }
156
157 pub fn load_dyn_metadata(&mut self, index: &Variable, out: &Variable, ty: Item) -> Word {
158 let info = Variable::Named {
159 id: self.state.info,
160 item: ty,
161 is_array: false,
162 };
163 self.read_indexed_unchecked(out, &info, index)
164 }
165
166 fn ext_pos(&self, var: &Variable) -> u32 {
167 let pos = match var {
168 Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => *pos,
169 _ => panic!("Only global buffers have rank"),
170 };
171 self.ext_meta_pos[pos as usize]
172 }
173}