1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{
6 SpirvCompiler, SpirvTarget,
7 item::{Elem, Item},
8 lookups::Array,
9};
10use cubecl_core::ir::{self, ConstantValue, Id};
11use rspirv::{
12 dr::Builder,
13 spirv::{self, FPEncoding, StorageClass, Word},
14};
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum Variable {
18 GlobalInputArray(Word, Item, u32),
19 GlobalOutputArray(Word, Item, u32),
20 GlobalScalar(Word, Elem),
21 Constant(Word, ConstVal, Item),
22 Local {
23 id: Word,
24 item: Item,
25 },
26 Versioned {
27 id: (Id, u16),
28 item: Item,
29 variable: ir::Variable,
30 },
31 LocalBinding {
32 id: Id,
33 item: Item,
34 variable: ir::Variable,
35 },
36 Raw(Word, Item),
37 Named {
38 id: Word,
39 item: Item,
40 is_array: bool,
41 },
42 Slice {
43 ptr: Box<Variable>,
44 offset: Word,
45 end: Word,
46 const_len: Option<u32>,
47 item: Item,
48 },
49 SharedArray(Word, Item, u32),
50 Shared(Word, Item),
51 ConstantArray(Word, Item, u32),
52 LocalArray(Word, Item, u32),
53 CoopMatrix(Id, Elem),
54 Id(Word),
55 Builtin(Word, Item),
56}
57
58impl Variable {
59 pub fn scope(&self) -> spirv::Scope {
60 match self {
61 Variable::GlobalInputArray(..)
62 | Variable::GlobalOutputArray(..)
63 | Variable::Named { .. }
64 | Variable::GlobalScalar(..) => spirv::Scope::Device,
65 Variable::SharedArray(..) | Variable::Shared(..) => spirv::Scope::Workgroup,
66 Variable::CoopMatrix(..) => spirv::Scope::Subgroup,
67 Variable::Slice { ptr, .. } => ptr.scope(),
68 Variable::Raw(..) => unimplemented!("Can't get scope of raw variable"),
69 Variable::Id(_) => unimplemented!("Can't get scope of raw id"),
70 _ => spirv::Scope::Invocation,
71 }
72 }
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76pub enum ConstVal {
77 Bit32(u32),
78 Bit64(u64),
79}
80
81impl ConstVal {
82 pub fn as_u64(&self) -> u64 {
83 match self {
84 ConstVal::Bit32(val) => *val as u64,
85 ConstVal::Bit64(val) => *val,
86 }
87 }
88
89 pub fn as_u32(&self) -> u32 {
90 match self {
91 ConstVal::Bit32(val) => *val,
92 ConstVal::Bit64(_) => panic!("Truncating 64 bit variable to 32 bit"),
93 }
94 }
95
96 pub fn as_float(&self, width: u32, encoding: Option<FPEncoding>) -> f64 {
97 match (width, encoding) {
98 (64, _) => f64::from_bits(self.as_u64()),
99 (32, _) => f32::from_bits(self.as_u32()) as f64,
100 (16, None) => half::f16::from_bits(self.as_u32() as u16).to_f64(),
101 (_, Some(FPEncoding::BFloat16KHR)) => {
102 half::bf16::from_bits(self.as_u32() as u16).to_f64()
103 }
104 (_, Some(FPEncoding::Float8E4M3EXT)) => {
105 cubecl_common::e4m3::from_bits(self.as_u32() as u8).to_f64()
106 }
107 (_, Some(FPEncoding::Float8E5M2EXT)) => {
108 cubecl_common::e5m2::from_bits(self.as_u32() as u8).to_f64()
109 }
110 _ => unreachable!(),
111 }
112 }
113
114 pub fn as_int(&self, width: u32) -> i64 {
115 unsafe {
116 match width {
117 64 => transmute::<u64, i64>(self.as_u64()),
118 32 => transmute::<u32, i32>(self.as_u32()) as i64,
119 16 => transmute::<u16, i16>(self.as_u32() as u16) as i64,
120 8 => transmute::<u8, i8>(self.as_u32() as u8) as i64,
121 _ => unreachable!(),
122 }
123 }
124 }
125
126 pub fn from_float(value: f64, width: u32, encoding: Option<FPEncoding>) -> Self {
127 match (width, encoding) {
128 (64, _) => ConstVal::Bit64(value.to_bits()),
129 (32, _) => ConstVal::Bit32((value as f32).to_bits()),
130 (16, None) => ConstVal::Bit32(half::f16::from_f64(value).to_bits() as u32),
131 (_, Some(FPEncoding::BFloat16KHR)) => {
132 ConstVal::Bit32(half::bf16::from_f64(value).to_bits() as u32)
133 }
134 (_, Some(FPEncoding::Float8E4M3EXT)) => {
135 ConstVal::Bit32(cubecl_common::e4m3::from_f64(value).to_bits() as u32)
136 }
137 (_, Some(FPEncoding::Float8E5M2EXT)) => {
138 ConstVal::Bit32(cubecl_common::e5m2::from_f64(value).to_bits() as u32)
139 }
140 _ => unreachable!(),
141 }
142 }
143
144 pub fn from_int(value: i64, width: u32) -> Self {
145 match width {
146 64 => ConstVal::Bit64(unsafe { transmute::<i64, u64>(value) }),
147 32 => ConstVal::Bit32(unsafe { transmute::<i32, u32>(value as i32) }),
148 16 => ConstVal::Bit32(unsafe { transmute::<i16, u16>(value as i16) } as u32),
149 8 => ConstVal::Bit32(unsafe { transmute::<i8, u8>(value as i8) } as u32),
150 _ => unreachable!(),
151 }
152 }
153
154 pub fn from_uint(value: u64, width: u32) -> Self {
155 match width {
156 64 => ConstVal::Bit64(value),
157 32 => ConstVal::Bit32(value as u32),
158 16 => ConstVal::Bit32(value as u16 as u32),
159 8 => ConstVal::Bit32(value as u8 as u32),
160 _ => unreachable!(),
161 }
162 }
163
164 pub fn from_bool(value: bool) -> Self {
165 ConstVal::Bit32(value as u32)
166 }
167}
168
169impl From<(ConstantValue, Item)> for ConstVal {
170 fn from((value, ty): (ConstantValue, Item)) -> Self {
171 let elem = ty.elem();
172 let width = elem.size() * 8;
173 match value {
174 ConstantValue::Int(val) => ConstVal::from_int(val, width),
175 ConstantValue::Float(val) => ConstVal::from_float(val, width, elem.float_encoding()),
176 ConstantValue::UInt(val) => ConstVal::from_uint(val, width),
177 ConstantValue::Bool(val) => ConstVal::from_bool(val),
178 }
179 }
180}
181
182impl From<u32> for ConstVal {
183 fn from(value: u32) -> Self {
184 ConstVal::Bit32(value)
185 }
186}
187
188impl From<f32> for ConstVal {
189 fn from(value: f32) -> Self {
190 ConstVal::Bit32(value.to_bits())
191 }
192}
193
194impl Variable {
195 pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
196 match self {
197 Variable::GlobalInputArray(id, _, _) => *id,
198 Variable::GlobalOutputArray(id, _, _) => *id,
199 Variable::GlobalScalar(id, _) => *id,
200 Variable::Constant(id, _, _) => *id,
201 Variable::Local { id, .. } => *id,
202 Variable::Versioned {
203 id, variable: var, ..
204 } => b.get_versioned(*id, var),
205 Variable::LocalBinding {
206 id, variable: var, ..
207 } => b.get_binding(*id, var),
208 Variable::Raw(id, _) => *id,
209 Variable::Named { id, .. } => *id,
210 Variable::Slice { ptr, .. } => ptr.id(b),
211 Variable::SharedArray(id, _, _) => *id,
212 Variable::Shared(id, _) => *id,
213 Variable::ConstantArray(id, _, _) => *id,
214 Variable::LocalArray(id, _, _) => *id,
215 Variable::CoopMatrix(_, _) => unimplemented!("Can't get ID from matrix var"),
216 Variable::Id(id) => *id,
217 Variable::Builtin(id, ..) => *id,
218 }
219 }
220
221 pub fn item(&self) -> Item {
222 match self {
223 Variable::GlobalInputArray(_, item, _) => item.clone(),
224 Variable::GlobalOutputArray(_, item, _) => item.clone(),
225 Variable::GlobalScalar(_, elem) => Item::Scalar(*elem),
226 Variable::Constant(_, _, item) => item.clone(),
227 Variable::Local { item, .. } => item.clone(),
228 Variable::Versioned { item, .. } => item.clone(),
229 Variable::LocalBinding { item, .. } => item.clone(),
230 Variable::Named { item, .. } => item.clone(),
231 Variable::Slice { item, .. } => item.clone(),
232 Variable::SharedArray(_, item, _) => item.clone(),
233 Variable::Shared(_, item) => item.clone(),
234 Variable::ConstantArray(_, item, _) => item.clone(),
235 Variable::LocalArray(_, item, _) => item.clone(),
236 Variable::CoopMatrix(_, elem) => Item::Scalar(*elem),
237 Variable::Builtin(_, item) => item.clone(),
238 Variable::Raw(_, item) => item.clone(),
239 Variable::Id(_) => unimplemented!("Can't get item of raw ID"),
240 }
241 }
242
243 pub fn indexed_item(&self) -> Item {
244 match self {
245 Variable::LocalBinding {
246 item: Item::Vector(elem, _),
247 ..
248 } => Item::Scalar(*elem),
249 Variable::Local {
250 item: Item::Vector(elem, _),
251 ..
252 } => Item::Scalar(*elem),
253 Variable::Versioned {
254 item: Item::Vector(elem, _),
255 ..
256 } => Item::Scalar(*elem),
257 other => other.item(),
258 }
259 }
260
261 pub fn elem(&self) -> Elem {
262 self.item().elem()
263 }
264
265 pub fn has_len(&self) -> bool {
266 matches!(
267 self,
268 Variable::GlobalInputArray(_, _, _)
269 | Variable::GlobalOutputArray(_, _, _)
270 | Variable::Named {
271 is_array: false,
272 ..
273 }
274 | Variable::Slice { .. }
275 | Variable::SharedArray(_, _, _)
276 | Variable::ConstantArray(_, _, _)
277 | Variable::LocalArray(_, _, _)
278 )
279 }
280
281 pub fn has_buffer_len(&self) -> bool {
282 matches!(
283 self,
284 Variable::GlobalInputArray(_, _, _)
285 | Variable::GlobalOutputArray(_, _, _)
286 | Variable::Named {
287 is_array: false,
288 ..
289 }
290 )
291 }
292
293 pub fn as_const(&self) -> Option<ConstVal> {
294 match self {
295 Self::Constant(_, val, _) => Some(*val),
296 _ => None,
297 }
298 }
299
300 pub fn as_binding(&self) -> Option<Id> {
301 match self {
302 Self::LocalBinding { id, .. } => Some(*id),
303 _ => None,
304 }
305 }
306}
307
308#[derive(Debug)]
309pub enum IndexedVariable {
310 Pointer(Word, Item),
311 Composite(Word, u32, Item),
312 DynamicComposite(Word, u32, Item),
313 Scalar(Variable),
314}
315
316impl<T: SpirvTarget> SpirvCompiler<T> {
317 pub fn compile_variable(&mut self, variable: ir::Variable) -> Variable {
318 let item = variable.ty;
319 match variable.kind {
320 ir::VariableKind::Constant(value) => {
321 let item = self.compile_type(item);
322 let const_val = (value, item.clone()).into();
323
324 if let Some(existing) = self.state.constants.get(&(const_val, item.clone())) {
325 Variable::Constant(*existing, const_val, item)
326 } else {
327 let id = item.constant(self, const_val);
328 self.state.constants.insert((const_val, item.clone()), id);
329 Variable::Constant(id, const_val, item)
330 }
331 }
332 ir::VariableKind::GlobalInputArray(pos) => {
333 let id = self.state.buffers[pos as usize];
334 Variable::GlobalInputArray(id, self.compile_type(item), pos)
335 }
336 ir::VariableKind::GlobalOutputArray(pos) => {
337 let id = self.state.buffers[pos as usize];
338 Variable::GlobalOutputArray(id, self.compile_type(item), pos)
339 }
340 ir::VariableKind::GlobalScalar(id) => self.global_scalar(id, item.storage_type()),
341 ir::VariableKind::LocalMut { id } => {
342 let item = self.compile_type(item);
343 let var = self.get_local(id, &item, variable);
344 Variable::Local { id: var, item }
345 }
346 ir::VariableKind::Versioned { id, version } => {
347 let item = self.compile_type(item);
348 let id = (id, version);
349 Variable::Versioned { id, item, variable }
350 }
351 ir::VariableKind::LocalConst { id } => {
352 let item = self.compile_type(item);
353 Variable::LocalBinding { id, item, variable }
354 }
355 ir::VariableKind::Builtin(builtin) => {
356 let item = self.compile_type(item);
357 self.compile_builtin(builtin, item)
358 }
359 ir::VariableKind::ConstantArray { id, length, .. } => {
360 let item = self.compile_type(item);
361 let id = self.state.const_arrays[id as usize].id;
362 Variable::ConstantArray(id, item, length as u32)
363 }
364 ir::VariableKind::SharedArray { id, length, .. } => {
365 let item = self.compile_type(item);
366 let id = self.state.shared_arrays[&id].id;
367 Variable::SharedArray(id, item, length as u32)
368 }
369 ir::VariableKind::Shared { id } => {
370 let item = self.compile_type(item);
371 let id = self.state.shared[&id].id;
372 Variable::Shared(id, item)
373 }
374 ir::VariableKind::LocalArray {
375 id,
376 length,
377 unroll_factor,
378 } => {
379 let item = self.compile_type(item);
380 let id = if let Some(arr) = self.state.local_arrays.get(&id) {
381 arr.id
382 } else {
383 let arr_ty = Item::Array(Box::new(item.clone()), length as u32);
384 let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(arr_ty)).id(self);
385 let arr_id = self.declare_function_variable(ptr_ty);
386 self.debug_var_name(arr_id, variable);
387 let arr = Array {
388 id: arr_id,
389 item: item.clone(),
390 len: (length * unroll_factor) as u32,
391 var: variable,
392 alignment: None,
393 };
394 self.state.local_arrays.insert(id, arr);
395 arr_id
396 };
397 Variable::LocalArray(id, item, length as u32)
398 }
399 ir::VariableKind::Matrix { id, mat } => {
400 let elem = self.compile_type(ir::Type::new(mat.storage)).elem();
401 if self.state.matrices.contains_key(&id) {
402 Variable::CoopMatrix(id, elem)
403 } else {
404 let matrix = self.init_coop_matrix(mat, variable);
405 self.state.matrices.insert(id, matrix);
406 Variable::CoopMatrix(id, elem)
407 }
408 }
409 ir::VariableKind::Pipeline { .. } => panic!("Pipeline not supported."),
410 ir::VariableKind::BarrierToken { .. } => {
411 panic!("Barrier not supported.")
412 }
413 ir::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
414 ir::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
415 }
416 }
417
418 pub fn read(&mut self, variable: &Variable) -> Word {
419 match variable {
420 Variable::Slice { ptr, .. } => self.read(ptr),
421 Variable::Shared(id, item) if self.compilation_options.supports_explicit_smem => {
422 let ty = item.id(self);
423 let ptr_ty =
424 Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
425 let index = vec![self.const_u32(0)];
426 let access = self.access_chain(ptr_ty, None, *id, index).unwrap();
427 self.load(ty, None, access, None, []).unwrap()
428 }
429 Variable::Local { id, item } | Variable::Shared(id, item) => {
430 let ty = item.id(self);
431 self.load(ty, None, *id, None, []).unwrap()
432 }
433 Variable::Named { id, item, .. } => {
434 let ty = item.id(self);
435 self.load(ty, None, *id, None, []).unwrap()
436 }
437 ssa => ssa.id(self),
438 }
439 }
440
441 pub fn read_as(&mut self, variable: &Variable, item: &Item) -> Word {
442 if let Some(as_const) = variable.as_const() {
443 self.static_cast(as_const, &variable.elem(), item).0
444 } else {
445 let id = self.read(variable);
446 variable.item().cast_to(self, None, id, item)
447 }
448 }
449
450 pub fn index(
451 &mut self,
452 variable: &Variable,
453 index: &Variable,
454 unchecked: bool,
455 ) -> IndexedVariable {
456 let access_chain = if unchecked {
457 Builder::in_bounds_access_chain
458 } else {
459 Builder::access_chain
460 };
461 let index_id = self.read(index);
462 match variable {
463 Variable::GlobalInputArray(id, item, _)
464 | Variable::GlobalOutputArray(id, item, _)
465 | Variable::Named { id, item, .. } => {
466 let ptr_ty =
467 Item::Pointer(StorageClass::StorageBuffer, Box::new(item.clone())).id(self);
468 let zero = self.const_u32(0);
469 let id = access_chain(self, ptr_ty, None, *id, vec![zero, index_id]).unwrap();
470
471 IndexedVariable::Pointer(id, item.clone())
472 }
473 Variable::Local {
474 id,
475 item: Item::Vector(elem, _),
476 } => {
477 let ptr_ty =
478 Item::Pointer(StorageClass::Function, Box::new(Item::Scalar(*elem))).id(self);
479 let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
480
481 IndexedVariable::Pointer(id, Item::Scalar(*elem))
482 }
483 Variable::Shared(id, Item::Vector(elem, _)) => {
484 let ptr_ty =
485 Item::Pointer(StorageClass::Workgroup, Box::new(Item::Scalar(*elem))).id(self);
486
487 let mut index = vec![index_id];
488 if self.compilation_options.supports_explicit_smem {
489 index.insert(0, self.const_u32(0));
490 }
491
492 let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
493
494 IndexedVariable::Pointer(id, Item::Scalar(*elem))
495 }
496 Variable::LocalBinding {
497 id,
498 item: Item::Vector(elem, vec),
499 variable,
500 } if index.as_const().is_some() => IndexedVariable::Composite(
501 self.get_binding(*id, variable),
502 index.as_const().unwrap().as_u32(),
503 Item::Vector(*elem, *vec),
504 ),
505 Variable::LocalBinding {
506 id,
507 item: Item::Vector(elem, vec),
508 variable,
509 } => IndexedVariable::DynamicComposite(
510 self.get_binding(*id, variable),
511 index_id,
512 Item::Vector(*elem, *vec),
513 ),
514 Variable::Versioned {
515 id,
516 item: Item::Vector(elem, vec),
517 variable,
518 } if index.as_const().is_some() => IndexedVariable::Composite(
519 self.get_versioned(*id, variable),
520 index.as_const().unwrap().as_u32(),
521 Item::Vector(*elem, *vec),
522 ),
523 Variable::Versioned {
524 id,
525 item: Item::Vector(elem, vec),
526 variable,
527 } => IndexedVariable::DynamicComposite(
528 self.get_versioned(*id, variable),
529 index_id,
530 Item::Vector(*elem, *vec),
531 ),
532 Variable::Shared(id, item) if self.compilation_options.supports_explicit_smem => {
533 let ptr_ty =
534 Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
535 let index = vec![self.const_u32(0)];
536 let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
537 IndexedVariable::Pointer(id, item.clone())
538 }
539 Variable::Local { .. }
540 | Variable::Shared(..)
541 | Variable::LocalBinding { .. }
542 | Variable::Versioned { .. } => IndexedVariable::Scalar(variable.clone()),
543 Variable::Constant(_, val, item) => {
544 let scalar_item = Item::Scalar(item.elem());
545 let (id, val) = self.static_cast(*val, &item.elem(), &scalar_item);
546 IndexedVariable::Scalar(Variable::Constant(id, val, scalar_item))
547 }
548 Variable::Slice { ptr, offset, .. } => {
549 let item = Item::Scalar(Elem::Int(32, false));
550 let int = item.id(self);
551 let index = match index.as_const() {
552 Some(ConstVal::Bit32(0)) => *offset,
553 _ => self.i_add(int, None, *offset, index_id).unwrap(),
554 };
555 self.index(ptr, &Variable::Raw(index, item), unchecked)
556 }
557 Variable::SharedArray(id, item, _) => {
558 let ptr_ty =
559 Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
560 let mut index = vec![index_id];
561 if self.compilation_options.supports_explicit_smem {
562 index.insert(0, self.const_u32(0));
563 }
564 let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
565 IndexedVariable::Pointer(id, item.clone())
566 }
567 Variable::ConstantArray(id, item, _) | Variable::LocalArray(id, item, _) => {
568 let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
569 let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
570 IndexedVariable::Pointer(id, item.clone())
571 }
572 var => unimplemented!("Can't index into {var:?}"),
573 }
574 }
575
576 pub fn read_indexed(&mut self, out: &Variable, variable: &Variable, index: &Variable) -> Word {
577 let always_in_bounds = is_always_in_bounds(variable, index);
578 let indexed = self.index(variable, index, always_in_bounds);
579
580 let read = |b: &mut Self| match indexed {
581 IndexedVariable::Pointer(ptr, item) => {
582 let ty = item.id(b);
583 let out_id = b.write_id(out);
584 b.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
585 }
586 IndexedVariable::Composite(var, index, item) => {
587 let elem = item.elem();
588 let ty = elem.id(b);
589 let out_id = b.write_id(out);
590 b.composite_extract(ty, Some(out_id), var, vec![index])
591 .unwrap()
592 }
593 IndexedVariable::DynamicComposite(var, index, item) => {
594 let elem = item.elem();
595 let ty = elem.id(b);
596 let out_id = b.write_id(out);
597 b.vector_extract_dynamic(ty, Some(out_id), var, index)
598 .unwrap()
599 }
600 IndexedVariable::Scalar(var) => {
601 let ty = out.item().id(b);
602 let input = b.read(&var);
603 let out_id = b.write_id(out);
604 b.copy_object(ty, Some(out_id), input).unwrap();
605 b.write(out, out_id);
606 out_id
607 }
608 };
609
610 read(self)
611 }
612
613 pub fn read_indexed_unchecked(
614 &mut self,
615 out: &Variable,
616 variable: &Variable,
617 index: &Variable,
618 ) -> Word {
619 let indexed = self.index(variable, index, true);
620
621 match indexed {
622 IndexedVariable::Pointer(ptr, item) => {
623 let ty = item.id(self);
624 let out_id = self.write_id(out);
625 self.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
626 }
627 IndexedVariable::Composite(var, index, item) => {
628 let elem = item.elem();
629 let ty = elem.id(self);
630 let out_id = self.write_id(out);
631 self.composite_extract(ty, Some(out_id), var, vec![index])
632 .unwrap()
633 }
634 IndexedVariable::DynamicComposite(var, index, item) => {
635 let elem = item.elem();
636 let ty = elem.id(self);
637 let out_id = self.write_id(out);
638 self.vector_extract_dynamic(ty, Some(out_id), var, index)
639 .unwrap()
640 }
641 IndexedVariable::Scalar(var) => {
642 let ty = out.item().id(self);
643 let input = self.read(&var);
644 let out_id = self.write_id(out);
645 self.copy_object(ty, Some(out_id), input).unwrap();
646 self.write(out, out_id);
647 out_id
648 }
649 }
650 }
651
652 pub fn index_ptr(&mut self, var: &Variable, index: &Variable) -> Word {
653 let always_in_bounds = is_always_in_bounds(var, index);
654 match self.index(var, index, always_in_bounds) {
655 IndexedVariable::Pointer(ptr, _) => ptr,
656 other => unreachable!("{other:?}"),
657 }
658 }
659
660 pub fn write_id(&mut self, variable: &Variable) -> Word {
661 match variable {
662 Variable::LocalBinding { id, variable, .. } => self.get_binding(*id, variable),
663 Variable::Versioned { id, variable, .. } => self.get_versioned(*id, variable),
664 Variable::Local { .. } => self.id(),
665 Variable::Shared(..) => self.id(),
666 Variable::GlobalScalar(id, _) => *id,
667 Variable::Raw(id, _) => *id,
668 Variable::Constant(_, _, _) => panic!("Can't write to constant scalar"),
669 Variable::GlobalInputArray(_, _, _)
670 | Variable::GlobalOutputArray(_, _, _)
671 | Variable::Slice { .. }
672 | Variable::Named { .. }
673 | Variable::SharedArray(_, _, _)
674 | Variable::ConstantArray(_, _, _)
675 | Variable::LocalArray(_, _, _) => panic!("Can't write to unindexed array"),
676 global => panic!("Can't write to builtin {global:?}"),
677 }
678 }
679
680 pub fn write(&mut self, variable: &Variable, value: Word) {
681 match variable {
682 Variable::Shared(id, item) if self.compilation_options.supports_explicit_smem => {
683 let ptr_ty =
684 Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
685 let index = vec![self.const_u32(0)];
686 let access = self.access_chain(ptr_ty, None, *id, index).unwrap();
687 self.store(access, value, None, []).unwrap()
688 }
689 Variable::Local { id, .. } | Variable::Shared(id, _) => {
690 self.store(*id, value, None, []).unwrap()
691 }
692
693 Variable::Slice { ptr, .. } => self.write(ptr, value),
694 _ => {}
695 }
696 }
697
698 pub fn write_indexed(&mut self, out: &Variable, index: &Variable, value: Word) {
699 let always_in_bounds = is_always_in_bounds(out, index);
700 let variable = self.index(out, index, always_in_bounds);
701
702 let write = |b: &mut Self| match variable {
703 IndexedVariable::Pointer(ptr, _) => b.store(ptr, value, None, vec![]).unwrap(),
704 IndexedVariable::Composite(var, index, item) => {
705 let ty = item.id(b);
706 let id = b
707 .composite_insert(ty, None, value, var, vec![index])
708 .unwrap();
709 b.write(out, id);
710 }
711 IndexedVariable::DynamicComposite(var, index, item) => {
712 let ty = item.id(b);
713 let id = b
714 .vector_insert_dynamic(ty, None, value, var, index)
715 .unwrap();
716 b.write(out, id);
717 }
718 IndexedVariable::Scalar(var) => b.write(&var, value),
719 };
720
721 write(self)
722 }
723
724 pub fn write_indexed_unchecked(&mut self, out: &Variable, index: &Variable, value: Word) {
725 let variable = self.index(out, index, true);
726
727 match variable {
728 IndexedVariable::Pointer(ptr, _) => self.store(ptr, value, None, vec![]).unwrap(),
729 IndexedVariable::Composite(var, index, item) => {
730 let ty = item.id(self);
731 let out_id = self
732 .composite_insert(ty, None, value, var, vec![index])
733 .unwrap();
734 self.write(out, out_id);
735 }
736 IndexedVariable::DynamicComposite(var, index, item) => {
737 let ty = item.id(self);
738 let out_id = self
739 .vector_insert_dynamic(ty, None, value, var, index)
740 .unwrap();
741 self.write(out, out_id);
742 }
743 IndexedVariable::Scalar(var) => self.write(&var, value),
744 }
745 }
746}
747
748fn is_always_in_bounds(var: &Variable, index: &Variable) -> bool {
749 let len = match var {
750 Variable::SharedArray(_, _, len)
751 | Variable::ConstantArray(_, _, len)
752 | Variable::LocalArray(_, _, len)
753 | Variable::Slice {
754 const_len: Some(len),
755 ..
756 } => *len,
757 _ => return false,
758 };
759
760 let const_index = match index {
761 Variable::Constant(_, value, _) => value.as_u32(),
762 _ => return false,
763 };
764
765 const_index < len
766}