1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{
6 SpirvCompiler, SpirvTarget,
7 item::{Elem, Item},
8 lookups::Array,
9};
10use cubecl_core::ir::{self, ConstantScalarValue, FloatKind, Id};
11use rspirv::{
12 dr::Builder,
13 spirv::{FPEncoding, StorageClass, Word},
14};
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum Variable {
18 SubgroupSize(Word),
19 GlobalInputArray(Word, Item, u32),
20 GlobalOutputArray(Word, Item, u32),
21 GlobalScalar(Word, Elem),
22 ConstantScalar(Word, ConstVal, Elem),
23 Local {
24 id: Word,
25 item: Item,
26 },
27 Versioned {
28 id: (Id, u16),
29 item: Item,
30 variable: ir::Variable,
31 },
32 LocalBinding {
33 id: Id,
34 item: Item,
35 variable: ir::Variable,
36 },
37 Raw(Word, Item),
38 Named {
39 id: Word,
40 item: Item,
41 is_array: bool,
42 },
43 Slice {
44 ptr: Box<Variable>,
45 offset: Word,
46 end: Word,
47 const_len: Option<u32>,
48 item: Item,
49 },
50 SharedMemory(Word, Item, u32),
51 ConstantArray(Word, Item, u32),
52 LocalArray(Word, Item, u32),
53 CoopMatrix(Id, Elem),
54 Id(Word),
55 LocalInvocationIndex(Word),
56 LocalInvocationIdX(Word),
57 LocalInvocationIdY(Word),
58 LocalInvocationIdZ(Word),
59 Rank(Word),
60 WorkgroupId(Word),
61 WorkgroupIdX(Word),
62 WorkgroupIdY(Word),
63 WorkgroupIdZ(Word),
64 GlobalInvocationIndex(Word),
65 GlobalInvocationIdX(Word),
66 GlobalInvocationIdY(Word),
67 GlobalInvocationIdZ(Word),
68 WorkgroupSize(Word),
69 WorkgroupSizeX(Word),
70 WorkgroupSizeY(Word),
71 WorkgroupSizeZ(Word),
72 NumWorkgroups(Word),
73 NumWorkgroupsX(Word),
74 NumWorkgroupsY(Word),
75 NumWorkgroupsZ(Word),
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
79pub enum ConstVal {
80 Bit32(u32),
81 Bit64(u64),
82}
83
84impl ConstVal {
85 pub fn as_u64(&self) -> u64 {
86 match self {
87 ConstVal::Bit32(val) => *val as u64,
88 ConstVal::Bit64(val) => *val,
89 }
90 }
91
92 pub fn as_u32(&self) -> u32 {
93 match self {
94 ConstVal::Bit32(val) => *val,
95 ConstVal::Bit64(_) => panic!("Truncating 64 bit variable to 32 bit"),
96 }
97 }
98
99 pub fn as_float(&self, width: u32, encoding: Option<FPEncoding>) -> f64 {
100 match (width, encoding) {
101 (64, _) => f64::from_bits(self.as_u64()),
102 (32, _) => f32::from_bits(self.as_u32()) as f64,
103 (16, None) => half::f16::from_bits(self.as_u32() as u16).to_f64(),
104 (_, Some(FPEncoding::BFloat16KHR)) => {
105 half::bf16::from_bits(self.as_u32() as u16).to_f64()
106 }
107 (_, Some(FPEncoding::Float8E4M3EXT)) => {
108 cubecl_common::e4m3::from_bits(self.as_u32() as u8).to_f64()
109 }
110 (_, Some(FPEncoding::Float8E5M2EXT)) => {
111 cubecl_common::e5m2::from_bits(self.as_u32() as u8).to_f64()
112 }
113 _ => unreachable!(),
114 }
115 }
116
117 pub fn as_int(&self, width: u32) -> i64 {
118 unsafe {
119 match width {
120 64 => transmute::<u64, i64>(self.as_u64()),
121 32 => transmute::<u32, i32>(self.as_u32()) as i64,
122 16 => transmute::<u16, i16>(self.as_u32() as u16) as i64,
123 8 => transmute::<u8, i8>(self.as_u32() as u8) as i64,
124 _ => unreachable!(),
125 }
126 }
127 }
128
129 pub fn from_float(value: f64, width: u32, encoding: Option<FPEncoding>) -> Self {
130 match (width, encoding) {
131 (64, _) => ConstVal::Bit64(value.to_bits()),
132 (32, _) => ConstVal::Bit32((value as f32).to_bits()),
133 (16, None) => ConstVal::Bit32(half::f16::from_f64(value).to_bits() as u32),
134 (_, Some(FPEncoding::BFloat16KHR)) => {
135 ConstVal::Bit32(half::bf16::from_f64(value).to_bits() as u32)
136 }
137 (_, Some(FPEncoding::Float8E4M3EXT)) => {
138 ConstVal::Bit32(cubecl_common::e4m3::from_f64(value).to_bits() as u32)
139 }
140 (_, Some(FPEncoding::Float8E5M2EXT)) => {
141 ConstVal::Bit32(cubecl_common::e5m2::from_f64(value).to_bits() as u32)
142 }
143 _ => unreachable!(),
144 }
145 }
146
147 pub fn from_int(value: i64, width: u32) -> Self {
148 match width {
149 64 => ConstVal::Bit64(unsafe { transmute::<i64, u64>(value) }),
150 32 => ConstVal::Bit32(unsafe { transmute::<i32, u32>(value as i32) }),
151 16 => ConstVal::Bit32(unsafe { transmute::<i16, u16>(value as i16) } as u32),
152 8 => ConstVal::Bit32(unsafe { transmute::<i8, u8>(value as i8) } as u32),
153 _ => unreachable!(),
154 }
155 }
156
157 pub fn from_uint(value: u64, width: u32) -> Self {
158 match width {
159 64 => ConstVal::Bit64(value),
160 32 => ConstVal::Bit32(value as u32),
161 16 => ConstVal::Bit32(value as u16 as u32),
162 8 => ConstVal::Bit32(value as u8 as u32),
163 _ => unreachable!(),
164 }
165 }
166
167 pub fn from_bool(value: bool) -> Self {
168 ConstVal::Bit32(value as u32)
169 }
170}
171
172impl From<ConstantScalarValue> for ConstVal {
173 fn from(value: ConstantScalarValue) -> Self {
174 let width = value.elem_type().size() as u32 * 8;
175 match value {
176 ConstantScalarValue::Int(val, _) => ConstVal::from_int(val, width),
177 ConstantScalarValue::Float(val, FloatKind::BF16) => {
178 ConstVal::from_float(val, width, Some(FPEncoding::BFloat16KHR))
179 }
180 ConstantScalarValue::Float(val, FloatKind::E4M3) => {
181 ConstVal::from_float(val, width, Some(FPEncoding::Float8E4M3EXT))
182 }
183 ConstantScalarValue::Float(val, FloatKind::E5M2) => {
184 ConstVal::from_float(val, width, Some(FPEncoding::Float8E5M2EXT))
185 }
186 ConstantScalarValue::Float(val, _) => ConstVal::from_float(val, width, None),
187 ConstantScalarValue::UInt(val, _) => ConstVal::from_uint(val, width),
188 ConstantScalarValue::Bool(val) => ConstVal::from_bool(val),
189 }
190 }
191}
192
193impl From<u32> for ConstVal {
194 fn from(value: u32) -> Self {
195 ConstVal::Bit32(value)
196 }
197}
198
199impl From<f32> for ConstVal {
200 fn from(value: f32) -> Self {
201 ConstVal::Bit32(value.to_bits())
202 }
203}
204
205impl Variable {
206 pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
207 match self {
208 Variable::GlobalInputArray(id, _, _) => *id,
209 Variable::GlobalOutputArray(id, _, _) => *id,
210 Variable::GlobalScalar(id, _) => *id,
211 Variable::ConstantScalar(id, _, _) => *id,
212 Variable::Local { id, .. } => *id,
213 Variable::Versioned {
214 id, variable: var, ..
215 } => b.get_versioned(*id, var),
216 Variable::LocalBinding {
217 id, variable: var, ..
218 } => b.get_binding(*id, var),
219 Variable::Raw(id, _) => *id,
220 Variable::Named { id, .. } => *id,
221 Variable::Slice { ptr, .. } => ptr.id(b),
222 Variable::SharedMemory(id, _, _) => *id,
223 Variable::ConstantArray(id, _, _) => *id,
224 Variable::LocalArray(id, _, _) => *id,
225 Variable::CoopMatrix(_, _) => unimplemented!("Can't get ID from matrix var"),
226 Variable::SubgroupSize(id) => *id,
227 Variable::Id(id) => *id,
228 Variable::LocalInvocationIndex(id) => *id,
229 Variable::LocalInvocationIdX(id) => *id,
230 Variable::LocalInvocationIdY(id) => *id,
231 Variable::LocalInvocationIdZ(id) => *id,
232 Variable::Rank(id) => *id,
233 Variable::WorkgroupId(id) => *id,
234 Variable::WorkgroupIdX(id) => *id,
235 Variable::WorkgroupIdY(id) => *id,
236 Variable::WorkgroupIdZ(id) => *id,
237 Variable::GlobalInvocationIndex(id) => *id,
238 Variable::GlobalInvocationIdX(id) => *id,
239 Variable::GlobalInvocationIdY(id) => *id,
240 Variable::GlobalInvocationIdZ(id) => *id,
241 Variable::WorkgroupSize(id) => *id,
242 Variable::WorkgroupSizeX(id) => *id,
243 Variable::WorkgroupSizeY(id) => *id,
244 Variable::WorkgroupSizeZ(id) => *id,
245 Variable::NumWorkgroups(id) => *id,
246 Variable::NumWorkgroupsX(id) => *id,
247 Variable::NumWorkgroupsY(id) => *id,
248 Variable::NumWorkgroupsZ(id) => *id,
249 }
250 }
251
252 pub fn item(&self) -> Item {
253 match self {
254 Variable::GlobalInputArray(_, item, _) => item.clone(),
255 Variable::GlobalOutputArray(_, item, _) => item.clone(),
256 Variable::GlobalScalar(_, elem) => Item::Scalar(*elem),
257 Variable::ConstantScalar(_, _, elem) => Item::Scalar(*elem),
258 Variable::Local { item, .. } => item.clone(),
259 Variable::Versioned { item, .. } => item.clone(),
260 Variable::LocalBinding { item, .. } => item.clone(),
261 Variable::Named { item, .. } => item.clone(),
262 Variable::Slice { item, .. } => item.clone(),
263 Variable::SharedMemory(_, item, _) => item.clone(),
264 Variable::ConstantArray(_, item, _) => item.clone(),
265 Variable::LocalArray(_, item, _) => item.clone(),
266 Variable::CoopMatrix(_, elem) => Item::Scalar(*elem),
267 _ => Item::Scalar(Elem::Int(32, false)), }
269 }
270
271 pub fn indexed_item(&self) -> Item {
272 match self {
273 Variable::LocalBinding {
274 item: Item::Vector(elem, _),
275 ..
276 } => Item::Scalar(*elem),
277 Variable::Local {
278 item: Item::Vector(elem, _),
279 ..
280 } => Item::Scalar(*elem),
281 Variable::Versioned {
282 item: Item::Vector(elem, _),
283 ..
284 } => Item::Scalar(*elem),
285 other => other.item(),
286 }
287 }
288
289 pub fn elem(&self) -> Elem {
290 self.item().elem()
291 }
292
293 pub fn has_len(&self) -> bool {
294 matches!(
295 self,
296 Variable::GlobalInputArray(_, _, _)
297 | Variable::GlobalOutputArray(_, _, _)
298 | Variable::Named {
299 is_array: false,
300 ..
301 }
302 | Variable::Slice { .. }
303 | Variable::SharedMemory(_, _, _)
304 | Variable::ConstantArray(_, _, _)
305 | Variable::LocalArray(_, _, _)
306 )
307 }
308
309 pub fn has_buffer_len(&self) -> bool {
310 matches!(
311 self,
312 Variable::GlobalInputArray(_, _, _)
313 | Variable::GlobalOutputArray(_, _, _)
314 | Variable::Named {
315 is_array: false,
316 ..
317 }
318 )
319 }
320
321 pub fn as_const(&self) -> Option<ConstVal> {
322 match self {
323 Self::ConstantScalar(_, val, _) => Some(*val),
324 _ => None,
325 }
326 }
327
328 pub fn as_binding(&self) -> Option<Id> {
329 match self {
330 Self::LocalBinding { id, .. } => Some(*id),
331 _ => None,
332 }
333 }
334}
335
336#[derive(Debug)]
337pub enum IndexedVariable {
338 Pointer(Word, Item),
339 Composite(Word, u32, Item),
340 DynamicComposite(Word, u32, Item),
341 Scalar(Variable),
342}
343
344impl<T: SpirvTarget> SpirvCompiler<T> {
345 pub fn compile_variable(&mut self, variable: ir::Variable) -> Variable {
346 let item = variable.ty;
347 match variable.kind {
348 ir::VariableKind::ConstantScalar(value) => {
349 let item = self.compile_type(ir::Type::new(value.storage_type()));
350 let const_val = value.into();
351
352 if let Some(existing) = self.state.constants.get(&(const_val, item.clone())) {
353 Variable::ConstantScalar(*existing, const_val, item.elem())
354 } else {
355 let id = item.elem().constant(self, const_val);
356 self.state.constants.insert((const_val, item.clone()), id);
357 Variable::ConstantScalar(id, const_val, item.elem())
358 }
359 }
360 ir::VariableKind::GlobalInputArray(pos) => {
361 let id = self.state.buffers[pos as usize];
362 Variable::GlobalInputArray(id, self.compile_type(item), pos)
363 }
364 ir::VariableKind::GlobalOutputArray(pos) => {
365 let id = self.state.buffers[pos as usize];
366 Variable::GlobalOutputArray(id, self.compile_type(item), pos)
367 }
368 ir::VariableKind::GlobalScalar(id) => self.global_scalar(id, item.storage_type()),
369 ir::VariableKind::LocalMut { id } => {
370 let item = self.compile_type(item);
371 let var = self.get_local(id, &item, variable);
372 Variable::Local { id: var, item }
373 }
374 ir::VariableKind::Versioned { id, version } => {
375 let item = self.compile_type(item);
376 let id = (id, version);
377 Variable::Versioned { id, item, variable }
378 }
379 ir::VariableKind::LocalConst { id } => {
380 let item = self.compile_type(item);
381 Variable::LocalBinding { id, item, variable }
382 }
383 ir::VariableKind::Builtin(builtin) => self.compile_builtin(builtin),
384 ir::VariableKind::ConstantArray { id, length, .. } => {
385 let item = self.compile_type(item);
386 let id = self.state.const_arrays[id as usize].id;
387 Variable::ConstantArray(id, item, length)
388 }
389 ir::VariableKind::SharedMemory { id, length, .. } => {
390 let item = self.compile_type(item);
391 let id = self.state.shared_memories[&id].id;
392 Variable::SharedMemory(id, item, length)
393 }
394 ir::VariableKind::LocalArray {
395 id,
396 length,
397 unroll_factor,
398 } => {
399 let item = self.compile_type(item);
400 let id = if let Some(arr) = self.state.local_arrays.get(&id) {
401 arr.id
402 } else {
403 let arr_ty = Item::Array(Box::new(item.clone()), length);
404 let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(arr_ty)).id(self);
405 let arr_id = self.declare_function_variable(ptr_ty);
406 self.debug_var_name(arr_id, variable);
407 let arr = Array {
408 id: arr_id,
409 item: item.clone(),
410 len: length * unroll_factor,
411 var: variable,
412 alignment: None,
413 };
414 self.state.local_arrays.insert(id, arr);
415 arr_id
416 };
417 Variable::LocalArray(id, item, length)
418 }
419 ir::VariableKind::Matrix { id, mat } => {
420 let elem = self.compile_type(ir::Type::new(mat.storage)).elem();
421 if self.state.matrices.contains_key(&id) {
422 Variable::CoopMatrix(id, elem)
423 } else {
424 let matrix = self.init_coop_matrix(mat, variable);
425 self.state.matrices.insert(id, matrix);
426 Variable::CoopMatrix(id, elem)
427 }
428 }
429 ir::VariableKind::Pipeline { .. } => panic!("Pipeline not supported."),
430 ir::VariableKind::Barrier { .. } => panic!("Barrier not supported."),
431 ir::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
432 ir::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
433 }
434 }
435
436 pub fn read(&mut self, variable: &Variable) -> Word {
437 match variable {
438 Variable::Slice { ptr, .. } => self.read(ptr),
439 Variable::Local { id, item } => {
440 let ty = item.id(self);
441 self.load(ty, None, *id, None, vec![]).unwrap()
442 }
443 Variable::Named { id, item, .. } => {
444 let ty = item.id(self);
445 self.load(ty, None, *id, None, vec![]).unwrap()
446 }
447 ssa => ssa.id(self),
448 }
449 }
450
451 pub fn read_as(&mut self, variable: &Variable, item: &Item) -> Word {
452 if let Some(as_const) = variable.as_const() {
453 self.static_cast(as_const, &variable.elem(), item)
454 } else {
455 let id = self.read(variable);
456 variable.item().cast_to(self, None, id, item)
457 }
458 }
459
460 pub fn index(
461 &mut self,
462 variable: &Variable,
463 index: &Variable,
464 unchecked: bool,
465 ) -> IndexedVariable {
466 let access_chain = if unchecked {
467 Builder::in_bounds_access_chain
468 } else {
469 Builder::access_chain
470 };
471 let index_id = self.read(index);
472 match variable {
473 Variable::GlobalInputArray(id, item, _)
474 | Variable::GlobalOutputArray(id, item, _)
475 | Variable::Named { id, item, .. } => {
476 let ptr_ty =
477 Item::Pointer(StorageClass::StorageBuffer, Box::new(item.clone())).id(self);
478 let zero = self.const_u32(0);
479 let id = access_chain(self, ptr_ty, None, *id, vec![zero, index_id]).unwrap();
480
481 IndexedVariable::Pointer(id, item.clone())
482 }
483 Variable::Local {
484 id,
485 item: Item::Vector(elem, _),
486 } => {
487 let ptr_ty =
488 Item::Pointer(StorageClass::Function, Box::new(Item::Scalar(*elem))).id(self);
489 let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
490
491 IndexedVariable::Pointer(id, Item::Scalar(*elem))
492 }
493 Variable::LocalBinding {
494 id,
495 item: Item::Vector(elem, vec),
496 variable,
497 } if index.as_const().is_some() => IndexedVariable::Composite(
498 self.get_binding(*id, variable),
499 index.as_const().unwrap().as_u32(),
500 Item::Vector(*elem, *vec),
501 ),
502 Variable::LocalBinding {
503 id,
504 item: Item::Vector(elem, vec),
505 variable,
506 } => IndexedVariable::DynamicComposite(
507 self.get_binding(*id, variable),
508 index_id,
509 Item::Vector(*elem, *vec),
510 ),
511 Variable::Versioned {
512 id,
513 item: Item::Vector(elem, vec),
514 variable,
515 } if index.as_const().is_some() => IndexedVariable::Composite(
516 self.get_versioned(*id, variable),
517 index.as_const().unwrap().as_u32(),
518 Item::Vector(*elem, *vec),
519 ),
520 Variable::Versioned {
521 id,
522 item: Item::Vector(elem, vec),
523 variable,
524 } => IndexedVariable::DynamicComposite(
525 self.get_versioned(*id, variable),
526 index_id,
527 Item::Vector(*elem, *vec),
528 ),
529 Variable::Local { .. } | Variable::LocalBinding { .. } | Variable::Versioned { .. } => {
530 IndexedVariable::Scalar(variable.clone())
531 }
532 Variable::Slice { ptr, offset, .. } => {
533 let item = Item::Scalar(Elem::Int(32, false));
534 let int = item.id(self);
535 let index = match index.as_const() {
536 Some(ConstVal::Bit32(0)) => *offset,
537 _ => self.i_add(int, None, *offset, index_id).unwrap(),
538 };
539 self.index(ptr, &Variable::Raw(index, item), unchecked)
540 }
541 Variable::SharedMemory(id, item, _) => {
542 let ptr_ty =
543 Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
544 let mut index = vec![index_id];
545 if self.compilation_options.supports_explicit_smem {
546 index.insert(0, self.const_u32(0));
547 }
548 let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
549 IndexedVariable::Pointer(id, item.clone())
550 }
551 Variable::ConstantArray(id, item, _) | Variable::LocalArray(id, item, _) => {
552 let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
553 let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
554 IndexedVariable::Pointer(id, item.clone())
555 }
556 var => unimplemented!("Can't index into {var:?}"),
557 }
558 }
559
560 pub fn read_indexed(&mut self, out: &Variable, variable: &Variable, index: &Variable) -> Word {
561 let always_in_bounds = is_always_in_bounds(variable, index);
562 let indexed = self.index(variable, index, always_in_bounds);
563
564 let read = |b: &mut Self| match indexed {
565 IndexedVariable::Pointer(ptr, item) => {
566 let ty = item.id(b);
567 let out_id = b.write_id(out);
568 b.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
569 }
570 IndexedVariable::Composite(var, index, item) => {
571 let elem = item.elem();
572 let ty = elem.id(b);
573 let out_id = b.write_id(out);
574 b.composite_extract(ty, Some(out_id), var, vec![index])
575 .unwrap()
576 }
577 IndexedVariable::DynamicComposite(var, index, item) => {
578 let elem = item.elem();
579 let ty = elem.id(b);
580 let out_id = b.write_id(out);
581 b.vector_extract_dynamic(ty, Some(out_id), var, index)
582 .unwrap()
583 }
584 IndexedVariable::Scalar(var) => {
585 let ty = out.item().id(b);
586 let input = b.read(&var);
587 let out_id = b.write_id(out);
588 b.copy_object(ty, Some(out_id), input).unwrap();
589 b.write(out, out_id);
590 out_id
591 }
592 };
593
594 read(self)
595 }
596
597 pub fn read_indexed_unchecked(
598 &mut self,
599 out: &Variable,
600 variable: &Variable,
601 index: &Variable,
602 ) -> Word {
603 let indexed = self.index(variable, index, true);
604
605 match indexed {
606 IndexedVariable::Pointer(ptr, item) => {
607 let ty = item.id(self);
608 let out_id = self.write_id(out);
609 self.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
610 }
611 IndexedVariable::Composite(var, index, item) => {
612 let elem = item.elem();
613 let ty = elem.id(self);
614 let out_id = self.write_id(out);
615 self.composite_extract(ty, Some(out_id), var, vec![index])
616 .unwrap()
617 }
618 IndexedVariable::DynamicComposite(var, index, item) => {
619 let elem = item.elem();
620 let ty = elem.id(self);
621 let out_id = self.write_id(out);
622 self.vector_extract_dynamic(ty, Some(out_id), var, index)
623 .unwrap()
624 }
625 IndexedVariable::Scalar(var) => {
626 let ty = out.item().id(self);
627 let input = self.read(&var);
628 let out_id = self.write_id(out);
629 self.copy_object(ty, Some(out_id), input).unwrap();
630 self.write(out, out_id);
631 out_id
632 }
633 }
634 }
635
636 pub fn index_ptr(&mut self, var: &Variable, index: &Variable) -> Word {
637 let always_in_bounds = is_always_in_bounds(var, index);
638 match self.index(var, index, always_in_bounds) {
639 IndexedVariable::Pointer(ptr, _) => ptr,
640 other => unreachable!("{other:?}"),
641 }
642 }
643
644 pub fn write_id(&mut self, variable: &Variable) -> Word {
645 match variable {
646 Variable::LocalBinding { id, variable, .. } => self.get_binding(*id, variable),
647 Variable::Versioned { id, variable, .. } => self.get_versioned(*id, variable),
648 Variable::Local { .. } => self.id(),
649 Variable::GlobalScalar(id, _) => *id,
650 Variable::Raw(id, _) => *id,
651 Variable::ConstantScalar(_, _, _) => panic!("Can't write to constant scalar"),
652 Variable::GlobalInputArray(_, _, _)
653 | Variable::GlobalOutputArray(_, _, _)
654 | Variable::Slice { .. }
655 | Variable::Named { .. }
656 | Variable::SharedMemory(_, _, _)
657 | Variable::ConstantArray(_, _, _)
658 | Variable::LocalArray(_, _, _) => panic!("Can't write to unindexed array"),
659 global => panic!("Can't write to builtin {global:?}"),
660 }
661 }
662
663 pub fn write(&mut self, variable: &Variable, value: Word) {
664 match variable {
665 Variable::Local { id, .. } => self.store(*id, value, None, vec![]).unwrap(),
666 Variable::Slice { ptr, .. } => self.write(ptr, value),
667 _ => {}
668 }
669 }
670
671 pub fn write_indexed(&mut self, out: &Variable, index: &Variable, value: Word) {
672 let always_in_bounds = is_always_in_bounds(out, index);
673 let variable = self.index(out, index, always_in_bounds);
674
675 let write = |b: &mut Self| match variable {
676 IndexedVariable::Pointer(ptr, _) => b.store(ptr, value, None, vec![]).unwrap(),
677 IndexedVariable::Composite(var, index, item) => {
678 let ty = item.id(b);
679 let id = b
680 .composite_insert(ty, None, value, var, vec![index])
681 .unwrap();
682 b.write(out, id);
683 }
684 IndexedVariable::DynamicComposite(var, index, item) => {
685 let ty = item.id(b);
686 let id = b
687 .vector_insert_dynamic(ty, None, value, var, index)
688 .unwrap();
689 b.write(out, id);
690 }
691 IndexedVariable::Scalar(var) => b.write(&var, value),
692 };
693
694 write(self)
695 }
696
697 pub fn write_indexed_unchecked(&mut self, out: &Variable, index: &Variable, value: Word) {
698 let variable = self.index(out, index, true);
699
700 match variable {
701 IndexedVariable::Pointer(ptr, _) => self.store(ptr, value, None, vec![]).unwrap(),
702 IndexedVariable::Composite(var, index, item) => {
703 let ty = item.id(self);
704 let out_id = self
705 .composite_insert(ty, None, value, var, vec![index])
706 .unwrap();
707 self.write(out, out_id);
708 }
709 IndexedVariable::DynamicComposite(var, index, item) => {
710 let ty = item.id(self);
711 let out_id = self
712 .vector_insert_dynamic(ty, None, value, var, index)
713 .unwrap();
714 self.write(out, out_id);
715 }
716 IndexedVariable::Scalar(var) => self.write(&var, value),
717 }
718 }
719}
720
721fn is_always_in_bounds(var: &Variable, index: &Variable) -> bool {
722 let len = match var {
723 Variable::SharedMemory(_, _, len)
724 | Variable::ConstantArray(_, _, len)
725 | Variable::LocalArray(_, _, len)
726 | Variable::Slice {
727 const_len: Some(len),
728 ..
729 } => *len,
730 _ => return false,
731 };
732
733 let const_index = match index {
734 Variable::ConstantScalar(_, value, _) => value.as_u32(),
735 _ => return false,
736 };
737
738 const_index < len
739}