1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{
6 SpirvCompiler, SpirvTarget,
7 item::{Elem, Item},
8 lookups::Array,
9};
10use cubecl_core::ir::{self, ConstantValue, Id};
11use rspirv::{
12 dr::Builder,
13 spirv::{self, FPEncoding, StorageClass, Word},
14};
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum Variable {
18 GlobalInputArray(Word, Item, u32),
19 GlobalOutputArray(Word, Item, u32),
20 GlobalScalar(Word, Elem),
21 Constant(Word, ConstVal, Item),
22 Local {
23 id: Word,
24 item: Item,
25 },
26 Versioned {
27 id: (Id, u16),
28 item: Item,
29 variable: ir::Variable,
30 },
31 LocalBinding {
32 id: Id,
33 item: Item,
34 variable: ir::Variable,
35 },
36 Raw(Word, Item),
37 Slice {
38 ptr: Box<Variable>,
39 offset: Word,
40 end: Word,
41 const_len: Option<u32>,
42 item: Item,
43 },
44 SharedArray(Word, Item, u32),
45 Shared(Word, Item),
46 ConstantArray(Word, Item, u32),
47 LocalArray(Word, Item, u32),
48 CoopMatrix(Id, Elem),
49 Id(Word),
50 Builtin(Word, Item),
51}
52
53impl Variable {
54 pub fn scope(&self) -> spirv::Scope {
55 match self {
56 Variable::GlobalInputArray(..)
57 | Variable::GlobalOutputArray(..)
58 | Variable::GlobalScalar(..) => spirv::Scope::Device,
59 Variable::SharedArray(..) | Variable::Shared(..) => spirv::Scope::Workgroup,
60 Variable::CoopMatrix(..) => spirv::Scope::Subgroup,
61 Variable::Slice { ptr, .. } => ptr.scope(),
62 Variable::Raw(..) => unimplemented!("Can't get scope of raw variable"),
63 Variable::Id(_) => unimplemented!("Can't get scope of raw id"),
64 _ => spirv::Scope::Invocation,
65 }
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
70pub enum ConstVal {
71 Bit32(u32),
72 Bit64(u64),
73}
74
75impl ConstVal {
76 pub fn as_u64(&self) -> u64 {
77 match self {
78 ConstVal::Bit32(val) => *val as u64,
79 ConstVal::Bit64(val) => *val,
80 }
81 }
82
83 pub fn as_u32(&self) -> u32 {
84 match self {
85 ConstVal::Bit32(val) => *val,
86 ConstVal::Bit64(_) => panic!("Truncating 64 bit variable to 32 bit"),
87 }
88 }
89
90 pub fn as_float(&self, width: u32, encoding: Option<FPEncoding>) -> f64 {
91 match (width, encoding) {
92 (64, _) => f64::from_bits(self.as_u64()),
93 (32, _) => f32::from_bits(self.as_u32()) as f64,
94 (16, None) => half::f16::from_bits(self.as_u32() as u16).to_f64(),
95 (_, Some(FPEncoding::BFloat16KHR)) => {
96 half::bf16::from_bits(self.as_u32() as u16).to_f64()
97 }
98 (_, Some(FPEncoding::Float8E4M3EXT)) => {
99 cubecl_common::e4m3::from_bits(self.as_u32() as u8).to_f64()
100 }
101 (_, Some(FPEncoding::Float8E5M2EXT)) => {
102 cubecl_common::e5m2::from_bits(self.as_u32() as u8).to_f64()
103 }
104 _ => unreachable!(),
105 }
106 }
107
108 pub fn as_int(&self, width: u32) -> i64 {
109 unsafe {
110 match width {
111 64 => transmute::<u64, i64>(self.as_u64()),
112 32 => transmute::<u32, i32>(self.as_u32()) as i64,
113 16 => transmute::<u16, i16>(self.as_u32() as u16) as i64,
114 8 => transmute::<u8, i8>(self.as_u32() as u8) as i64,
115 _ => unreachable!(),
116 }
117 }
118 }
119
120 pub fn from_float(value: f64, width: u32, encoding: Option<FPEncoding>) -> Self {
121 match (width, encoding) {
122 (64, _) => ConstVal::Bit64(value.to_bits()),
123 (32, _) => ConstVal::Bit32((value as f32).to_bits()),
124 (16, None) => ConstVal::Bit32(half::f16::from_f64(value).to_bits() as u32),
125 (_, Some(FPEncoding::BFloat16KHR)) => {
126 ConstVal::Bit32(half::bf16::from_f64(value).to_bits() as u32)
127 }
128 (_, Some(FPEncoding::Float8E4M3EXT)) => {
129 ConstVal::Bit32(cubecl_common::e4m3::from_f64(value).to_bits() as u32)
130 }
131 (_, Some(FPEncoding::Float8E5M2EXT)) => {
132 ConstVal::Bit32(cubecl_common::e5m2::from_f64(value).to_bits() as u32)
133 }
134 _ => unreachable!(),
135 }
136 }
137
138 pub fn from_int(value: i64, width: u32) -> Self {
139 match width {
140 64 => ConstVal::Bit64(unsafe { transmute::<i64, u64>(value) }),
141 32 => ConstVal::Bit32(unsafe { transmute::<i32, u32>(value as i32) }),
142 16 => ConstVal::Bit32(unsafe { transmute::<i16, u16>(value as i16) } as u32),
143 8 => ConstVal::Bit32(unsafe { transmute::<i8, u8>(value as i8) } as u32),
144 _ => unreachable!(),
145 }
146 }
147
148 pub fn from_uint(value: u64, width: u32) -> Self {
149 match width {
150 64 => ConstVal::Bit64(value),
151 32 => ConstVal::Bit32(value as u32),
152 16 => ConstVal::Bit32(value as u16 as u32),
153 8 => ConstVal::Bit32(value as u8 as u32),
154 _ => unreachable!(),
155 }
156 }
157
158 pub fn from_bool(value: bool) -> Self {
159 ConstVal::Bit32(value as u32)
160 }
161}
162
163impl From<(ConstantValue, Item)> for ConstVal {
164 fn from((value, ty): (ConstantValue, Item)) -> Self {
165 let elem = ty.elem();
166 let width = elem.size() * 8;
167 match value {
168 ConstantValue::Int(val) => ConstVal::from_int(val, width),
169 ConstantValue::Float(val) => ConstVal::from_float(val, width, elem.float_encoding()),
170 ConstantValue::UInt(val) => ConstVal::from_uint(val, width),
171 ConstantValue::Bool(val) => ConstVal::from_bool(val),
172 }
173 }
174}
175
176impl From<u32> for ConstVal {
177 fn from(value: u32) -> Self {
178 ConstVal::Bit32(value)
179 }
180}
181
182impl From<f32> for ConstVal {
183 fn from(value: f32) -> Self {
184 ConstVal::Bit32(value.to_bits())
185 }
186}
187
188impl Variable {
189 pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
190 match self {
191 Variable::GlobalInputArray(id, _, _) => *id,
192 Variable::GlobalOutputArray(id, _, _) => *id,
193 Variable::GlobalScalar(id, _) => *id,
194 Variable::Constant(id, _, _) => *id,
195 Variable::Local { id, .. } => *id,
196 Variable::Versioned {
197 id, variable: var, ..
198 } => b.get_versioned(*id, var),
199 Variable::LocalBinding {
200 id, variable: var, ..
201 } => b.get_binding(*id, var),
202 Variable::Raw(id, _) => *id,
203 Variable::Slice { ptr, .. } => ptr.id(b),
204 Variable::SharedArray(id, _, _) => *id,
205 Variable::Shared(id, _) => *id,
206 Variable::ConstantArray(id, _, _) => *id,
207 Variable::LocalArray(id, _, _) => *id,
208 Variable::CoopMatrix(_, _) => unimplemented!("Can't get ID from matrix var"),
209 Variable::Id(id) => *id,
210 Variable::Builtin(id, ..) => *id,
211 }
212 }
213
214 pub fn item(&self) -> Item {
215 match self {
216 Variable::GlobalInputArray(_, item, _) => item.clone(),
217 Variable::GlobalOutputArray(_, item, _) => item.clone(),
218 Variable::GlobalScalar(_, elem) => Item::Scalar(*elem),
219 Variable::Constant(_, _, item) => item.clone(),
220 Variable::Local { item, .. } => item.clone(),
221 Variable::Versioned { item, .. } => item.clone(),
222 Variable::LocalBinding { item, .. } => item.clone(),
223 Variable::Slice { item, .. } => item.clone(),
224 Variable::SharedArray(_, item, _) => item.clone(),
225 Variable::Shared(_, item) => item.clone(),
226 Variable::ConstantArray(_, item, _) => item.clone(),
227 Variable::LocalArray(_, item, _) => item.clone(),
228 Variable::CoopMatrix(_, elem) => Item::Scalar(*elem),
229 Variable::Builtin(_, item) => item.clone(),
230 Variable::Raw(_, item) => item.clone(),
231 Variable::Id(_) => unimplemented!("Can't get item of raw ID"),
232 }
233 }
234
235 pub fn indexed_item(&self) -> Item {
236 match self {
237 Variable::LocalBinding {
238 item: Item::Vector(elem, _),
239 ..
240 } => Item::Scalar(*elem),
241 Variable::Local {
242 item: Item::Vector(elem, _),
243 ..
244 } => Item::Scalar(*elem),
245 Variable::Versioned {
246 item: Item::Vector(elem, _),
247 ..
248 } => Item::Scalar(*elem),
249 other => other.item(),
250 }
251 }
252
253 pub fn elem(&self) -> Elem {
254 self.item().elem()
255 }
256
257 pub fn has_len(&self) -> bool {
258 matches!(
259 self,
260 Variable::GlobalInputArray(_, _, _)
261 | Variable::GlobalOutputArray(_, _, _)
262 | Variable::Slice { .. }
263 | Variable::SharedArray(_, _, _)
264 | Variable::ConstantArray(_, _, _)
265 | Variable::LocalArray(_, _, _)
266 )
267 }
268
269 pub fn has_buffer_len(&self) -> bool {
270 matches!(
271 self,
272 Variable::GlobalInputArray(_, _, _) | Variable::GlobalOutputArray(_, _, _)
273 )
274 }
275
276 pub fn as_const(&self) -> Option<ConstVal> {
277 match self {
278 Self::Constant(_, val, _) => Some(*val),
279 _ => None,
280 }
281 }
282
283 pub fn as_binding(&self) -> Option<Id> {
284 match self {
285 Self::LocalBinding { id, .. } => Some(*id),
286 _ => None,
287 }
288 }
289}
290
291#[derive(Debug)]
292pub enum IndexedVariable {
293 Pointer(Word, Item),
294 Composite(Word, u32, Item),
295 DynamicComposite(Word, u32, Item),
296 Scalar(Variable),
297}
298
299impl<T: SpirvTarget> SpirvCompiler<T> {
300 pub fn compile_variable(&mut self, variable: ir::Variable) -> Variable {
301 let item = variable.ty;
302 match variable.kind {
303 ir::VariableKind::Constant(value) => {
304 let item = self.compile_type(item);
305 let const_val = (value, item.clone()).into();
306
307 if let Some(existing) = self.state.constants.get(&(const_val, item.clone())) {
308 Variable::Constant(*existing, const_val, item)
309 } else {
310 let id = item.constant(self, const_val);
311 self.state.constants.insert((const_val, item.clone()), id);
312 Variable::Constant(id, const_val, item)
313 }
314 }
315 ir::VariableKind::GlobalInputArray(pos) => {
316 let id = self.state.buffers[pos as usize];
317 Variable::GlobalInputArray(id, self.compile_type(item), pos)
318 }
319 ir::VariableKind::GlobalOutputArray(pos) => {
320 let id = self.state.buffers[pos as usize];
321 Variable::GlobalOutputArray(id, self.compile_type(item), pos)
322 }
323 ir::VariableKind::GlobalScalar(id) => self.global_scalar(id, item.storage_type()),
324 ir::VariableKind::LocalMut { id } => {
325 let item = self.compile_type(item);
326 let var = self.get_local(id, &item, variable);
327 Variable::Local { id: var, item }
328 }
329 ir::VariableKind::Versioned { id, version } => {
330 let item = self.compile_type(item);
331 let id = (id, version);
332 Variable::Versioned { id, item, variable }
333 }
334 ir::VariableKind::LocalConst { id } => {
335 let item = self.compile_type(item);
336 Variable::LocalBinding { id, item, variable }
337 }
338 ir::VariableKind::Builtin(builtin) => {
339 let item = self.compile_type(item);
340 self.compile_builtin(builtin, item)
341 }
342 ir::VariableKind::ConstantArray { id, length, .. } => {
343 let item = self.compile_type(item);
344 let id = self.state.const_arrays[id as usize].id;
345 Variable::ConstantArray(id, item, length as u32)
346 }
347 ir::VariableKind::SharedArray { id, length, .. } => {
348 let item = self.compile_type(item);
349 let id = self.state.shared_arrays[&id].id;
350 Variable::SharedArray(id, item, length as u32)
351 }
352 ir::VariableKind::Shared { id } => {
353 let item = self.compile_type(item);
354 let id = self.state.shared[&id].id;
355 Variable::Shared(id, item)
356 }
357 ir::VariableKind::LocalArray {
358 id,
359 length,
360 unroll_factor,
361 } => {
362 let item = self.compile_type(item);
363 let id = if let Some(arr) = self.state.local_arrays.get(&id) {
364 arr.id
365 } else {
366 let arr_ty = Item::Array(Box::new(item.clone()), length as u32);
367 let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(arr_ty)).id(self);
368 let arr_id = self.declare_function_variable(ptr_ty);
369 self.debug_var_name(arr_id, variable);
370 let arr = Array {
371 id: arr_id,
372 item: item.clone(),
373 len: (length * unroll_factor) as u32,
374 var: variable,
375 alignment: None,
376 };
377 self.state.local_arrays.insert(id, arr);
378 arr_id
379 };
380 Variable::LocalArray(id, item, length as u32)
381 }
382 ir::VariableKind::Matrix { id, mat } => {
383 let elem = self.compile_type(ir::Type::new(mat.storage)).elem();
384 if self.state.matrices.contains_key(&id) {
385 Variable::CoopMatrix(id, elem)
386 } else {
387 let matrix = self.init_coop_matrix(mat, variable);
388 self.state.matrices.insert(id, matrix);
389 Variable::CoopMatrix(id, elem)
390 }
391 }
392 ir::VariableKind::Pipeline { .. } => panic!("Pipeline not supported."),
393 ir::VariableKind::BarrierToken { .. } => {
394 panic!("Barrier not supported.")
395 }
396 ir::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
397 ir::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
398 }
399 }
400
401 pub fn read(&mut self, variable: &Variable) -> Word {
402 match variable {
403 Variable::Slice { ptr, .. } => self.read(ptr),
404 Variable::Shared(id, item)
405 if self.compilation_options.vulkan.supports_explicit_smem =>
406 {
407 let ty = item.id(self);
408 let ptr_ty =
409 Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
410 let index = vec![self.const_u32(0)];
411 let access = self.access_chain(ptr_ty, None, *id, index).unwrap();
412 self.load(ty, None, access, None, []).unwrap()
413 }
414 Variable::Local { id, item } | Variable::Shared(id, item) => {
415 let ty = item.id(self);
416 self.load(ty, None, *id, None, []).unwrap()
417 }
418 ssa => ssa.id(self),
419 }
420 }
421
422 pub fn read_as(&mut self, variable: &Variable, item: &Item) -> Word {
423 if let Some(as_const) = variable.as_const() {
424 self.static_cast(as_const, &variable.elem(), item).0
425 } else {
426 let id = self.read(variable);
427 variable.item().cast_to(self, None, id, item)
428 }
429 }
430
431 pub fn index(
432 &mut self,
433 variable: &Variable,
434 index: &Variable,
435 unchecked: bool,
436 ) -> IndexedVariable {
437 let access_chain = if unchecked {
438 Builder::in_bounds_access_chain
439 } else {
440 Builder::access_chain
441 };
442 let index_id = self.read(index);
443 match variable {
444 Variable::GlobalInputArray(id, item, _) | Variable::GlobalOutputArray(id, item, _) => {
445 let ptr_ty =
446 Item::Pointer(StorageClass::StorageBuffer, Box::new(item.clone())).id(self);
447 let zero = self.const_u32(0);
448 let id = access_chain(self, ptr_ty, None, *id, vec![zero, index_id]).unwrap();
449
450 IndexedVariable::Pointer(id, item.clone())
451 }
452 Variable::Local {
453 id,
454 item: Item::Vector(elem, _),
455 } => {
456 let ptr_ty =
457 Item::Pointer(StorageClass::Function, Box::new(Item::Scalar(*elem))).id(self);
458 let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
459
460 IndexedVariable::Pointer(id, Item::Scalar(*elem))
461 }
462 Variable::Shared(id, Item::Vector(elem, _)) => {
463 let ptr_ty =
464 Item::Pointer(StorageClass::Workgroup, Box::new(Item::Scalar(*elem))).id(self);
465
466 let mut index = vec![index_id];
467 if self.compilation_options.vulkan.supports_explicit_smem {
468 index.insert(0, self.const_u32(0));
469 }
470
471 let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
472
473 IndexedVariable::Pointer(id, Item::Scalar(*elem))
474 }
475 Variable::LocalBinding {
476 id,
477 item: Item::Vector(elem, vec),
478 variable,
479 } if index.as_const().is_some() => IndexedVariable::Composite(
480 self.get_binding(*id, variable),
481 index.as_const().unwrap().as_u64() as u32,
482 Item::Vector(*elem, *vec),
483 ),
484 Variable::LocalBinding {
485 id,
486 item: Item::Vector(elem, vec),
487 variable,
488 } => IndexedVariable::DynamicComposite(
489 self.get_binding(*id, variable),
490 index_id,
491 Item::Vector(*elem, *vec),
492 ),
493 Variable::Versioned {
494 id,
495 item: Item::Vector(elem, vec),
496 variable,
497 } if index.as_const().is_some() => IndexedVariable::Composite(
498 self.get_versioned(*id, variable),
499 index.as_const().unwrap().as_u64() as u32,
500 Item::Vector(*elem, *vec),
501 ),
502 Variable::Versioned {
503 id,
504 item: Item::Vector(elem, vec),
505 variable,
506 } => IndexedVariable::DynamicComposite(
507 self.get_versioned(*id, variable),
508 index_id,
509 Item::Vector(*elem, *vec),
510 ),
511 Variable::Shared(id, item)
512 if self.compilation_options.vulkan.supports_explicit_smem =>
513 {
514 let ptr_ty =
515 Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
516 let index = vec![self.const_u32(0)];
517 let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
518 IndexedVariable::Pointer(id, item.clone())
519 }
520 Variable::Local { .. }
521 | Variable::Shared(..)
522 | Variable::LocalBinding { .. }
523 | Variable::Versioned { .. } => IndexedVariable::Scalar(variable.clone()),
524 Variable::Constant(_, val, item) => {
525 let scalar_item = Item::Scalar(item.elem());
526 let (id, val) = self.static_cast(*val, &item.elem(), &scalar_item);
527 IndexedVariable::Scalar(Variable::Constant(id, val, scalar_item))
528 }
529 Variable::Slice { ptr, offset, .. } => {
530 let item = Item::Scalar(Elem::Int(32, false));
531 let int = item.id(self);
532 let index = match index.as_const() {
533 Some(ConstVal::Bit32(0)) => *offset,
534 _ => self.i_add(int, None, *offset, index_id).unwrap(),
535 };
536 self.index(ptr, &Variable::Raw(index, item), unchecked)
537 }
538 Variable::SharedArray(id, item, _) => {
539 let ptr_ty =
540 Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
541 let mut index = vec![index_id];
542 if self.compilation_options.vulkan.supports_explicit_smem {
543 index.insert(0, self.const_u32(0));
544 }
545 let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
546 IndexedVariable::Pointer(id, item.clone())
547 }
548 Variable::ConstantArray(id, item, _) | Variable::LocalArray(id, item, _) => {
549 let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
550 let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
551 IndexedVariable::Pointer(id, item.clone())
552 }
553 var => unimplemented!("Can't index into {var:?}"),
554 }
555 }
556
557 pub fn read_indexed(&mut self, out: &Variable, variable: &Variable, index: &Variable) -> Word {
558 let always_in_bounds = is_always_in_bounds(variable, index);
559 let indexed = self.index(variable, index, always_in_bounds);
560
561 let read = |b: &mut Self| match indexed {
562 IndexedVariable::Pointer(ptr, item) => {
563 let ty = item.id(b);
564 let out_id = b.write_id(out);
565 b.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
566 }
567 IndexedVariable::Composite(var, index, item) => {
568 let elem = item.elem();
569 let ty = elem.id(b);
570 let out_id = b.write_id(out);
571 b.composite_extract(ty, Some(out_id), var, vec![index])
572 .unwrap()
573 }
574 IndexedVariable::DynamicComposite(var, index, item) => {
575 let elem = item.elem();
576 let ty = elem.id(b);
577 let out_id = b.write_id(out);
578 b.vector_extract_dynamic(ty, Some(out_id), var, index)
579 .unwrap()
580 }
581 IndexedVariable::Scalar(var) => {
582 let ty = out.item().id(b);
583 let input = b.read(&var);
584 let out_id = b.write_id(out);
585 b.copy_object(ty, Some(out_id), input).unwrap();
586 b.write(out, out_id);
587 out_id
588 }
589 };
590
591 read(self)
592 }
593
594 pub fn read_indexed_unchecked(
595 &mut self,
596 out: &Variable,
597 variable: &Variable,
598 index: &Variable,
599 ) -> Word {
600 let indexed = self.index(variable, index, true);
601
602 match indexed {
603 IndexedVariable::Pointer(ptr, item) => {
604 let ty = item.id(self);
605 let out_id = self.write_id(out);
606 self.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
607 }
608 IndexedVariable::Composite(var, index, item) => {
609 let elem = item.elem();
610 let ty = elem.id(self);
611 let out_id = self.write_id(out);
612 self.composite_extract(ty, Some(out_id), var, vec![index])
613 .unwrap()
614 }
615 IndexedVariable::DynamicComposite(var, index, item) => {
616 let elem = item.elem();
617 let ty = elem.id(self);
618 let out_id = self.write_id(out);
619 self.vector_extract_dynamic(ty, Some(out_id), var, index)
620 .unwrap()
621 }
622 IndexedVariable::Scalar(var) => {
623 let ty = out.item().id(self);
624 let input = self.read(&var);
625 let out_id = self.write_id(out);
626 self.copy_object(ty, Some(out_id), input).unwrap();
627 self.write(out, out_id);
628 out_id
629 }
630 }
631 }
632
633 pub fn index_ptr(&mut self, var: &Variable, index: &Variable) -> Word {
634 let always_in_bounds = is_always_in_bounds(var, index);
635 match self.index(var, index, always_in_bounds) {
636 IndexedVariable::Pointer(ptr, _) => ptr,
637 other => unreachable!("{other:?}"),
638 }
639 }
640
641 pub fn write_id(&mut self, variable: &Variable) -> Word {
642 match variable {
643 Variable::LocalBinding { id, variable, .. } => self.get_binding(*id, variable),
644 Variable::Versioned { id, variable, .. } => self.get_versioned(*id, variable),
645 Variable::Local { .. } => self.id(),
646 Variable::Shared(..) => self.id(),
647 Variable::GlobalScalar(id, _) => *id,
648 Variable::Raw(id, _) => *id,
649 Variable::Constant(_, _, _) => panic!("Can't write to constant scalar"),
650 Variable::GlobalInputArray(_, _, _)
651 | Variable::GlobalOutputArray(_, _, _)
652 | Variable::Slice { .. }
653 | Variable::SharedArray(_, _, _)
654 | Variable::ConstantArray(_, _, _)
655 | Variable::LocalArray(_, _, _) => panic!("Can't write to unindexed array"),
656 global => panic!("Can't write to builtin {global:?}"),
657 }
658 }
659
660 pub fn write(&mut self, variable: &Variable, value: Word) {
661 match variable {
662 Variable::Shared(id, item)
663 if self.compilation_options.vulkan.supports_explicit_smem =>
664 {
665 let ptr_ty =
666 Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
667 let index = vec![self.const_u32(0)];
668 let access = self.access_chain(ptr_ty, None, *id, index).unwrap();
669 self.store(access, value, None, []).unwrap()
670 }
671 Variable::Local { id, .. } | Variable::Shared(id, _) => {
672 self.store(*id, value, None, []).unwrap()
673 }
674
675 Variable::Slice { ptr, .. } => self.write(ptr, value),
676 _ => {}
677 }
678 }
679
680 pub fn write_indexed(&mut self, out: &Variable, index: &Variable, value: Word) {
681 let always_in_bounds = is_always_in_bounds(out, index);
682 let variable = self.index(out, index, always_in_bounds);
683
684 let write = |b: &mut Self| match variable {
685 IndexedVariable::Pointer(ptr, _) => b.store(ptr, value, None, vec![]).unwrap(),
686 IndexedVariable::Composite(var, index, item) => {
687 let ty = item.id(b);
688 let id = b
689 .composite_insert(ty, None, value, var, vec![index])
690 .unwrap();
691 b.write(out, id);
692 }
693 IndexedVariable::DynamicComposite(var, index, item) => {
694 let ty = item.id(b);
695 let id = b
696 .vector_insert_dynamic(ty, None, value, var, index)
697 .unwrap();
698 b.write(out, id);
699 }
700 IndexedVariable::Scalar(var) => b.write(&var, value),
701 };
702
703 write(self)
704 }
705
706 pub fn write_indexed_unchecked(&mut self, out: &Variable, index: &Variable, value: Word) {
707 let variable = self.index(out, index, true);
708
709 match variable {
710 IndexedVariable::Pointer(ptr, _) => self.store(ptr, value, None, vec![]).unwrap(),
711 IndexedVariable::Composite(var, index, item) => {
712 let ty = item.id(self);
713 let out_id = self
714 .composite_insert(ty, None, value, var, vec![index])
715 .unwrap();
716 self.write(out, out_id);
717 }
718 IndexedVariable::DynamicComposite(var, index, item) => {
719 let ty = item.id(self);
720 let out_id = self
721 .vector_insert_dynamic(ty, None, value, var, index)
722 .unwrap();
723 self.write(out, out_id);
724 }
725 IndexedVariable::Scalar(var) => self.write(&var, value),
726 }
727 }
728}
729
730fn is_always_in_bounds(var: &Variable, index: &Variable) -> bool {
731 let len = match var {
732 Variable::SharedArray(_, _, len)
733 | Variable::ConstantArray(_, _, len)
734 | Variable::LocalArray(_, _, len)
735 | Variable::Slice {
736 const_len: Some(len),
737 ..
738 } => *len,
739 _ => return false,
740 };
741
742 let const_index = match index {
743 Variable::Constant(_, value, _) => value.as_u64(),
744 _ => return false,
745 };
746
747 const_index < len as u64
748}