1use cubecl_core::ir as core;
2use cubecl_core::ir::Metadata;
3use rspirv::spirv::{StorageClass, Word};
4
5use crate::{
6 SpirvCompiler, SpirvTarget,
7 item::{Elem, Item},
8 variable::Variable,
9};
10
11impl<T: SpirvTarget> SpirvCompiler<T> {
12 pub fn compile_meta(&mut self, meta: Metadata, out: Option<core::Variable>, uniform: bool) {
13 let out = out.unwrap();
14 match meta {
15 Metadata::Rank { var } => {
16 let var = self.compile_variable(var);
17 let out = self.compile_variable(out);
18 let pos = self.ext_pos(&var);
19
20 let out_id = self.write_id(&out);
21 self.mark_uniformity(out_id, uniform);
22
23 let offset = self.metadata.rank_index(pos);
24 self.load_const_metadata(offset, Some(out_id));
25 self.write(&out, out_id);
26 }
27 Metadata::Length { var } => {
28 let var = self.compile_variable(var);
29 let out = self.compile_variable(out);
30 self.length(&var, Some(&out), uniform);
31 }
32 Metadata::BufferLength { var } => {
33 let var = self.compile_variable(var);
34 let out = self.compile_variable(out);
35 self.buffer_length(&var, Some(&out), uniform);
36 }
37 Metadata::Stride { dim, var } => {
38 let int = self.type_int(32, 0);
39
40 let var = self.compile_variable(var);
41 let dim = self.compile_variable(dim);
42 let out = self.compile_variable(out);
43
44 let out_id = out.id(self);
45 self.mark_uniformity(out_id, uniform);
46
47 let pos = self.ext_pos(&var);
48
49 let offs_offset = self.metadata.stride_offset_index(pos);
50 let offset = self.load_const_metadata(offs_offset, None);
51 let dim_id = self.read(&dim);
52
53 let index = self.i_add(int, None, offset, dim_id).unwrap();
54 self.mark_uniformity(index, uniform);
55 let index = Variable::Id(index);
56 self.load_dyn_metadata(&index, &out);
57 }
58 Metadata::Shape { dim, var } => {
59 let int = self.type_int(32, 0);
60
61 let var = self.compile_variable(var);
62 let dim = self.compile_variable(dim);
63 let out = self.compile_variable(out);
64
65 let out_id = out.id(self);
66 self.mark_uniformity(out_id, uniform);
67
68 let pos = self.ext_pos(&var);
69
70 let offs_offset = self.metadata.shape_offset_index(pos);
71 let offset = self.load_const_metadata(offs_offset, None);
72 let dim_id = self.read(&dim);
73
74 let index = self.i_add(int, None, offset, dim_id).unwrap();
75 let index = Variable::Id(index);
76 self.load_dyn_metadata(&index, &out);
77 }
78 }
79 }
80
81 pub fn length(&mut self, var: &Variable, out: Option<&Variable>, uniform: bool) -> Word {
82 let (out_id, out_ty) = if let Some(out) = out {
83 let out_id = self.write_id(out);
84 self.mark_uniformity(out_id, uniform);
85 let out_ty = out.elem().id(self);
86 (Some(out_id), out_ty)
87 } else {
88 (None, self.type_int(32, 0))
89 };
90
91 let id = match var {
92 Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => {
93 let offset = self.metadata.len_index(*pos);
94 let id = self.load_const_metadata(offset, out_id);
95
96 if let Some(out_id) = out_id {
97 self.debug_name(out_id, format!("len({pos})"));
98 }
99 id
100 }
101 Variable::Slice {
102 const_len: Some(len),
103 ..
104 } => {
105 let len = self.const_u32(*len);
106 if out.is_some() {
107 self.copy_object(out_ty, out_id, len).unwrap()
108 } else {
109 len
110 }
111 }
112 Variable::Slice { offset, end, .. } => {
113 let len_ty = Elem::Int(32, false).id(self);
114 self.i_sub(len_ty, out_id, *end, *offset).unwrap()
115 }
116 Variable::SharedMemory(_, _, len)
117 | Variable::ConstantArray(_, _, len)
118 | Variable::LocalArray(_, _, len) => self.const_u32(*len),
119 var => unimplemented!("Var {var:?} doesn't have length"),
120 };
121 if let Some(out) = out {
122 self.write(out, id);
123 }
124 id
125 }
126
127 pub fn buffer_length(&mut self, var: &Variable, out: Option<&Variable>, uniform: bool) -> Word {
128 let out_id = out.map(|it| self.write_id(it));
129 if let Some(out_id) = out_id {
130 self.mark_uniformity(out_id, uniform);
131 }
132
133 let position = match var {
134 Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => *pos,
135 _ => panic!("Only Input and Output have a buffer length, got: {var:?}"),
136 };
137 let offset = self.metadata.buffer_len_index(position);
138 let id = self.load_const_metadata(offset, out_id);
139
140 if let Some(out) = out {
141 self.debug_name(out_id.unwrap(), format!("buffer_len({position})"));
142 self.write(out, id);
143 }
144 id
145 }
146
147 pub fn load_const_metadata(&mut self, index: u32, out: Option<Word>) -> Word {
148 self.insert_global(|b| {
149 let int = Item::Scalar(Elem::Int(32, false));
150 let int_ty = int.id(b);
151 let int_ptr = Item::Pointer(StorageClass::StorageBuffer, Box::new(int)).id(b);
152 let info = b.state.info;
153 let zero = b.const_u32(0);
154 let index = b.const_u32(index);
155 let info_ptr = b
156 .access_chain(int_ptr, None, info, vec![zero, index])
157 .unwrap();
158 b.load(int_ty, out, info_ptr, None, vec![]).unwrap()
159 })
160 }
161
162 pub fn load_dyn_metadata(&mut self, index: &Variable, out: &Variable) -> Word {
163 let int_ty = Item::Scalar(Elem::Int(32, false));
164 let info = Variable::Named {
165 id: self.state.info,
166 item: int_ty,
167 is_array: false,
168 };
169 self.read_indexed_unchecked(out, &info, index)
170 }
171
172 fn ext_pos(&self, var: &Variable) -> u32 {
173 let pos = match var {
174 Variable::GlobalInputArray(_, _, pos) | Variable::GlobalOutputArray(_, _, pos) => *pos,
175 _ => panic!("Only global buffers have rank"),
176 };
177 self.ext_meta_pos[pos as usize]
178 }
179}