1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{
6 SpirvCompiler, SpirvTarget,
7 item::{Elem, Item},
8 lookups::Array,
9};
10use cubecl_core::ir::{self, ConstantScalarValue, FloatKind, Id};
11use rspirv::{
12 dr::Builder,
13 spirv::{FPEncoding, StorageClass, Word},
14};
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum Variable {
18 SubgroupSize(Word),
19 GlobalInputArray(Word, Item, u32),
20 GlobalOutputArray(Word, Item, u32),
21 GlobalScalar(Word, Elem),
22 ConstantScalar(Word, ConstVal, Elem),
23 Local {
24 id: Word,
25 item: Item,
26 },
27 Versioned {
28 id: (Id, u16),
29 item: Item,
30 variable: ir::Variable,
31 },
32 LocalBinding {
33 id: Id,
34 item: Item,
35 variable: ir::Variable,
36 },
37 Raw(Word, Item),
38 Named {
39 id: Word,
40 item: Item,
41 is_array: bool,
42 },
43 Slice {
44 ptr: Box<Variable>,
45 offset: Word,
46 end: Word,
47 const_len: Option<u32>,
48 item: Item,
49 },
50 SharedMemory(Word, Item, u32),
51 ConstantArray(Word, Item, u32),
52 LocalArray(Word, Item, u32),
53 CoopMatrix(Id, Elem),
54 Id(Word),
55 LocalInvocationIndex(Word),
56 LocalInvocationIdX(Word),
57 LocalInvocationIdY(Word),
58 LocalInvocationIdZ(Word),
59 Rank(Word),
60 WorkgroupId(Word),
61 WorkgroupIdX(Word),
62 WorkgroupIdY(Word),
63 WorkgroupIdZ(Word),
64 GlobalInvocationIndex(Word),
65 GlobalInvocationIdX(Word),
66 GlobalInvocationIdY(Word),
67 GlobalInvocationIdZ(Word),
68 WorkgroupSize(Word),
69 WorkgroupSizeX(Word),
70 WorkgroupSizeY(Word),
71 WorkgroupSizeZ(Word),
72 NumWorkgroups(Word),
73 NumWorkgroupsX(Word),
74 NumWorkgroupsY(Word),
75 NumWorkgroupsZ(Word),
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
79pub enum ConstVal {
80 Bit32(u32),
81 Bit64(u64),
82}
83
84impl ConstVal {
85 pub fn as_u64(&self) -> u64 {
86 match self {
87 ConstVal::Bit32(val) => *val as u64,
88 ConstVal::Bit64(val) => *val,
89 }
90 }
91
92 pub fn as_u32(&self) -> u32 {
93 match self {
94 ConstVal::Bit32(val) => *val,
95 ConstVal::Bit64(_) => panic!("Truncating 64 bit variable to 32 bit"),
96 }
97 }
98
99 pub fn as_float(&self, width: u32, encoding: Option<FPEncoding>) -> f64 {
100 match (width, encoding) {
101 (64, _) => f64::from_bits(self.as_u64()),
102 (32, _) => f32::from_bits(self.as_u32()) as f64,
103 (16, None) => half::f16::from_bits(self.as_u32() as u16).to_f64(),
104 (_, Some(FPEncoding::BFloat16KHR)) => {
105 half::bf16::from_bits(self.as_u32() as u16).to_f64()
106 }
107 (_, Some(FPEncoding::Float8E4M3EXT)) => {
108 cubecl_common::e4m3::from_bits(self.as_u32() as u8).to_f64()
109 }
110 (_, Some(FPEncoding::Float8E5M2EXT)) => {
111 cubecl_common::e5m2::from_bits(self.as_u32() as u8).to_f64()
112 }
113 _ => unreachable!(),
114 }
115 }
116
117 pub fn as_int(&self, width: u32) -> i64 {
118 unsafe {
119 match width {
120 64 => transmute::<u64, i64>(self.as_u64()),
121 32 => transmute::<u32, i32>(self.as_u32()) as i64,
122 16 => transmute::<u16, i16>(self.as_u32() as u16) as i64,
123 8 => transmute::<u8, i8>(self.as_u32() as u8) as i64,
124 _ => unreachable!(),
125 }
126 }
127 }
128
129 pub fn from_float(value: f64, width: u32, encoding: Option<FPEncoding>) -> Self {
130 match (width, encoding) {
131 (64, _) => ConstVal::Bit64(value.to_bits()),
132 (32, _) => ConstVal::Bit32((value as f32).to_bits()),
133 (16, None) => ConstVal::Bit32(half::f16::from_f64(value).to_bits() as u32),
134 (_, Some(FPEncoding::BFloat16KHR)) => {
135 ConstVal::Bit32(half::bf16::from_f64(value).to_bits() as u32)
136 }
137 (_, Some(FPEncoding::Float8E4M3EXT)) => {
138 ConstVal::Bit32(cubecl_common::e4m3::from_f64(value).to_bits() as u32)
139 }
140 (_, Some(FPEncoding::Float8E5M2EXT)) => {
141 ConstVal::Bit32(cubecl_common::e5m2::from_f64(value).to_bits() as u32)
142 }
143 _ => unreachable!(),
144 }
145 }
146
147 pub fn from_int(value: i64, width: u32) -> Self {
148 match width {
149 64 => ConstVal::Bit64(unsafe { transmute::<i64, u64>(value) }),
150 32 => ConstVal::Bit32(unsafe { transmute::<i32, u32>(value as i32) }),
151 16 => ConstVal::Bit32(unsafe { transmute::<i16, u16>(value as i16) } as u32),
152 8 => ConstVal::Bit32(unsafe { transmute::<i8, u8>(value as i8) } as u32),
153 _ => unreachable!(),
154 }
155 }
156
157 pub fn from_uint(value: u64, width: u32) -> Self {
158 match width {
159 64 => ConstVal::Bit64(value),
160 32 => ConstVal::Bit32(value as u32),
161 16 => ConstVal::Bit32(value as u16 as u32),
162 8 => ConstVal::Bit32(value as u8 as u32),
163 _ => unreachable!(),
164 }
165 }
166
167 pub fn from_bool(value: bool) -> Self {
168 ConstVal::Bit32(value as u32)
169 }
170}
171
172impl From<ConstantScalarValue> for ConstVal {
173 fn from(value: ConstantScalarValue) -> Self {
174 let width = value.elem_type().size() as u32 * 8;
175 match value {
176 ConstantScalarValue::Int(val, _) => ConstVal::from_int(val, width),
177 ConstantScalarValue::Float(val, FloatKind::BF16) => {
178 ConstVal::from_float(val, width, Some(FPEncoding::BFloat16KHR))
179 }
180 ConstantScalarValue::Float(val, FloatKind::E4M3) => {
181 ConstVal::from_float(val, width, Some(FPEncoding::Float8E4M3EXT))
182 }
183 ConstantScalarValue::Float(val, FloatKind::E5M2) => {
184 ConstVal::from_float(val, width, Some(FPEncoding::Float8E5M2EXT))
185 }
186 ConstantScalarValue::Float(val, _) => ConstVal::from_float(val, width, None),
187 ConstantScalarValue::UInt(val, _) => ConstVal::from_uint(val, width),
188 ConstantScalarValue::Bool(val) => ConstVal::from_bool(val),
189 }
190 }
191}
192
193impl From<u32> for ConstVal {
194 fn from(value: u32) -> Self {
195 ConstVal::Bit32(value)
196 }
197}
198
199impl From<f32> for ConstVal {
200 fn from(value: f32) -> Self {
201 ConstVal::Bit32(value.to_bits())
202 }
203}
204
205impl Variable {
206 pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
207 match self {
208 Variable::GlobalInputArray(id, _, _) => *id,
209 Variable::GlobalOutputArray(id, _, _) => *id,
210 Variable::GlobalScalar(id, _) => *id,
211 Variable::ConstantScalar(id, _, _) => *id,
212 Variable::Local { id, .. } => *id,
213 Variable::Versioned {
214 id, variable: var, ..
215 } => b.get_versioned(*id, var),
216 Variable::LocalBinding {
217 id, variable: var, ..
218 } => b.get_binding(*id, var),
219 Variable::Raw(id, _) => *id,
220 Variable::Named { id, .. } => *id,
221 Variable::Slice { ptr, .. } => ptr.id(b),
222 Variable::SharedMemory(id, _, _) => *id,
223 Variable::ConstantArray(id, _, _) => *id,
224 Variable::LocalArray(id, _, _) => *id,
225 Variable::CoopMatrix(_, _) => unimplemented!("Can't get ID from matrix var"),
226 Variable::SubgroupSize(id) => *id,
227 Variable::Id(id) => *id,
228 Variable::LocalInvocationIndex(id) => *id,
229 Variable::LocalInvocationIdX(id) => *id,
230 Variable::LocalInvocationIdY(id) => *id,
231 Variable::LocalInvocationIdZ(id) => *id,
232 Variable::Rank(id) => *id,
233 Variable::WorkgroupId(id) => *id,
234 Variable::WorkgroupIdX(id) => *id,
235 Variable::WorkgroupIdY(id) => *id,
236 Variable::WorkgroupIdZ(id) => *id,
237 Variable::GlobalInvocationIndex(id) => *id,
238 Variable::GlobalInvocationIdX(id) => *id,
239 Variable::GlobalInvocationIdY(id) => *id,
240 Variable::GlobalInvocationIdZ(id) => *id,
241 Variable::WorkgroupSize(id) => *id,
242 Variable::WorkgroupSizeX(id) => *id,
243 Variable::WorkgroupSizeY(id) => *id,
244 Variable::WorkgroupSizeZ(id) => *id,
245 Variable::NumWorkgroups(id) => *id,
246 Variable::NumWorkgroupsX(id) => *id,
247 Variable::NumWorkgroupsY(id) => *id,
248 Variable::NumWorkgroupsZ(id) => *id,
249 }
250 }
251
252 pub fn item(&self) -> Item {
253 match self {
254 Variable::GlobalInputArray(_, item, _) => item.clone(),
255 Variable::GlobalOutputArray(_, item, _) => item.clone(),
256 Variable::GlobalScalar(_, elem) => Item::Scalar(*elem),
257 Variable::ConstantScalar(_, _, elem) => Item::Scalar(*elem),
258 Variable::Local { item, .. } => item.clone(),
259 Variable::Versioned { item, .. } => item.clone(),
260 Variable::LocalBinding { item, .. } => item.clone(),
261 Variable::Named { item, .. } => item.clone(),
262 Variable::Slice { item, .. } => item.clone(),
263 Variable::SharedMemory(_, item, _) => item.clone(),
264 Variable::ConstantArray(_, item, _) => item.clone(),
265 Variable::LocalArray(_, item, _) => item.clone(),
266 Variable::CoopMatrix(_, elem) => Item::Scalar(*elem),
267 _ => Item::Scalar(Elem::Int(32, false)), }
269 }
270
271 pub fn indexed_item(&self) -> Item {
272 match self {
273 Variable::LocalBinding {
274 item: Item::Vector(elem, _),
275 ..
276 } => Item::Scalar(*elem),
277 Variable::Local {
278 item: Item::Vector(elem, _),
279 ..
280 } => Item::Scalar(*elem),
281 Variable::Versioned {
282 item: Item::Vector(elem, _),
283 ..
284 } => Item::Scalar(*elem),
285 other => other.item(),
286 }
287 }
288
289 pub fn elem(&self) -> Elem {
290 self.item().elem()
291 }
292
293 pub fn has_len(&self) -> bool {
294 matches!(
295 self,
296 Variable::GlobalInputArray(_, _, _)
297 | Variable::GlobalOutputArray(_, _, _)
298 | Variable::Named {
299 is_array: false,
300 ..
301 }
302 | Variable::Slice { .. }
303 | Variable::SharedMemory(_, _, _)
304 | Variable::ConstantArray(_, _, _)
305 | Variable::LocalArray(_, _, _)
306 )
307 }
308
309 pub fn has_buffer_len(&self) -> bool {
310 matches!(
311 self,
312 Variable::GlobalInputArray(_, _, _)
313 | Variable::GlobalOutputArray(_, _, _)
314 | Variable::Named {
315 is_array: false,
316 ..
317 }
318 )
319 }
320
321 pub fn as_const(&self) -> Option<ConstVal> {
322 match self {
323 Self::ConstantScalar(_, val, _) => Some(*val),
324 _ => None,
325 }
326 }
327
328 pub fn as_binding(&self) -> Option<Id> {
329 match self {
330 Self::LocalBinding { id, .. } => Some(*id),
331 _ => None,
332 }
333 }
334}
335
336#[derive(Debug)]
337pub enum IndexedVariable {
338 Pointer(Word, Item),
339 Composite(Word, u32, Item),
340 DynamicComposite(Word, u32, Item),
341 Scalar(Variable),
342}
343
344impl<T: SpirvTarget> SpirvCompiler<T> {
345 pub fn compile_variable(&mut self, variable: ir::Variable) -> Variable {
346 let item = variable.ty;
347 match variable.kind {
348 ir::VariableKind::ConstantScalar(value) => {
349 let item = self.compile_type(ir::Type::new(value.storage_type()));
350 let const_val = value.into();
351
352 if let Some(existing) = self.state.constants.get(&(const_val, item.clone())) {
353 Variable::ConstantScalar(*existing, const_val, item.elem())
354 } else {
355 let id = item.elem().constant(self, const_val);
356 self.state.constants.insert((const_val, item.clone()), id);
357 Variable::ConstantScalar(id, const_val, item.elem())
358 }
359 }
360 ir::VariableKind::GlobalInputArray(pos) => {
361 let id = self.state.buffers[pos as usize];
362 Variable::GlobalInputArray(id, self.compile_type(item), pos)
363 }
364 ir::VariableKind::GlobalOutputArray(pos) => {
365 let id = self.state.buffers[pos as usize];
366 Variable::GlobalOutputArray(id, self.compile_type(item), pos)
367 }
368 ir::VariableKind::GlobalScalar(id) => self.global_scalar(id, item.storage_type()),
369 ir::VariableKind::LocalMut { id } => {
370 let item = self.compile_type(item);
371 let var = self.get_local(id, &item, variable);
372 Variable::Local { id: var, item }
373 }
374 ir::VariableKind::Versioned { id, version } => {
375 let item = self.compile_type(item);
376 let id = (id, version);
377 Variable::Versioned { id, item, variable }
378 }
379 ir::VariableKind::LocalConst { id } => {
380 let item = self.compile_type(item);
381 Variable::LocalBinding { id, item, variable }
382 }
383 ir::VariableKind::Builtin(builtin) => self.compile_builtin(builtin),
384 ir::VariableKind::ConstantArray { id, length, .. } => {
385 let item = self.compile_type(item);
386 let id = self.state.const_arrays[id as usize].id;
387 Variable::ConstantArray(id, item, length)
388 }
389 ir::VariableKind::SharedMemory { id, length, .. } => {
390 let item = self.compile_type(item);
391 let id = self.state.shared_memories[&id].id;
392 Variable::SharedMemory(id, item, length)
393 }
394 ir::VariableKind::LocalArray {
395 id,
396 length,
397 unroll_factor,
398 } => {
399 let item = self.compile_type(item);
400 let id = if let Some(arr) = self.state.local_arrays.get(&id) {
401 arr.id
402 } else {
403 let arr_ty = Item::Array(Box::new(item.clone()), length);
404 let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(arr_ty)).id(self);
405 let arr_id = self.declare_function_variable(ptr_ty);
406 self.debug_var_name(arr_id, variable);
407 let arr = Array {
408 id: arr_id,
409 item: item.clone(),
410 len: length * unroll_factor,
411 var: variable,
412 alignment: None,
413 };
414 self.state.local_arrays.insert(id, arr);
415 arr_id
416 };
417 Variable::LocalArray(id, item, length)
418 }
419 ir::VariableKind::Matrix { id, mat } => {
420 let elem = self.compile_type(ir::Type::new(mat.storage)).elem();
421 if self.state.matrices.contains_key(&id) {
422 Variable::CoopMatrix(id, elem)
423 } else {
424 let matrix = self.init_coop_matrix(mat, variable);
425 self.state.matrices.insert(id, matrix);
426 Variable::CoopMatrix(id, elem)
427 }
428 }
429 ir::VariableKind::Pipeline { .. } => panic!("Pipeline not supported."),
430 ir::VariableKind::Barrier { .. } | ir::VariableKind::BarrierToken { .. } => {
431 panic!("Barrier not supported.")
432 }
433 ir::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
434 ir::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
435 }
436 }
437
438 pub fn read(&mut self, variable: &Variable) -> Word {
439 match variable {
440 Variable::Slice { ptr, .. } => self.read(ptr),
441 Variable::Local { id, item } => {
442 let ty = item.id(self);
443 self.load(ty, None, *id, None, vec![]).unwrap()
444 }
445 Variable::Named { id, item, .. } => {
446 let ty = item.id(self);
447 self.load(ty, None, *id, None, vec![]).unwrap()
448 }
449 ssa => ssa.id(self),
450 }
451 }
452
453 pub fn read_as(&mut self, variable: &Variable, item: &Item) -> Word {
454 if let Some(as_const) = variable.as_const() {
455 self.static_cast(as_const, &variable.elem(), item)
456 } else {
457 let id = self.read(variable);
458 variable.item().cast_to(self, None, id, item)
459 }
460 }
461
462 pub fn index(
463 &mut self,
464 variable: &Variable,
465 index: &Variable,
466 unchecked: bool,
467 ) -> IndexedVariable {
468 let access_chain = if unchecked {
469 Builder::in_bounds_access_chain
470 } else {
471 Builder::access_chain
472 };
473 let index_id = self.read(index);
474 match variable {
475 Variable::GlobalInputArray(id, item, _)
476 | Variable::GlobalOutputArray(id, item, _)
477 | Variable::Named { id, item, .. } => {
478 let ptr_ty =
479 Item::Pointer(StorageClass::StorageBuffer, Box::new(item.clone())).id(self);
480 let zero = self.const_u32(0);
481 let id = access_chain(self, ptr_ty, None, *id, vec![zero, index_id]).unwrap();
482
483 IndexedVariable::Pointer(id, item.clone())
484 }
485 Variable::Local {
486 id,
487 item: Item::Vector(elem, _),
488 } => {
489 let ptr_ty =
490 Item::Pointer(StorageClass::Function, Box::new(Item::Scalar(*elem))).id(self);
491 let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
492
493 IndexedVariable::Pointer(id, Item::Scalar(*elem))
494 }
495 Variable::LocalBinding {
496 id,
497 item: Item::Vector(elem, vec),
498 variable,
499 } if index.as_const().is_some() => IndexedVariable::Composite(
500 self.get_binding(*id, variable),
501 index.as_const().unwrap().as_u32(),
502 Item::Vector(*elem, *vec),
503 ),
504 Variable::LocalBinding {
505 id,
506 item: Item::Vector(elem, vec),
507 variable,
508 } => IndexedVariable::DynamicComposite(
509 self.get_binding(*id, variable),
510 index_id,
511 Item::Vector(*elem, *vec),
512 ),
513 Variable::Versioned {
514 id,
515 item: Item::Vector(elem, vec),
516 variable,
517 } if index.as_const().is_some() => IndexedVariable::Composite(
518 self.get_versioned(*id, variable),
519 index.as_const().unwrap().as_u32(),
520 Item::Vector(*elem, *vec),
521 ),
522 Variable::Versioned {
523 id,
524 item: Item::Vector(elem, vec),
525 variable,
526 } => IndexedVariable::DynamicComposite(
527 self.get_versioned(*id, variable),
528 index_id,
529 Item::Vector(*elem, *vec),
530 ),
531 Variable::Local { .. } | Variable::LocalBinding { .. } | Variable::Versioned { .. } => {
532 IndexedVariable::Scalar(variable.clone())
533 }
534 Variable::Slice { ptr, offset, .. } => {
535 let item = Item::Scalar(Elem::Int(32, false));
536 let int = item.id(self);
537 let index = match index.as_const() {
538 Some(ConstVal::Bit32(0)) => *offset,
539 _ => self.i_add(int, None, *offset, index_id).unwrap(),
540 };
541 self.index(ptr, &Variable::Raw(index, item), unchecked)
542 }
543 Variable::SharedMemory(id, item, _) => {
544 let ptr_ty =
545 Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
546 let mut index = vec![index_id];
547 if self.compilation_options.supports_explicit_smem {
548 index.insert(0, self.const_u32(0));
549 }
550 let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
551 IndexedVariable::Pointer(id, item.clone())
552 }
553 Variable::ConstantArray(id, item, _) | Variable::LocalArray(id, item, _) => {
554 let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
555 let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
556 IndexedVariable::Pointer(id, item.clone())
557 }
558 var => unimplemented!("Can't index into {var:?}"),
559 }
560 }
561
562 pub fn read_indexed(&mut self, out: &Variable, variable: &Variable, index: &Variable) -> Word {
563 let always_in_bounds = is_always_in_bounds(variable, index);
564 let indexed = self.index(variable, index, always_in_bounds);
565
566 let read = |b: &mut Self| match indexed {
567 IndexedVariable::Pointer(ptr, item) => {
568 let ty = item.id(b);
569 let out_id = b.write_id(out);
570 b.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
571 }
572 IndexedVariable::Composite(var, index, item) => {
573 let elem = item.elem();
574 let ty = elem.id(b);
575 let out_id = b.write_id(out);
576 b.composite_extract(ty, Some(out_id), var, vec![index])
577 .unwrap()
578 }
579 IndexedVariable::DynamicComposite(var, index, item) => {
580 let elem = item.elem();
581 let ty = elem.id(b);
582 let out_id = b.write_id(out);
583 b.vector_extract_dynamic(ty, Some(out_id), var, index)
584 .unwrap()
585 }
586 IndexedVariable::Scalar(var) => {
587 let ty = out.item().id(b);
588 let input = b.read(&var);
589 let out_id = b.write_id(out);
590 b.copy_object(ty, Some(out_id), input).unwrap();
591 b.write(out, out_id);
592 out_id
593 }
594 };
595
596 read(self)
597 }
598
599 pub fn read_indexed_unchecked(
600 &mut self,
601 out: &Variable,
602 variable: &Variable,
603 index: &Variable,
604 ) -> Word {
605 let indexed = self.index(variable, index, true);
606
607 match indexed {
608 IndexedVariable::Pointer(ptr, item) => {
609 let ty = item.id(self);
610 let out_id = self.write_id(out);
611 self.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
612 }
613 IndexedVariable::Composite(var, index, item) => {
614 let elem = item.elem();
615 let ty = elem.id(self);
616 let out_id = self.write_id(out);
617 self.composite_extract(ty, Some(out_id), var, vec![index])
618 .unwrap()
619 }
620 IndexedVariable::DynamicComposite(var, index, item) => {
621 let elem = item.elem();
622 let ty = elem.id(self);
623 let out_id = self.write_id(out);
624 self.vector_extract_dynamic(ty, Some(out_id), var, index)
625 .unwrap()
626 }
627 IndexedVariable::Scalar(var) => {
628 let ty = out.item().id(self);
629 let input = self.read(&var);
630 let out_id = self.write_id(out);
631 self.copy_object(ty, Some(out_id), input).unwrap();
632 self.write(out, out_id);
633 out_id
634 }
635 }
636 }
637
638 pub fn index_ptr(&mut self, var: &Variable, index: &Variable) -> Word {
639 let always_in_bounds = is_always_in_bounds(var, index);
640 match self.index(var, index, always_in_bounds) {
641 IndexedVariable::Pointer(ptr, _) => ptr,
642 other => unreachable!("{other:?}"),
643 }
644 }
645
646 pub fn write_id(&mut self, variable: &Variable) -> Word {
647 match variable {
648 Variable::LocalBinding { id, variable, .. } => self.get_binding(*id, variable),
649 Variable::Versioned { id, variable, .. } => self.get_versioned(*id, variable),
650 Variable::Local { .. } => self.id(),
651 Variable::GlobalScalar(id, _) => *id,
652 Variable::Raw(id, _) => *id,
653 Variable::ConstantScalar(_, _, _) => panic!("Can't write to constant scalar"),
654 Variable::GlobalInputArray(_, _, _)
655 | Variable::GlobalOutputArray(_, _, _)
656 | Variable::Slice { .. }
657 | Variable::Named { .. }
658 | Variable::SharedMemory(_, _, _)
659 | Variable::ConstantArray(_, _, _)
660 | Variable::LocalArray(_, _, _) => panic!("Can't write to unindexed array"),
661 global => panic!("Can't write to builtin {global:?}"),
662 }
663 }
664
665 pub fn write(&mut self, variable: &Variable, value: Word) {
666 match variable {
667 Variable::Local { id, .. } => self.store(*id, value, None, vec![]).unwrap(),
668 Variable::Slice { ptr, .. } => self.write(ptr, value),
669 _ => {}
670 }
671 }
672
673 pub fn write_indexed(&mut self, out: &Variable, index: &Variable, value: Word) {
674 let always_in_bounds = is_always_in_bounds(out, index);
675 let variable = self.index(out, index, always_in_bounds);
676
677 let write = |b: &mut Self| match variable {
678 IndexedVariable::Pointer(ptr, _) => b.store(ptr, value, None, vec![]).unwrap(),
679 IndexedVariable::Composite(var, index, item) => {
680 let ty = item.id(b);
681 let id = b
682 .composite_insert(ty, None, value, var, vec![index])
683 .unwrap();
684 b.write(out, id);
685 }
686 IndexedVariable::DynamicComposite(var, index, item) => {
687 let ty = item.id(b);
688 let id = b
689 .vector_insert_dynamic(ty, None, value, var, index)
690 .unwrap();
691 b.write(out, id);
692 }
693 IndexedVariable::Scalar(var) => b.write(&var, value),
694 };
695
696 write(self)
697 }
698
699 pub fn write_indexed_unchecked(&mut self, out: &Variable, index: &Variable, value: Word) {
700 let variable = self.index(out, index, true);
701
702 match variable {
703 IndexedVariable::Pointer(ptr, _) => self.store(ptr, value, None, vec![]).unwrap(),
704 IndexedVariable::Composite(var, index, item) => {
705 let ty = item.id(self);
706 let out_id = self
707 .composite_insert(ty, None, value, var, vec![index])
708 .unwrap();
709 self.write(out, out_id);
710 }
711 IndexedVariable::DynamicComposite(var, index, item) => {
712 let ty = item.id(self);
713 let out_id = self
714 .vector_insert_dynamic(ty, None, value, var, index)
715 .unwrap();
716 self.write(out, out_id);
717 }
718 IndexedVariable::Scalar(var) => self.write(&var, value),
719 }
720 }
721}
722
723fn is_always_in_bounds(var: &Variable, index: &Variable) -> bool {
724 let len = match var {
725 Variable::SharedMemory(_, _, len)
726 | Variable::ConstantArray(_, _, len)
727 | Variable::LocalArray(_, _, len)
728 | Variable::Slice {
729 const_len: Some(len),
730 ..
731 } => *len,
732 _ => return false,
733 };
734
735 let const_index = match index {
736 Variable::ConstantScalar(_, value, _) => value.as_u32(),
737 _ => return false,
738 };
739
740 const_index < len
741}