1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{
6 SpirvCompiler, SpirvTarget,
7 item::{Elem, Item},
8 lookups::Array,
9};
10use cubecl_core::ir::{self, ConstantScalarValue, FloatKind, Id};
11use rspirv::{
12 dr::Builder,
13 spirv::{FPEncoding, StorageClass, Word},
14};
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum Variable {
18 SubgroupSize(Word),
19 GlobalInputArray(Word, Item, u32),
20 GlobalOutputArray(Word, Item, u32),
21 GlobalScalar(Word, Elem),
22 ConstantScalar(Word, ConstVal, Elem),
23 Local {
24 id: Word,
25 item: Item,
26 },
27 Versioned {
28 id: (Id, u16),
29 item: Item,
30 variable: ir::Variable,
31 },
32 LocalBinding {
33 id: Id,
34 item: Item,
35 variable: ir::Variable,
36 },
37 Raw(Word, Item),
38 Named {
39 id: Word,
40 item: Item,
41 is_array: bool,
42 },
43 Slice {
44 ptr: Box<Variable>,
45 offset: Word,
46 end: Word,
47 const_len: Option<u32>,
48 item: Item,
49 },
50 SharedArray(Word, Item, u32),
51 Shared(Word, Item),
52 ConstantArray(Word, Item, u32),
53 LocalArray(Word, Item, u32),
54 CoopMatrix(Id, Elem),
55 Id(Word),
56 LocalInvocationIndex(Word),
57 LocalInvocationIdX(Word),
58 LocalInvocationIdY(Word),
59 LocalInvocationIdZ(Word),
60 Rank(Word),
61 WorkgroupId(Word),
62 WorkgroupIdX(Word),
63 WorkgroupIdY(Word),
64 WorkgroupIdZ(Word),
65 GlobalInvocationIndex(Word),
66 GlobalInvocationIdX(Word),
67 GlobalInvocationIdY(Word),
68 GlobalInvocationIdZ(Word),
69 WorkgroupSize(Word),
70 WorkgroupSizeX(Word),
71 WorkgroupSizeY(Word),
72 WorkgroupSizeZ(Word),
73 NumWorkgroups(Word),
74 NumWorkgroupsX(Word),
75 NumWorkgroupsY(Word),
76 NumWorkgroupsZ(Word),
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
80pub enum ConstVal {
81 Bit32(u32),
82 Bit64(u64),
83}
84
85impl ConstVal {
86 pub fn as_u64(&self) -> u64 {
87 match self {
88 ConstVal::Bit32(val) => *val as u64,
89 ConstVal::Bit64(val) => *val,
90 }
91 }
92
93 pub fn as_u32(&self) -> u32 {
94 match self {
95 ConstVal::Bit32(val) => *val,
96 ConstVal::Bit64(_) => panic!("Truncating 64 bit variable to 32 bit"),
97 }
98 }
99
100 pub fn as_float(&self, width: u32, encoding: Option<FPEncoding>) -> f64 {
101 match (width, encoding) {
102 (64, _) => f64::from_bits(self.as_u64()),
103 (32, _) => f32::from_bits(self.as_u32()) as f64,
104 (16, None) => half::f16::from_bits(self.as_u32() as u16).to_f64(),
105 (_, Some(FPEncoding::BFloat16KHR)) => {
106 half::bf16::from_bits(self.as_u32() as u16).to_f64()
107 }
108 (_, Some(FPEncoding::Float8E4M3EXT)) => {
109 cubecl_common::e4m3::from_bits(self.as_u32() as u8).to_f64()
110 }
111 (_, Some(FPEncoding::Float8E5M2EXT)) => {
112 cubecl_common::e5m2::from_bits(self.as_u32() as u8).to_f64()
113 }
114 _ => unreachable!(),
115 }
116 }
117
118 pub fn as_int(&self, width: u32) -> i64 {
119 unsafe {
120 match width {
121 64 => transmute::<u64, i64>(self.as_u64()),
122 32 => transmute::<u32, i32>(self.as_u32()) as i64,
123 16 => transmute::<u16, i16>(self.as_u32() as u16) as i64,
124 8 => transmute::<u8, i8>(self.as_u32() as u8) as i64,
125 _ => unreachable!(),
126 }
127 }
128 }
129
130 pub fn from_float(value: f64, width: u32, encoding: Option<FPEncoding>) -> Self {
131 match (width, encoding) {
132 (64, _) => ConstVal::Bit64(value.to_bits()),
133 (32, _) => ConstVal::Bit32((value as f32).to_bits()),
134 (16, None) => ConstVal::Bit32(half::f16::from_f64(value).to_bits() as u32),
135 (_, Some(FPEncoding::BFloat16KHR)) => {
136 ConstVal::Bit32(half::bf16::from_f64(value).to_bits() as u32)
137 }
138 (_, Some(FPEncoding::Float8E4M3EXT)) => {
139 ConstVal::Bit32(cubecl_common::e4m3::from_f64(value).to_bits() as u32)
140 }
141 (_, Some(FPEncoding::Float8E5M2EXT)) => {
142 ConstVal::Bit32(cubecl_common::e5m2::from_f64(value).to_bits() as u32)
143 }
144 _ => unreachable!(),
145 }
146 }
147
148 pub fn from_int(value: i64, width: u32) -> Self {
149 match width {
150 64 => ConstVal::Bit64(unsafe { transmute::<i64, u64>(value) }),
151 32 => ConstVal::Bit32(unsafe { transmute::<i32, u32>(value as i32) }),
152 16 => ConstVal::Bit32(unsafe { transmute::<i16, u16>(value as i16) } as u32),
153 8 => ConstVal::Bit32(unsafe { transmute::<i8, u8>(value as i8) } as u32),
154 _ => unreachable!(),
155 }
156 }
157
158 pub fn from_uint(value: u64, width: u32) -> Self {
159 match width {
160 64 => ConstVal::Bit64(value),
161 32 => ConstVal::Bit32(value as u32),
162 16 => ConstVal::Bit32(value as u16 as u32),
163 8 => ConstVal::Bit32(value as u8 as u32),
164 _ => unreachable!(),
165 }
166 }
167
168 pub fn from_bool(value: bool) -> Self {
169 ConstVal::Bit32(value as u32)
170 }
171}
172
173impl From<ConstantScalarValue> for ConstVal {
174 fn from(value: ConstantScalarValue) -> Self {
175 let width = value.elem_type().size() as u32 * 8;
176 match value {
177 ConstantScalarValue::Int(val, _) => ConstVal::from_int(val, width),
178 ConstantScalarValue::Float(val, FloatKind::BF16) => {
179 ConstVal::from_float(val, width, Some(FPEncoding::BFloat16KHR))
180 }
181 ConstantScalarValue::Float(val, FloatKind::E4M3) => {
182 ConstVal::from_float(val, width, Some(FPEncoding::Float8E4M3EXT))
183 }
184 ConstantScalarValue::Float(val, FloatKind::E5M2) => {
185 ConstVal::from_float(val, width, Some(FPEncoding::Float8E5M2EXT))
186 }
187 ConstantScalarValue::Float(val, _) => ConstVal::from_float(val, width, None),
188 ConstantScalarValue::UInt(val, _) => ConstVal::from_uint(val, width),
189 ConstantScalarValue::Bool(val) => ConstVal::from_bool(val),
190 }
191 }
192}
193
194impl From<u32> for ConstVal {
195 fn from(value: u32) -> Self {
196 ConstVal::Bit32(value)
197 }
198}
199
200impl From<f32> for ConstVal {
201 fn from(value: f32) -> Self {
202 ConstVal::Bit32(value.to_bits())
203 }
204}
205
206impl Variable {
207 pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
208 match self {
209 Variable::GlobalInputArray(id, _, _) => *id,
210 Variable::GlobalOutputArray(id, _, _) => *id,
211 Variable::GlobalScalar(id, _) => *id,
212 Variable::ConstantScalar(id, _, _) => *id,
213 Variable::Local { id, .. } => *id,
214 Variable::Versioned {
215 id, variable: var, ..
216 } => b.get_versioned(*id, var),
217 Variable::LocalBinding {
218 id, variable: var, ..
219 } => b.get_binding(*id, var),
220 Variable::Raw(id, _) => *id,
221 Variable::Named { id, .. } => *id,
222 Variable::Slice { ptr, .. } => ptr.id(b),
223 Variable::SharedArray(id, _, _) => *id,
224 Variable::Shared(id, _) => *id,
225 Variable::ConstantArray(id, _, _) => *id,
226 Variable::LocalArray(id, _, _) => *id,
227 Variable::CoopMatrix(_, _) => unimplemented!("Can't get ID from matrix var"),
228 Variable::SubgroupSize(id) => *id,
229 Variable::Id(id) => *id,
230 Variable::LocalInvocationIndex(id) => *id,
231 Variable::LocalInvocationIdX(id) => *id,
232 Variable::LocalInvocationIdY(id) => *id,
233 Variable::LocalInvocationIdZ(id) => *id,
234 Variable::Rank(id) => *id,
235 Variable::WorkgroupId(id) => *id,
236 Variable::WorkgroupIdX(id) => *id,
237 Variable::WorkgroupIdY(id) => *id,
238 Variable::WorkgroupIdZ(id) => *id,
239 Variable::GlobalInvocationIndex(id) => *id,
240 Variable::GlobalInvocationIdX(id) => *id,
241 Variable::GlobalInvocationIdY(id) => *id,
242 Variable::GlobalInvocationIdZ(id) => *id,
243 Variable::WorkgroupSize(id) => *id,
244 Variable::WorkgroupSizeX(id) => *id,
245 Variable::WorkgroupSizeY(id) => *id,
246 Variable::WorkgroupSizeZ(id) => *id,
247 Variable::NumWorkgroups(id) => *id,
248 Variable::NumWorkgroupsX(id) => *id,
249 Variable::NumWorkgroupsY(id) => *id,
250 Variable::NumWorkgroupsZ(id) => *id,
251 }
252 }
253
254 pub fn item(&self) -> Item {
255 match self {
256 Variable::GlobalInputArray(_, item, _) => item.clone(),
257 Variable::GlobalOutputArray(_, item, _) => item.clone(),
258 Variable::GlobalScalar(_, elem) => Item::Scalar(*elem),
259 Variable::ConstantScalar(_, _, elem) => Item::Scalar(*elem),
260 Variable::Local { item, .. } => item.clone(),
261 Variable::Versioned { item, .. } => item.clone(),
262 Variable::LocalBinding { item, .. } => item.clone(),
263 Variable::Named { item, .. } => item.clone(),
264 Variable::Slice { item, .. } => item.clone(),
265 Variable::SharedArray(_, item, _) => item.clone(),
266 Variable::Shared(_, item) => item.clone(),
267 Variable::ConstantArray(_, item, _) => item.clone(),
268 Variable::LocalArray(_, item, _) => item.clone(),
269 Variable::CoopMatrix(_, elem) => Item::Scalar(*elem),
270 _ => Item::Scalar(Elem::Int(32, false)), }
272 }
273
274 pub fn indexed_item(&self) -> Item {
275 match self {
276 Variable::LocalBinding {
277 item: Item::Vector(elem, _),
278 ..
279 } => Item::Scalar(*elem),
280 Variable::Local {
281 item: Item::Vector(elem, _),
282 ..
283 } => Item::Scalar(*elem),
284 Variable::Versioned {
285 item: Item::Vector(elem, _),
286 ..
287 } => Item::Scalar(*elem),
288 other => other.item(),
289 }
290 }
291
292 pub fn elem(&self) -> Elem {
293 self.item().elem()
294 }
295
296 pub fn has_len(&self) -> bool {
297 matches!(
298 self,
299 Variable::GlobalInputArray(_, _, _)
300 | Variable::GlobalOutputArray(_, _, _)
301 | Variable::Named {
302 is_array: false,
303 ..
304 }
305 | Variable::Slice { .. }
306 | Variable::SharedArray(_, _, _)
307 | Variable::ConstantArray(_, _, _)
308 | Variable::LocalArray(_, _, _)
309 )
310 }
311
312 pub fn has_buffer_len(&self) -> bool {
313 matches!(
314 self,
315 Variable::GlobalInputArray(_, _, _)
316 | Variable::GlobalOutputArray(_, _, _)
317 | Variable::Named {
318 is_array: false,
319 ..
320 }
321 )
322 }
323
324 pub fn as_const(&self) -> Option<ConstVal> {
325 match self {
326 Self::ConstantScalar(_, val, _) => Some(*val),
327 _ => None,
328 }
329 }
330
331 pub fn as_binding(&self) -> Option<Id> {
332 match self {
333 Self::LocalBinding { id, .. } => Some(*id),
334 _ => None,
335 }
336 }
337}
338
339#[derive(Debug)]
340pub enum IndexedVariable {
341 Pointer(Word, Item),
342 Composite(Word, u32, Item),
343 DynamicComposite(Word, u32, Item),
344 Scalar(Variable),
345}
346
347impl<T: SpirvTarget> SpirvCompiler<T> {
348 pub fn compile_variable(&mut self, variable: ir::Variable) -> Variable {
349 let item = variable.ty;
350 match variable.kind {
351 ir::VariableKind::ConstantScalar(value) => {
352 let item = self.compile_type(ir::Type::new(value.storage_type()));
353 let const_val = value.into();
354
355 if let Some(existing) = self.state.constants.get(&(const_val, item.clone())) {
356 Variable::ConstantScalar(*existing, const_val, item.elem())
357 } else {
358 let id = item.elem().constant(self, const_val);
359 self.state.constants.insert((const_val, item.clone()), id);
360 Variable::ConstantScalar(id, const_val, item.elem())
361 }
362 }
363 ir::VariableKind::GlobalInputArray(pos) => {
364 let id = self.state.buffers[pos as usize];
365 Variable::GlobalInputArray(id, self.compile_type(item), pos)
366 }
367 ir::VariableKind::GlobalOutputArray(pos) => {
368 let id = self.state.buffers[pos as usize];
369 Variable::GlobalOutputArray(id, self.compile_type(item), pos)
370 }
371 ir::VariableKind::GlobalScalar(id) => self.global_scalar(id, item.storage_type()),
372 ir::VariableKind::LocalMut { id } => {
373 let item = self.compile_type(item);
374 let var = self.get_local(id, &item, variable);
375 Variable::Local { id: var, item }
376 }
377 ir::VariableKind::Versioned { id, version } => {
378 let item = self.compile_type(item);
379 let id = (id, version);
380 Variable::Versioned { id, item, variable }
381 }
382 ir::VariableKind::LocalConst { id } => {
383 let item = self.compile_type(item);
384 Variable::LocalBinding { id, item, variable }
385 }
386 ir::VariableKind::Builtin(builtin) => self.compile_builtin(builtin),
387 ir::VariableKind::ConstantArray { id, length, .. } => {
388 let item = self.compile_type(item);
389 let id = self.state.const_arrays[id as usize].id;
390 Variable::ConstantArray(id, item, length)
391 }
392 ir::VariableKind::SharedArray { id, length, .. } => {
393 let item = self.compile_type(item);
394 let id = self.state.shared_arrays[&id].id;
395 Variable::SharedArray(id, item, length)
396 }
397 ir::VariableKind::Shared { id } => {
398 let item = self.compile_type(item);
399 let id = self.state.shared[&id].id;
400 Variable::Shared(id, item)
401 }
402 ir::VariableKind::LocalArray {
403 id,
404 length,
405 unroll_factor,
406 } => {
407 let item = self.compile_type(item);
408 let id = if let Some(arr) = self.state.local_arrays.get(&id) {
409 arr.id
410 } else {
411 let arr_ty = Item::Array(Box::new(item.clone()), length);
412 let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(arr_ty)).id(self);
413 let arr_id = self.declare_function_variable(ptr_ty);
414 self.debug_var_name(arr_id, variable);
415 let arr = Array {
416 id: arr_id,
417 item: item.clone(),
418 len: length * unroll_factor,
419 var: variable,
420 alignment: None,
421 };
422 self.state.local_arrays.insert(id, arr);
423 arr_id
424 };
425 Variable::LocalArray(id, item, length)
426 }
427 ir::VariableKind::Matrix { id, mat } => {
428 let elem = self.compile_type(ir::Type::new(mat.storage)).elem();
429 if self.state.matrices.contains_key(&id) {
430 Variable::CoopMatrix(id, elem)
431 } else {
432 let matrix = self.init_coop_matrix(mat, variable);
433 self.state.matrices.insert(id, matrix);
434 Variable::CoopMatrix(id, elem)
435 }
436 }
437 ir::VariableKind::Pipeline { .. } => panic!("Pipeline not supported."),
438 ir::VariableKind::BarrierToken { .. } => {
439 panic!("Barrier not supported.")
440 }
441 ir::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
442 ir::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
443 }
444 }
445
446 pub fn read(&mut self, variable: &Variable) -> Word {
447 match variable {
448 Variable::Slice { ptr, .. } => self.read(ptr),
449 Variable::Shared(id, item) if self.compilation_options.supports_explicit_smem => {
450 let ty = item.id(self);
451 let ptr_ty =
452 Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
453 let index = vec![self.const_u32(0)];
454 let access = self.access_chain(ptr_ty, None, *id, index).unwrap();
455 self.load(ty, None, access, None, []).unwrap()
456 }
457 Variable::Local { id, item } | Variable::Shared(id, item) => {
458 let ty = item.id(self);
459 self.load(ty, None, *id, None, []).unwrap()
460 }
461 Variable::Named { id, item, .. } => {
462 let ty = item.id(self);
463 self.load(ty, None, *id, None, []).unwrap()
464 }
465 ssa => ssa.id(self),
466 }
467 }
468
469 pub fn read_as(&mut self, variable: &Variable, item: &Item) -> Word {
470 if let Some(as_const) = variable.as_const() {
471 self.static_cast(as_const, &variable.elem(), item)
472 } else {
473 let id = self.read(variable);
474 variable.item().cast_to(self, None, id, item)
475 }
476 }
477
478 pub fn index(
479 &mut self,
480 variable: &Variable,
481 index: &Variable,
482 unchecked: bool,
483 ) -> IndexedVariable {
484 let access_chain = if unchecked {
485 Builder::in_bounds_access_chain
486 } else {
487 Builder::access_chain
488 };
489 let index_id = self.read(index);
490 match variable {
491 Variable::GlobalInputArray(id, item, _)
492 | Variable::GlobalOutputArray(id, item, _)
493 | Variable::Named { id, item, .. } => {
494 let ptr_ty =
495 Item::Pointer(StorageClass::StorageBuffer, Box::new(item.clone())).id(self);
496 let zero = self.const_u32(0);
497 let id = access_chain(self, ptr_ty, None, *id, vec![zero, index_id]).unwrap();
498
499 IndexedVariable::Pointer(id, item.clone())
500 }
501 Variable::Local {
502 id,
503 item: Item::Vector(elem, _),
504 } => {
505 let ptr_ty =
506 Item::Pointer(StorageClass::Function, Box::new(Item::Scalar(*elem))).id(self);
507 let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
508
509 IndexedVariable::Pointer(id, Item::Scalar(*elem))
510 }
511 Variable::Shared(id, Item::Vector(elem, _)) => {
512 let ptr_ty =
513 Item::Pointer(StorageClass::Workgroup, Box::new(Item::Scalar(*elem))).id(self);
514
515 let mut index = vec![index_id];
516 if self.compilation_options.supports_explicit_smem {
517 index.insert(0, self.const_u32(0));
518 }
519
520 let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
521
522 IndexedVariable::Pointer(id, Item::Scalar(*elem))
523 }
524 Variable::LocalBinding {
525 id,
526 item: Item::Vector(elem, vec),
527 variable,
528 } if index.as_const().is_some() => IndexedVariable::Composite(
529 self.get_binding(*id, variable),
530 index.as_const().unwrap().as_u32(),
531 Item::Vector(*elem, *vec),
532 ),
533 Variable::LocalBinding {
534 id,
535 item: Item::Vector(elem, vec),
536 variable,
537 } => IndexedVariable::DynamicComposite(
538 self.get_binding(*id, variable),
539 index_id,
540 Item::Vector(*elem, *vec),
541 ),
542 Variable::Versioned {
543 id,
544 item: Item::Vector(elem, vec),
545 variable,
546 } if index.as_const().is_some() => IndexedVariable::Composite(
547 self.get_versioned(*id, variable),
548 index.as_const().unwrap().as_u32(),
549 Item::Vector(*elem, *vec),
550 ),
551 Variable::Versioned {
552 id,
553 item: Item::Vector(elem, vec),
554 variable,
555 } => IndexedVariable::DynamicComposite(
556 self.get_versioned(*id, variable),
557 index_id,
558 Item::Vector(*elem, *vec),
559 ),
560 Variable::Shared(id, item) if self.compilation_options.supports_explicit_smem => {
561 let ptr_ty =
562 Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
563 let index = vec![self.const_u32(0)];
564 let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
565 IndexedVariable::Pointer(id, item.clone())
566 }
567 Variable::Local { .. }
568 | Variable::Shared(..)
569 | Variable::LocalBinding { .. }
570 | Variable::Versioned { .. } => IndexedVariable::Scalar(variable.clone()),
571 Variable::Slice { ptr, offset, .. } => {
572 let item = Item::Scalar(Elem::Int(32, false));
573 let int = item.id(self);
574 let index = match index.as_const() {
575 Some(ConstVal::Bit32(0)) => *offset,
576 _ => self.i_add(int, None, *offset, index_id).unwrap(),
577 };
578 self.index(ptr, &Variable::Raw(index, item), unchecked)
579 }
580 Variable::SharedArray(id, item, _) => {
581 let ptr_ty =
582 Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
583 let mut index = vec![index_id];
584 if self.compilation_options.supports_explicit_smem {
585 index.insert(0, self.const_u32(0));
586 }
587 let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
588 IndexedVariable::Pointer(id, item.clone())
589 }
590 Variable::ConstantArray(id, item, _) | Variable::LocalArray(id, item, _) => {
591 let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
592 let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
593 IndexedVariable::Pointer(id, item.clone())
594 }
595 var => unimplemented!("Can't index into {var:?}"),
596 }
597 }
598
599 pub fn read_indexed(&mut self, out: &Variable, variable: &Variable, index: &Variable) -> Word {
600 let always_in_bounds = is_always_in_bounds(variable, index);
601 let indexed = self.index(variable, index, always_in_bounds);
602
603 let read = |b: &mut Self| match indexed {
604 IndexedVariable::Pointer(ptr, item) => {
605 let ty = item.id(b);
606 let out_id = b.write_id(out);
607 b.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
608 }
609 IndexedVariable::Composite(var, index, item) => {
610 let elem = item.elem();
611 let ty = elem.id(b);
612 let out_id = b.write_id(out);
613 b.composite_extract(ty, Some(out_id), var, vec![index])
614 .unwrap()
615 }
616 IndexedVariable::DynamicComposite(var, index, item) => {
617 let elem = item.elem();
618 let ty = elem.id(b);
619 let out_id = b.write_id(out);
620 b.vector_extract_dynamic(ty, Some(out_id), var, index)
621 .unwrap()
622 }
623 IndexedVariable::Scalar(var) => {
624 let ty = out.item().id(b);
625 let input = b.read(&var);
626 let out_id = b.write_id(out);
627 b.copy_object(ty, Some(out_id), input).unwrap();
628 b.write(out, out_id);
629 out_id
630 }
631 };
632
633 read(self)
634 }
635
636 pub fn read_indexed_unchecked(
637 &mut self,
638 out: &Variable,
639 variable: &Variable,
640 index: &Variable,
641 ) -> Word {
642 let indexed = self.index(variable, index, true);
643
644 match indexed {
645 IndexedVariable::Pointer(ptr, item) => {
646 let ty = item.id(self);
647 let out_id = self.write_id(out);
648 self.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
649 }
650 IndexedVariable::Composite(var, index, item) => {
651 let elem = item.elem();
652 let ty = elem.id(self);
653 let out_id = self.write_id(out);
654 self.composite_extract(ty, Some(out_id), var, vec![index])
655 .unwrap()
656 }
657 IndexedVariable::DynamicComposite(var, index, item) => {
658 let elem = item.elem();
659 let ty = elem.id(self);
660 let out_id = self.write_id(out);
661 self.vector_extract_dynamic(ty, Some(out_id), var, index)
662 .unwrap()
663 }
664 IndexedVariable::Scalar(var) => {
665 let ty = out.item().id(self);
666 let input = self.read(&var);
667 let out_id = self.write_id(out);
668 self.copy_object(ty, Some(out_id), input).unwrap();
669 self.write(out, out_id);
670 out_id
671 }
672 }
673 }
674
675 pub fn index_ptr(&mut self, var: &Variable, index: &Variable) -> Word {
676 let always_in_bounds = is_always_in_bounds(var, index);
677 match self.index(var, index, always_in_bounds) {
678 IndexedVariable::Pointer(ptr, _) => ptr,
679 other => unreachable!("{other:?}"),
680 }
681 }
682
683 pub fn write_id(&mut self, variable: &Variable) -> Word {
684 match variable {
685 Variable::LocalBinding { id, variable, .. } => self.get_binding(*id, variable),
686 Variable::Versioned { id, variable, .. } => self.get_versioned(*id, variable),
687 Variable::Local { .. } => self.id(),
688 Variable::Shared(..) => self.id(),
689 Variable::GlobalScalar(id, _) => *id,
690 Variable::Raw(id, _) => *id,
691 Variable::ConstantScalar(_, _, _) => panic!("Can't write to constant scalar"),
692 Variable::GlobalInputArray(_, _, _)
693 | Variable::GlobalOutputArray(_, _, _)
694 | Variable::Slice { .. }
695 | Variable::Named { .. }
696 | Variable::SharedArray(_, _, _)
697 | Variable::ConstantArray(_, _, _)
698 | Variable::LocalArray(_, _, _) => panic!("Can't write to unindexed array"),
699 global => panic!("Can't write to builtin {global:?}"),
700 }
701 }
702
703 pub fn write(&mut self, variable: &Variable, value: Word) {
704 match variable {
705 Variable::Shared(id, item) if self.compilation_options.supports_explicit_smem => {
706 let ptr_ty =
707 Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
708 let index = vec![self.const_u32(0)];
709 let access = self.access_chain(ptr_ty, None, *id, index).unwrap();
710 self.store(access, value, None, []).unwrap()
711 }
712 Variable::Local { id, .. } | Variable::Shared(id, _) => {
713 self.store(*id, value, None, []).unwrap()
714 }
715
716 Variable::Slice { ptr, .. } => self.write(ptr, value),
717 _ => {}
718 }
719 }
720
721 pub fn write_indexed(&mut self, out: &Variable, index: &Variable, value: Word) {
722 let always_in_bounds = is_always_in_bounds(out, index);
723 let variable = self.index(out, index, always_in_bounds);
724
725 let write = |b: &mut Self| match variable {
726 IndexedVariable::Pointer(ptr, _) => b.store(ptr, value, None, vec![]).unwrap(),
727 IndexedVariable::Composite(var, index, item) => {
728 let ty = item.id(b);
729 let id = b
730 .composite_insert(ty, None, value, var, vec![index])
731 .unwrap();
732 b.write(out, id);
733 }
734 IndexedVariable::DynamicComposite(var, index, item) => {
735 let ty = item.id(b);
736 let id = b
737 .vector_insert_dynamic(ty, None, value, var, index)
738 .unwrap();
739 b.write(out, id);
740 }
741 IndexedVariable::Scalar(var) => b.write(&var, value),
742 };
743
744 write(self)
745 }
746
747 pub fn write_indexed_unchecked(&mut self, out: &Variable, index: &Variable, value: Word) {
748 let variable = self.index(out, index, true);
749
750 match variable {
751 IndexedVariable::Pointer(ptr, _) => self.store(ptr, value, None, vec![]).unwrap(),
752 IndexedVariable::Composite(var, index, item) => {
753 let ty = item.id(self);
754 let out_id = self
755 .composite_insert(ty, None, value, var, vec![index])
756 .unwrap();
757 self.write(out, out_id);
758 }
759 IndexedVariable::DynamicComposite(var, index, item) => {
760 let ty = item.id(self);
761 let out_id = self
762 .vector_insert_dynamic(ty, None, value, var, index)
763 .unwrap();
764 self.write(out, out_id);
765 }
766 IndexedVariable::Scalar(var) => self.write(&var, value),
767 }
768 }
769}
770
771fn is_always_in_bounds(var: &Variable, index: &Variable) -> bool {
772 let len = match var {
773 Variable::SharedArray(_, _, len)
774 | Variable::ConstantArray(_, _, len)
775 | Variable::LocalArray(_, _, len)
776 | Variable::Slice {
777 const_len: Some(len),
778 ..
779 } => *len,
780 _ => return false,
781 };
782
783 let const_index = match index {
784 Variable::ConstantScalar(_, value, _) => value.as_u32(),
785 _ => return false,
786 };
787
788 const_index < len
789}