Skip to main content

cubecl_spirv/
variable.rs

1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{
6    SpirvCompiler, SpirvTarget,
7    item::{Elem, Item},
8    lookups::Array,
9};
10use cubecl_core::ir::{self, ConstantValue, Id};
11use rspirv::{
12    dr::Builder,
13    spirv::{self, FPEncoding, StorageClass, Word},
14};
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum Variable {
18    GlobalInputArray(Word, Item, u32),
19    GlobalOutputArray(Word, Item, u32),
20    GlobalScalar(Word, Elem),
21    Constant(Word, ConstVal, Item),
22    Local {
23        id: Word,
24        item: Item,
25    },
26    Versioned {
27        id: (Id, u16),
28        item: Item,
29        variable: ir::Variable,
30    },
31    LocalBinding {
32        id: Id,
33        item: Item,
34        variable: ir::Variable,
35    },
36    Raw(Word, Item),
37    Slice {
38        ptr: Box<Variable>,
39        offset: Word,
40        end: Word,
41        const_len: Option<u32>,
42        item: Item,
43    },
44    SharedArray(Word, Item, u32),
45    Shared(Word, Item),
46    ConstantArray(Word, Item, u32),
47    LocalArray(Word, Item, u32),
48    CoopMatrix(Id, Elem),
49    Id(Word),
50    Builtin(Word, Item),
51}
52
53impl Variable {
54    pub fn scope(&self) -> spirv::Scope {
55        match self {
56            Variable::GlobalInputArray(..)
57            | Variable::GlobalOutputArray(..)
58            | Variable::GlobalScalar(..) => spirv::Scope::Device,
59            Variable::SharedArray(..) | Variable::Shared(..) => spirv::Scope::Workgroup,
60            Variable::CoopMatrix(..) => spirv::Scope::Subgroup,
61            Variable::Slice { ptr, .. } => ptr.scope(),
62            Variable::Raw(..) => unimplemented!("Can't get scope of raw variable"),
63            Variable::Id(_) => unimplemented!("Can't get scope of raw id"),
64            _ => spirv::Scope::Invocation,
65        }
66    }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
70pub enum ConstVal {
71    Bit32(u32),
72    Bit64(u64),
73}
74
75impl ConstVal {
76    pub fn as_u64(&self) -> u64 {
77        match self {
78            ConstVal::Bit32(val) => *val as u64,
79            ConstVal::Bit64(val) => *val,
80        }
81    }
82
83    pub fn as_u32(&self) -> u32 {
84        match self {
85            ConstVal::Bit32(val) => *val,
86            ConstVal::Bit64(_) => panic!("Truncating 64 bit variable to 32 bit"),
87        }
88    }
89
90    pub fn as_float(&self, width: u32, encoding: Option<FPEncoding>) -> f64 {
91        match (width, encoding) {
92            (64, _) => f64::from_bits(self.as_u64()),
93            (32, _) => f32::from_bits(self.as_u32()) as f64,
94            (16, None) => half::f16::from_bits(self.as_u32() as u16).to_f64(),
95            (_, Some(FPEncoding::BFloat16KHR)) => {
96                half::bf16::from_bits(self.as_u32() as u16).to_f64()
97            }
98            (_, Some(FPEncoding::Float8E4M3EXT)) => {
99                cubecl_common::e4m3::from_bits(self.as_u32() as u8).to_f64()
100            }
101            (_, Some(FPEncoding::Float8E5M2EXT)) => {
102                cubecl_common::e5m2::from_bits(self.as_u32() as u8).to_f64()
103            }
104            _ => unreachable!(),
105        }
106    }
107
108    pub fn as_int(&self, width: u32) -> i64 {
109        unsafe {
110            match width {
111                64 => transmute::<u64, i64>(self.as_u64()),
112                32 => transmute::<u32, i32>(self.as_u32()) as i64,
113                16 => transmute::<u16, i16>(self.as_u32() as u16) as i64,
114                8 => transmute::<u8, i8>(self.as_u32() as u8) as i64,
115                _ => unreachable!(),
116            }
117        }
118    }
119
120    pub fn from_float(value: f64, width: u32, encoding: Option<FPEncoding>) -> Self {
121        match (width, encoding) {
122            (64, _) => ConstVal::Bit64(value.to_bits()),
123            (32, _) => ConstVal::Bit32((value as f32).to_bits()),
124            (16, None) => ConstVal::Bit32(half::f16::from_f64(value).to_bits() as u32),
125            (_, Some(FPEncoding::BFloat16KHR)) => {
126                ConstVal::Bit32(half::bf16::from_f64(value).to_bits() as u32)
127            }
128            (_, Some(FPEncoding::Float8E4M3EXT)) => {
129                ConstVal::Bit32(cubecl_common::e4m3::from_f64(value).to_bits() as u32)
130            }
131            (_, Some(FPEncoding::Float8E5M2EXT)) => {
132                ConstVal::Bit32(cubecl_common::e5m2::from_f64(value).to_bits() as u32)
133            }
134            _ => unreachable!(),
135        }
136    }
137
138    pub fn from_int(value: i64, width: u32) -> Self {
139        match width {
140            64 => ConstVal::Bit64(unsafe { transmute::<i64, u64>(value) }),
141            32 => ConstVal::Bit32(unsafe { transmute::<i32, u32>(value as i32) }),
142            16 => ConstVal::Bit32(unsafe { transmute::<i16, u16>(value as i16) } as u32),
143            8 => ConstVal::Bit32(unsafe { transmute::<i8, u8>(value as i8) } as u32),
144            _ => unreachable!(),
145        }
146    }
147
148    pub fn from_uint(value: u64, width: u32) -> Self {
149        match width {
150            64 => ConstVal::Bit64(value),
151            32 => ConstVal::Bit32(value as u32),
152            16 => ConstVal::Bit32(value as u16 as u32),
153            8 => ConstVal::Bit32(value as u8 as u32),
154            _ => unreachable!(),
155        }
156    }
157
158    pub fn from_bool(value: bool) -> Self {
159        ConstVal::Bit32(value as u32)
160    }
161}
162
163impl From<(ConstantValue, Item)> for ConstVal {
164    fn from((value, ty): (ConstantValue, Item)) -> Self {
165        let elem = ty.elem();
166        let width = elem.size() * 8;
167        match value {
168            ConstantValue::Int(val) => ConstVal::from_int(val, width),
169            ConstantValue::Float(val) => ConstVal::from_float(val, width, elem.float_encoding()),
170            ConstantValue::UInt(val) => ConstVal::from_uint(val, width),
171            ConstantValue::Bool(val) => ConstVal::from_bool(val),
172        }
173    }
174}
175
176impl From<u32> for ConstVal {
177    fn from(value: u32) -> Self {
178        ConstVal::Bit32(value)
179    }
180}
181
182impl From<f32> for ConstVal {
183    fn from(value: f32) -> Self {
184        ConstVal::Bit32(value.to_bits())
185    }
186}
187
188impl Variable {
189    pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
190        match self {
191            Variable::GlobalInputArray(id, _, _) => *id,
192            Variable::GlobalOutputArray(id, _, _) => *id,
193            Variable::GlobalScalar(id, _) => *id,
194            Variable::Constant(id, _, _) => *id,
195            Variable::Local { id, .. } => *id,
196            Variable::Versioned {
197                id, variable: var, ..
198            } => b.get_versioned(*id, var),
199            Variable::LocalBinding {
200                id, variable: var, ..
201            } => b.get_binding(*id, var),
202            Variable::Raw(id, _) => *id,
203            Variable::Slice { ptr, .. } => ptr.id(b),
204            Variable::SharedArray(id, _, _) => *id,
205            Variable::Shared(id, _) => *id,
206            Variable::ConstantArray(id, _, _) => *id,
207            Variable::LocalArray(id, _, _) => *id,
208            Variable::CoopMatrix(_, _) => unimplemented!("Can't get ID from matrix var"),
209            Variable::Id(id) => *id,
210            Variable::Builtin(id, ..) => *id,
211        }
212    }
213
214    pub fn item(&self) -> Item {
215        match self {
216            Variable::GlobalInputArray(_, item, _) => item.clone(),
217            Variable::GlobalOutputArray(_, item, _) => item.clone(),
218            Variable::GlobalScalar(_, elem) => Item::Scalar(*elem),
219            Variable::Constant(_, _, item) => item.clone(),
220            Variable::Local { item, .. } => item.clone(),
221            Variable::Versioned { item, .. } => item.clone(),
222            Variable::LocalBinding { item, .. } => item.clone(),
223            Variable::Slice { item, .. } => item.clone(),
224            Variable::SharedArray(_, item, _) => item.clone(),
225            Variable::Shared(_, item) => item.clone(),
226            Variable::ConstantArray(_, item, _) => item.clone(),
227            Variable::LocalArray(_, item, _) => item.clone(),
228            Variable::CoopMatrix(_, elem) => Item::Scalar(*elem),
229            Variable::Builtin(_, item) => item.clone(),
230            Variable::Raw(_, item) => item.clone(),
231            Variable::Id(_) => unimplemented!("Can't get item of raw ID"),
232        }
233    }
234
235    pub fn indexed_item(&self) -> Item {
236        match self {
237            Variable::LocalBinding {
238                item: Item::Vector(elem, _),
239                ..
240            } => Item::Scalar(*elem),
241            Variable::Local {
242                item: Item::Vector(elem, _),
243                ..
244            } => Item::Scalar(*elem),
245            Variable::Versioned {
246                item: Item::Vector(elem, _),
247                ..
248            } => Item::Scalar(*elem),
249            other => other.item(),
250        }
251    }
252
253    pub fn elem(&self) -> Elem {
254        self.item().elem()
255    }
256
257    pub fn has_len(&self) -> bool {
258        matches!(
259            self,
260            Variable::GlobalInputArray(_, _, _)
261                | Variable::GlobalOutputArray(_, _, _)
262                | Variable::Slice { .. }
263                | Variable::SharedArray(_, _, _)
264                | Variable::ConstantArray(_, _, _)
265                | Variable::LocalArray(_, _, _)
266        )
267    }
268
269    pub fn has_buffer_len(&self) -> bool {
270        matches!(
271            self,
272            Variable::GlobalInputArray(_, _, _) | Variable::GlobalOutputArray(_, _, _)
273        )
274    }
275
276    pub fn as_const(&self) -> Option<ConstVal> {
277        match self {
278            Self::Constant(_, val, _) => Some(*val),
279            _ => None,
280        }
281    }
282
283    pub fn as_binding(&self) -> Option<Id> {
284        match self {
285            Self::LocalBinding { id, .. } => Some(*id),
286            _ => None,
287        }
288    }
289}
290
291#[derive(Debug)]
292pub enum IndexedVariable {
293    Pointer(Word, Item),
294    Composite(Word, u32, Item),
295    DynamicComposite(Word, u32, Item),
296    Scalar(Variable),
297}
298
299impl<T: SpirvTarget> SpirvCompiler<T> {
300    pub fn compile_variable(&mut self, variable: ir::Variable) -> Variable {
301        let item = variable.ty;
302        match variable.kind {
303            ir::VariableKind::Constant(value) => {
304                let item = self.compile_type(item);
305                let const_val = (value, item.clone()).into();
306
307                if let Some(existing) = self.state.constants.get(&(const_val, item.clone())) {
308                    Variable::Constant(*existing, const_val, item)
309                } else {
310                    let id = item.constant(self, const_val);
311                    self.state.constants.insert((const_val, item.clone()), id);
312                    Variable::Constant(id, const_val, item)
313                }
314            }
315            ir::VariableKind::GlobalInputArray(pos) => {
316                let id = self.state.buffers[pos as usize];
317                Variable::GlobalInputArray(id, self.compile_type(item), pos)
318            }
319            ir::VariableKind::GlobalOutputArray(pos) => {
320                let id = self.state.buffers[pos as usize];
321                Variable::GlobalOutputArray(id, self.compile_type(item), pos)
322            }
323            ir::VariableKind::GlobalScalar(id) => self.global_scalar(id, item.storage_type()),
324            ir::VariableKind::LocalMut { id } => {
325                let item = self.compile_type(item);
326                let var = self.get_local(id, &item, variable);
327                Variable::Local { id: var, item }
328            }
329            ir::VariableKind::Versioned { id, version } => {
330                let item = self.compile_type(item);
331                let id = (id, version);
332                Variable::Versioned { id, item, variable }
333            }
334            ir::VariableKind::LocalConst { id } => {
335                let item = self.compile_type(item);
336                Variable::LocalBinding { id, item, variable }
337            }
338            ir::VariableKind::Builtin(builtin) => {
339                let item = self.compile_type(item);
340                self.compile_builtin(builtin, item)
341            }
342            ir::VariableKind::ConstantArray { id, length, .. } => {
343                let item = self.compile_type(item);
344                let id = self.state.const_arrays[id as usize].id;
345                Variable::ConstantArray(id, item, length as u32)
346            }
347            ir::VariableKind::SharedArray { id, length, .. } => {
348                let item = self.compile_type(item);
349                let id = self.state.shared_arrays[&id].id;
350                Variable::SharedArray(id, item, length as u32)
351            }
352            ir::VariableKind::Shared { id } => {
353                let item = self.compile_type(item);
354                let id = self.state.shared[&id].id;
355                Variable::Shared(id, item)
356            }
357            ir::VariableKind::LocalArray {
358                id,
359                length,
360                unroll_factor,
361            } => {
362                let item = self.compile_type(item);
363                let id = if let Some(arr) = self.state.local_arrays.get(&id) {
364                    arr.id
365                } else {
366                    let arr_ty = Item::Array(Box::new(item.clone()), length as u32);
367                    let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(arr_ty)).id(self);
368                    let arr_id = self.declare_function_variable(ptr_ty);
369                    self.debug_var_name(arr_id, variable);
370                    let arr = Array {
371                        id: arr_id,
372                        item: item.clone(),
373                        len: (length * unroll_factor) as u32,
374                        var: variable,
375                        alignment: None,
376                    };
377                    self.state.local_arrays.insert(id, arr);
378                    arr_id
379                };
380                Variable::LocalArray(id, item, length as u32)
381            }
382            ir::VariableKind::Matrix { id, mat } => {
383                let elem = self.compile_type(ir::Type::new(mat.storage)).elem();
384                if self.state.matrices.contains_key(&id) {
385                    Variable::CoopMatrix(id, elem)
386                } else {
387                    let matrix = self.init_coop_matrix(mat, variable);
388                    self.state.matrices.insert(id, matrix);
389                    Variable::CoopMatrix(id, elem)
390                }
391            }
392            ir::VariableKind::Pipeline { .. } => panic!("Pipeline not supported."),
393            ir::VariableKind::BarrierToken { .. } => {
394                panic!("Barrier not supported.")
395            }
396            ir::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
397            ir::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
398        }
399    }
400
401    pub fn read(&mut self, variable: &Variable) -> Word {
402        match variable {
403            Variable::Slice { ptr, .. } => self.read(ptr),
404            Variable::Shared(id, item)
405                if self.compilation_options.vulkan.supports_explicit_smem =>
406            {
407                let ty = item.id(self);
408                let ptr_ty =
409                    Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
410                let index = vec![self.const_u32(0)];
411                let access = self.access_chain(ptr_ty, None, *id, index).unwrap();
412                self.load(ty, None, access, None, []).unwrap()
413            }
414            Variable::Local { id, item } | Variable::Shared(id, item) => {
415                let ty = item.id(self);
416                self.load(ty, None, *id, None, []).unwrap()
417            }
418            ssa => ssa.id(self),
419        }
420    }
421
422    pub fn read_as(&mut self, variable: &Variable, item: &Item) -> Word {
423        if let Some(as_const) = variable.as_const() {
424            self.static_cast(as_const, &variable.elem(), item).0
425        } else {
426            let id = self.read(variable);
427            variable.item().cast_to(self, None, id, item)
428        }
429    }
430
431    pub fn index(
432        &mut self,
433        variable: &Variable,
434        index: &Variable,
435        unchecked: bool,
436    ) -> IndexedVariable {
437        let access_chain = if unchecked {
438            Builder::in_bounds_access_chain
439        } else {
440            Builder::access_chain
441        };
442        let index_id = self.read(index);
443        match variable {
444            Variable::GlobalInputArray(id, item, _) | Variable::GlobalOutputArray(id, item, _) => {
445                let ptr_ty =
446                    Item::Pointer(StorageClass::StorageBuffer, Box::new(item.clone())).id(self);
447                let zero = self.const_u32(0);
448                let id = access_chain(self, ptr_ty, None, *id, vec![zero, index_id]).unwrap();
449
450                IndexedVariable::Pointer(id, item.clone())
451            }
452            Variable::Local {
453                id,
454                item: Item::Vector(elem, _),
455            } => {
456                let ptr_ty =
457                    Item::Pointer(StorageClass::Function, Box::new(Item::Scalar(*elem))).id(self);
458                let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
459
460                IndexedVariable::Pointer(id, Item::Scalar(*elem))
461            }
462            Variable::Shared(id, Item::Vector(elem, _)) => {
463                let ptr_ty =
464                    Item::Pointer(StorageClass::Workgroup, Box::new(Item::Scalar(*elem))).id(self);
465
466                let mut index = vec![index_id];
467                if self.compilation_options.vulkan.supports_explicit_smem {
468                    index.insert(0, self.const_u32(0));
469                }
470
471                let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
472
473                IndexedVariable::Pointer(id, Item::Scalar(*elem))
474            }
475            Variable::LocalBinding {
476                id,
477                item: Item::Vector(elem, vec),
478                variable,
479            } if index.as_const().is_some() => IndexedVariable::Composite(
480                self.get_binding(*id, variable),
481                index.as_const().unwrap().as_u64() as u32,
482                Item::Vector(*elem, *vec),
483            ),
484            Variable::LocalBinding {
485                id,
486                item: Item::Vector(elem, vec),
487                variable,
488            } => IndexedVariable::DynamicComposite(
489                self.get_binding(*id, variable),
490                index_id,
491                Item::Vector(*elem, *vec),
492            ),
493            Variable::Versioned {
494                id,
495                item: Item::Vector(elem, vec),
496                variable,
497            } if index.as_const().is_some() => IndexedVariable::Composite(
498                self.get_versioned(*id, variable),
499                index.as_const().unwrap().as_u64() as u32,
500                Item::Vector(*elem, *vec),
501            ),
502            Variable::Versioned {
503                id,
504                item: Item::Vector(elem, vec),
505                variable,
506            } => IndexedVariable::DynamicComposite(
507                self.get_versioned(*id, variable),
508                index_id,
509                Item::Vector(*elem, *vec),
510            ),
511            Variable::Shared(id, item)
512                if self.compilation_options.vulkan.supports_explicit_smem =>
513            {
514                let ptr_ty =
515                    Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
516                let index = vec![self.const_u32(0)];
517                let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
518                IndexedVariable::Pointer(id, item.clone())
519            }
520            Variable::Local { .. }
521            | Variable::Shared(..)
522            | Variable::LocalBinding { .. }
523            | Variable::Versioned { .. } => IndexedVariable::Scalar(variable.clone()),
524            Variable::Constant(_, val, item) => {
525                let scalar_item = Item::Scalar(item.elem());
526                let (id, val) = self.static_cast(*val, &item.elem(), &scalar_item);
527                IndexedVariable::Scalar(Variable::Constant(id, val, scalar_item))
528            }
529            Variable::Slice { ptr, offset, .. } => {
530                let item = Item::Scalar(Elem::Int(32, false));
531                let int = item.id(self);
532                let index = match index.as_const() {
533                    Some(ConstVal::Bit32(0)) => *offset,
534                    _ => self.i_add(int, None, *offset, index_id).unwrap(),
535                };
536                self.index(ptr, &Variable::Raw(index, item), unchecked)
537            }
538            Variable::SharedArray(id, item, _) => {
539                let ptr_ty =
540                    Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
541                let mut index = vec![index_id];
542                if self.compilation_options.vulkan.supports_explicit_smem {
543                    index.insert(0, self.const_u32(0));
544                }
545                let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
546                IndexedVariable::Pointer(id, item.clone())
547            }
548            Variable::ConstantArray(id, item, _) | Variable::LocalArray(id, item, _) => {
549                let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
550                let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
551                IndexedVariable::Pointer(id, item.clone())
552            }
553            var => unimplemented!("Can't index into {var:?}"),
554        }
555    }
556
557    pub fn read_indexed(&mut self, out: &Variable, variable: &Variable, index: &Variable) -> Word {
558        let always_in_bounds = is_always_in_bounds(variable, index);
559        let indexed = self.index(variable, index, always_in_bounds);
560
561        let read = |b: &mut Self| match indexed {
562            IndexedVariable::Pointer(ptr, item) => {
563                let ty = item.id(b);
564                let out_id = b.write_id(out);
565                b.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
566            }
567            IndexedVariable::Composite(var, index, item) => {
568                let elem = item.elem();
569                let ty = elem.id(b);
570                let out_id = b.write_id(out);
571                b.composite_extract(ty, Some(out_id), var, vec![index])
572                    .unwrap()
573            }
574            IndexedVariable::DynamicComposite(var, index, item) => {
575                let elem = item.elem();
576                let ty = elem.id(b);
577                let out_id = b.write_id(out);
578                b.vector_extract_dynamic(ty, Some(out_id), var, index)
579                    .unwrap()
580            }
581            IndexedVariable::Scalar(var) => {
582                let ty = out.item().id(b);
583                let input = b.read(&var);
584                let out_id = b.write_id(out);
585                b.copy_object(ty, Some(out_id), input).unwrap();
586                b.write(out, out_id);
587                out_id
588            }
589        };
590
591        read(self)
592    }
593
594    pub fn read_indexed_unchecked(
595        &mut self,
596        out: &Variable,
597        variable: &Variable,
598        index: &Variable,
599    ) -> Word {
600        let indexed = self.index(variable, index, true);
601
602        match indexed {
603            IndexedVariable::Pointer(ptr, item) => {
604                let ty = item.id(self);
605                let out_id = self.write_id(out);
606                self.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
607            }
608            IndexedVariable::Composite(var, index, item) => {
609                let elem = item.elem();
610                let ty = elem.id(self);
611                let out_id = self.write_id(out);
612                self.composite_extract(ty, Some(out_id), var, vec![index])
613                    .unwrap()
614            }
615            IndexedVariable::DynamicComposite(var, index, item) => {
616                let elem = item.elem();
617                let ty = elem.id(self);
618                let out_id = self.write_id(out);
619                self.vector_extract_dynamic(ty, Some(out_id), var, index)
620                    .unwrap()
621            }
622            IndexedVariable::Scalar(var) => {
623                let ty = out.item().id(self);
624                let input = self.read(&var);
625                let out_id = self.write_id(out);
626                self.copy_object(ty, Some(out_id), input).unwrap();
627                self.write(out, out_id);
628                out_id
629            }
630        }
631    }
632
633    pub fn index_ptr(&mut self, var: &Variable, index: &Variable) -> Word {
634        let always_in_bounds = is_always_in_bounds(var, index);
635        match self.index(var, index, always_in_bounds) {
636            IndexedVariable::Pointer(ptr, _) => ptr,
637            other => unreachable!("{other:?}"),
638        }
639    }
640
641    pub fn write_id(&mut self, variable: &Variable) -> Word {
642        match variable {
643            Variable::LocalBinding { id, variable, .. } => self.get_binding(*id, variable),
644            Variable::Versioned { id, variable, .. } => self.get_versioned(*id, variable),
645            Variable::Local { .. } => self.id(),
646            Variable::Shared(..) => self.id(),
647            Variable::GlobalScalar(id, _) => *id,
648            Variable::Raw(id, _) => *id,
649            Variable::Constant(_, _, _) => panic!("Can't write to constant scalar"),
650            Variable::GlobalInputArray(_, _, _)
651            | Variable::GlobalOutputArray(_, _, _)
652            | Variable::Slice { .. }
653            | Variable::SharedArray(_, _, _)
654            | Variable::ConstantArray(_, _, _)
655            | Variable::LocalArray(_, _, _) => panic!("Can't write to unindexed array"),
656            global => panic!("Can't write to builtin {global:?}"),
657        }
658    }
659
660    pub fn write(&mut self, variable: &Variable, value: Word) {
661        match variable {
662            Variable::Shared(id, item)
663                if self.compilation_options.vulkan.supports_explicit_smem =>
664            {
665                let ptr_ty =
666                    Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
667                let index = vec![self.const_u32(0)];
668                let access = self.access_chain(ptr_ty, None, *id, index).unwrap();
669                self.store(access, value, None, []).unwrap()
670            }
671            Variable::Local { id, .. } | Variable::Shared(id, _) => {
672                self.store(*id, value, None, []).unwrap()
673            }
674
675            Variable::Slice { ptr, .. } => self.write(ptr, value),
676            _ => {}
677        }
678    }
679
680    pub fn write_indexed(&mut self, out: &Variable, index: &Variable, value: Word) {
681        let always_in_bounds = is_always_in_bounds(out, index);
682        let variable = self.index(out, index, always_in_bounds);
683
684        let write = |b: &mut Self| match variable {
685            IndexedVariable::Pointer(ptr, _) => b.store(ptr, value, None, vec![]).unwrap(),
686            IndexedVariable::Composite(var, index, item) => {
687                let ty = item.id(b);
688                let id = b
689                    .composite_insert(ty, None, value, var, vec![index])
690                    .unwrap();
691                b.write(out, id);
692            }
693            IndexedVariable::DynamicComposite(var, index, item) => {
694                let ty = item.id(b);
695                let id = b
696                    .vector_insert_dynamic(ty, None, value, var, index)
697                    .unwrap();
698                b.write(out, id);
699            }
700            IndexedVariable::Scalar(var) => b.write(&var, value),
701        };
702
703        write(self)
704    }
705
706    pub fn write_indexed_unchecked(&mut self, out: &Variable, index: &Variable, value: Word) {
707        let variable = self.index(out, index, true);
708
709        match variable {
710            IndexedVariable::Pointer(ptr, _) => self.store(ptr, value, None, vec![]).unwrap(),
711            IndexedVariable::Composite(var, index, item) => {
712                let ty = item.id(self);
713                let out_id = self
714                    .composite_insert(ty, None, value, var, vec![index])
715                    .unwrap();
716                self.write(out, out_id);
717            }
718            IndexedVariable::DynamicComposite(var, index, item) => {
719                let ty = item.id(self);
720                let out_id = self
721                    .vector_insert_dynamic(ty, None, value, var, index)
722                    .unwrap();
723                self.write(out, out_id);
724            }
725            IndexedVariable::Scalar(var) => self.write(&var, value),
726        }
727    }
728}
729
730fn is_always_in_bounds(var: &Variable, index: &Variable) -> bool {
731    let len = match var {
732        Variable::SharedArray(_, _, len)
733        | Variable::ConstantArray(_, _, len)
734        | Variable::LocalArray(_, _, len)
735        | Variable::Slice {
736            const_len: Some(len),
737            ..
738        } => *len,
739        _ => return false,
740    };
741
742    let const_index = match index {
743        Variable::Constant(_, value, _) => value.as_u64(),
744        _ => return false,
745    };
746
747    const_index < len as u64
748}