cubecl_spirv/
variable.rs

1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{
6    SpirvCompiler, SpirvTarget,
7    item::{Elem, Item},
8    lookups::Array,
9};
10use cubecl_core::ir::{self, ConstantScalarValue, FloatKind, Id};
11use rspirv::{
12    dr::Builder,
13    spirv::{FPEncoding, StorageClass, Word},
14};
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum Variable {
18    SubgroupSize(Word),
19    GlobalInputArray(Word, Item, u32),
20    GlobalOutputArray(Word, Item, u32),
21    GlobalScalar(Word, Elem),
22    ConstantScalar(Word, ConstVal, Elem),
23    Local {
24        id: Word,
25        item: Item,
26    },
27    Versioned {
28        id: (Id, u16),
29        item: Item,
30        variable: ir::Variable,
31    },
32    LocalBinding {
33        id: Id,
34        item: Item,
35        variable: ir::Variable,
36    },
37    Raw(Word, Item),
38    Named {
39        id: Word,
40        item: Item,
41        is_array: bool,
42    },
43    Slice {
44        ptr: Box<Variable>,
45        offset: Word,
46        end: Word,
47        const_len: Option<u32>,
48        item: Item,
49    },
50    SharedMemory(Word, Item, u32),
51    ConstantArray(Word, Item, u32),
52    LocalArray(Word, Item, u32),
53    CoopMatrix(Id, Elem),
54    Id(Word),
55    LocalInvocationIndex(Word),
56    LocalInvocationIdX(Word),
57    LocalInvocationIdY(Word),
58    LocalInvocationIdZ(Word),
59    Rank(Word),
60    WorkgroupId(Word),
61    WorkgroupIdX(Word),
62    WorkgroupIdY(Word),
63    WorkgroupIdZ(Word),
64    GlobalInvocationIndex(Word),
65    GlobalInvocationIdX(Word),
66    GlobalInvocationIdY(Word),
67    GlobalInvocationIdZ(Word),
68    WorkgroupSize(Word),
69    WorkgroupSizeX(Word),
70    WorkgroupSizeY(Word),
71    WorkgroupSizeZ(Word),
72    NumWorkgroups(Word),
73    NumWorkgroupsX(Word),
74    NumWorkgroupsY(Word),
75    NumWorkgroupsZ(Word),
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
79pub enum ConstVal {
80    Bit32(u32),
81    Bit64(u64),
82}
83
84impl ConstVal {
85    pub fn as_u64(&self) -> u64 {
86        match self {
87            ConstVal::Bit32(val) => *val as u64,
88            ConstVal::Bit64(val) => *val,
89        }
90    }
91
92    pub fn as_u32(&self) -> u32 {
93        match self {
94            ConstVal::Bit32(val) => *val,
95            ConstVal::Bit64(_) => panic!("Truncating 64 bit variable to 32 bit"),
96        }
97    }
98
99    pub fn as_float(&self, width: u32, encoding: Option<FPEncoding>) -> f64 {
100        match (width, encoding) {
101            (64, _) => f64::from_bits(self.as_u64()),
102            (32, _) => f32::from_bits(self.as_u32()) as f64,
103            (16, None) => half::f16::from_bits(self.as_u32() as u16).to_f64(),
104            (_, Some(FPEncoding::BFloat16KHR)) => {
105                half::bf16::from_bits(self.as_u32() as u16).to_f64()
106            }
107            (_, Some(FPEncoding::Float8E4M3EXT)) => {
108                cubecl_common::e4m3::from_bits(self.as_u32() as u8).to_f64()
109            }
110            (_, Some(FPEncoding::Float8E5M2EXT)) => {
111                cubecl_common::e5m2::from_bits(self.as_u32() as u8).to_f64()
112            }
113            _ => unreachable!(),
114        }
115    }
116
117    pub fn as_int(&self, width: u32) -> i64 {
118        unsafe {
119            match width {
120                64 => transmute::<u64, i64>(self.as_u64()),
121                32 => transmute::<u32, i32>(self.as_u32()) as i64,
122                16 => transmute::<u16, i16>(self.as_u32() as u16) as i64,
123                8 => transmute::<u8, i8>(self.as_u32() as u8) as i64,
124                _ => unreachable!(),
125            }
126        }
127    }
128
129    pub fn from_float(value: f64, width: u32, encoding: Option<FPEncoding>) -> Self {
130        match (width, encoding) {
131            (64, _) => ConstVal::Bit64(value.to_bits()),
132            (32, _) => ConstVal::Bit32((value as f32).to_bits()),
133            (16, None) => ConstVal::Bit32(half::f16::from_f64(value).to_bits() as u32),
134            (_, Some(FPEncoding::BFloat16KHR)) => {
135                ConstVal::Bit32(half::bf16::from_f64(value).to_bits() as u32)
136            }
137            (_, Some(FPEncoding::Float8E4M3EXT)) => {
138                ConstVal::Bit32(cubecl_common::e4m3::from_f64(value).to_bits() as u32)
139            }
140            (_, Some(FPEncoding::Float8E5M2EXT)) => {
141                ConstVal::Bit32(cubecl_common::e5m2::from_f64(value).to_bits() as u32)
142            }
143            _ => unreachable!(),
144        }
145    }
146
147    pub fn from_int(value: i64, width: u32) -> Self {
148        match width {
149            64 => ConstVal::Bit64(unsafe { transmute::<i64, u64>(value) }),
150            32 => ConstVal::Bit32(unsafe { transmute::<i32, u32>(value as i32) }),
151            16 => ConstVal::Bit32(unsafe { transmute::<i16, u16>(value as i16) } as u32),
152            8 => ConstVal::Bit32(unsafe { transmute::<i8, u8>(value as i8) } as u32),
153            _ => unreachable!(),
154        }
155    }
156
157    pub fn from_uint(value: u64, width: u32) -> Self {
158        match width {
159            64 => ConstVal::Bit64(value),
160            32 => ConstVal::Bit32(value as u32),
161            16 => ConstVal::Bit32(value as u16 as u32),
162            8 => ConstVal::Bit32(value as u8 as u32),
163            _ => unreachable!(),
164        }
165    }
166
167    pub fn from_bool(value: bool) -> Self {
168        ConstVal::Bit32(value as u32)
169    }
170}
171
172impl From<ConstantScalarValue> for ConstVal {
173    fn from(value: ConstantScalarValue) -> Self {
174        let width = value.elem_type().size() as u32 * 8;
175        match value {
176            ConstantScalarValue::Int(val, _) => ConstVal::from_int(val, width),
177            ConstantScalarValue::Float(val, FloatKind::BF16) => {
178                ConstVal::from_float(val, width, Some(FPEncoding::BFloat16KHR))
179            }
180            ConstantScalarValue::Float(val, FloatKind::E4M3) => {
181                ConstVal::from_float(val, width, Some(FPEncoding::Float8E4M3EXT))
182            }
183            ConstantScalarValue::Float(val, FloatKind::E5M2) => {
184                ConstVal::from_float(val, width, Some(FPEncoding::Float8E5M2EXT))
185            }
186            ConstantScalarValue::Float(val, _) => ConstVal::from_float(val, width, None),
187            ConstantScalarValue::UInt(val, _) => ConstVal::from_uint(val, width),
188            ConstantScalarValue::Bool(val) => ConstVal::from_bool(val),
189        }
190    }
191}
192
193impl From<u32> for ConstVal {
194    fn from(value: u32) -> Self {
195        ConstVal::Bit32(value)
196    }
197}
198
199impl From<f32> for ConstVal {
200    fn from(value: f32) -> Self {
201        ConstVal::Bit32(value.to_bits())
202    }
203}
204
205impl Variable {
206    pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
207        match self {
208            Variable::GlobalInputArray(id, _, _) => *id,
209            Variable::GlobalOutputArray(id, _, _) => *id,
210            Variable::GlobalScalar(id, _) => *id,
211            Variable::ConstantScalar(id, _, _) => *id,
212            Variable::Local { id, .. } => *id,
213            Variable::Versioned {
214                id, variable: var, ..
215            } => b.get_versioned(*id, var),
216            Variable::LocalBinding {
217                id, variable: var, ..
218            } => b.get_binding(*id, var),
219            Variable::Raw(id, _) => *id,
220            Variable::Named { id, .. } => *id,
221            Variable::Slice { ptr, .. } => ptr.id(b),
222            Variable::SharedMemory(id, _, _) => *id,
223            Variable::ConstantArray(id, _, _) => *id,
224            Variable::LocalArray(id, _, _) => *id,
225            Variable::CoopMatrix(_, _) => unimplemented!("Can't get ID from matrix var"),
226            Variable::SubgroupSize(id) => *id,
227            Variable::Id(id) => *id,
228            Variable::LocalInvocationIndex(id) => *id,
229            Variable::LocalInvocationIdX(id) => *id,
230            Variable::LocalInvocationIdY(id) => *id,
231            Variable::LocalInvocationIdZ(id) => *id,
232            Variable::Rank(id) => *id,
233            Variable::WorkgroupId(id) => *id,
234            Variable::WorkgroupIdX(id) => *id,
235            Variable::WorkgroupIdY(id) => *id,
236            Variable::WorkgroupIdZ(id) => *id,
237            Variable::GlobalInvocationIndex(id) => *id,
238            Variable::GlobalInvocationIdX(id) => *id,
239            Variable::GlobalInvocationIdY(id) => *id,
240            Variable::GlobalInvocationIdZ(id) => *id,
241            Variable::WorkgroupSize(id) => *id,
242            Variable::WorkgroupSizeX(id) => *id,
243            Variable::WorkgroupSizeY(id) => *id,
244            Variable::WorkgroupSizeZ(id) => *id,
245            Variable::NumWorkgroups(id) => *id,
246            Variable::NumWorkgroupsX(id) => *id,
247            Variable::NumWorkgroupsY(id) => *id,
248            Variable::NumWorkgroupsZ(id) => *id,
249        }
250    }
251
252    pub fn item(&self) -> Item {
253        match self {
254            Variable::GlobalInputArray(_, item, _) => item.clone(),
255            Variable::GlobalOutputArray(_, item, _) => item.clone(),
256            Variable::GlobalScalar(_, elem) => Item::Scalar(*elem),
257            Variable::ConstantScalar(_, _, elem) => Item::Scalar(*elem),
258            Variable::Local { item, .. } => item.clone(),
259            Variable::Versioned { item, .. } => item.clone(),
260            Variable::LocalBinding { item, .. } => item.clone(),
261            Variable::Named { item, .. } => item.clone(),
262            Variable::Slice { item, .. } => item.clone(),
263            Variable::SharedMemory(_, item, _) => item.clone(),
264            Variable::ConstantArray(_, item, _) => item.clone(),
265            Variable::LocalArray(_, item, _) => item.clone(),
266            Variable::CoopMatrix(_, elem) => Item::Scalar(*elem),
267            _ => Item::Scalar(Elem::Int(32, false)), // builtin
268        }
269    }
270
271    pub fn indexed_item(&self) -> Item {
272        match self {
273            Variable::LocalBinding {
274                item: Item::Vector(elem, _),
275                ..
276            } => Item::Scalar(*elem),
277            Variable::Local {
278                item: Item::Vector(elem, _),
279                ..
280            } => Item::Scalar(*elem),
281            Variable::Versioned {
282                item: Item::Vector(elem, _),
283                ..
284            } => Item::Scalar(*elem),
285            other => other.item(),
286        }
287    }
288
289    pub fn elem(&self) -> Elem {
290        self.item().elem()
291    }
292
293    pub fn has_len(&self) -> bool {
294        matches!(
295            self,
296            Variable::GlobalInputArray(_, _, _)
297                | Variable::GlobalOutputArray(_, _, _)
298                | Variable::Named {
299                    is_array: false,
300                    ..
301                }
302                | Variable::Slice { .. }
303                | Variable::SharedMemory(_, _, _)
304                | Variable::ConstantArray(_, _, _)
305                | Variable::LocalArray(_, _, _)
306        )
307    }
308
309    pub fn has_buffer_len(&self) -> bool {
310        matches!(
311            self,
312            Variable::GlobalInputArray(_, _, _)
313                | Variable::GlobalOutputArray(_, _, _)
314                | Variable::Named {
315                    is_array: false,
316                    ..
317                }
318        )
319    }
320
321    pub fn as_const(&self) -> Option<ConstVal> {
322        match self {
323            Self::ConstantScalar(_, val, _) => Some(*val),
324            _ => None,
325        }
326    }
327
328    pub fn as_binding(&self) -> Option<Id> {
329        match self {
330            Self::LocalBinding { id, .. } => Some(*id),
331            _ => None,
332        }
333    }
334}
335
336#[derive(Debug)]
337pub enum IndexedVariable {
338    Pointer(Word, Item),
339    Composite(Word, u32, Item),
340    DynamicComposite(Word, u32, Item),
341    Scalar(Variable),
342}
343
344impl<T: SpirvTarget> SpirvCompiler<T> {
345    pub fn compile_variable(&mut self, variable: ir::Variable) -> Variable {
346        let item = variable.ty;
347        match variable.kind {
348            ir::VariableKind::ConstantScalar(value) => {
349                let item = self.compile_type(ir::Type::new(value.storage_type()));
350                let const_val = value.into();
351
352                if let Some(existing) = self.state.constants.get(&(const_val, item.clone())) {
353                    Variable::ConstantScalar(*existing, const_val, item.elem())
354                } else {
355                    let id = item.elem().constant(self, const_val);
356                    self.state.constants.insert((const_val, item.clone()), id);
357                    Variable::ConstantScalar(id, const_val, item.elem())
358                }
359            }
360            ir::VariableKind::GlobalInputArray(pos) => {
361                let id = self.state.buffers[pos as usize];
362                Variable::GlobalInputArray(id, self.compile_type(item), pos)
363            }
364            ir::VariableKind::GlobalOutputArray(pos) => {
365                let id = self.state.buffers[pos as usize];
366                Variable::GlobalOutputArray(id, self.compile_type(item), pos)
367            }
368            ir::VariableKind::GlobalScalar(id) => self.global_scalar(id, item.storage_type()),
369            ir::VariableKind::LocalMut { id } => {
370                let item = self.compile_type(item);
371                let var = self.get_local(id, &item, variable);
372                Variable::Local { id: var, item }
373            }
374            ir::VariableKind::Versioned { id, version } => {
375                let item = self.compile_type(item);
376                let id = (id, version);
377                Variable::Versioned { id, item, variable }
378            }
379            ir::VariableKind::LocalConst { id } => {
380                let item = self.compile_type(item);
381                Variable::LocalBinding { id, item, variable }
382            }
383            ir::VariableKind::Builtin(builtin) => self.compile_builtin(builtin),
384            ir::VariableKind::ConstantArray { id, length, .. } => {
385                let item = self.compile_type(item);
386                let id = self.state.const_arrays[id as usize].id;
387                Variable::ConstantArray(id, item, length)
388            }
389            ir::VariableKind::SharedMemory { id, length, .. } => {
390                let item = self.compile_type(item);
391                let id = self.state.shared_memories[&id].id;
392                Variable::SharedMemory(id, item, length)
393            }
394            ir::VariableKind::LocalArray {
395                id,
396                length,
397                unroll_factor,
398            } => {
399                let item = self.compile_type(item);
400                let id = if let Some(arr) = self.state.local_arrays.get(&id) {
401                    arr.id
402                } else {
403                    let arr_ty = Item::Array(Box::new(item.clone()), length);
404                    let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(arr_ty)).id(self);
405                    let arr_id = self.declare_function_variable(ptr_ty);
406                    self.debug_var_name(arr_id, variable);
407                    let arr = Array {
408                        id: arr_id,
409                        item: item.clone(),
410                        len: length * unroll_factor,
411                        var: variable,
412                        alignment: None,
413                    };
414                    self.state.local_arrays.insert(id, arr);
415                    arr_id
416                };
417                Variable::LocalArray(id, item, length)
418            }
419            ir::VariableKind::Matrix { id, mat } => {
420                let elem = self.compile_type(ir::Type::new(mat.storage)).elem();
421                if self.state.matrices.contains_key(&id) {
422                    Variable::CoopMatrix(id, elem)
423                } else {
424                    let matrix = self.init_coop_matrix(mat, variable);
425                    self.state.matrices.insert(id, matrix);
426                    Variable::CoopMatrix(id, elem)
427                }
428            }
429            ir::VariableKind::Pipeline { .. } => panic!("Pipeline not supported."),
430            ir::VariableKind::Barrier { .. } => panic!("Barrier not supported."),
431            ir::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
432            ir::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
433        }
434    }
435
436    pub fn read(&mut self, variable: &Variable) -> Word {
437        match variable {
438            Variable::Slice { ptr, .. } => self.read(ptr),
439            Variable::Local { id, item } => {
440                let ty = item.id(self);
441                self.load(ty, None, *id, None, vec![]).unwrap()
442            }
443            Variable::Named { id, item, .. } => {
444                let ty = item.id(self);
445                self.load(ty, None, *id, None, vec![]).unwrap()
446            }
447            ssa => ssa.id(self),
448        }
449    }
450
451    pub fn read_as(&mut self, variable: &Variable, item: &Item) -> Word {
452        if let Some(as_const) = variable.as_const() {
453            self.static_cast(as_const, &variable.elem(), item)
454        } else {
455            let id = self.read(variable);
456            variable.item().cast_to(self, None, id, item)
457        }
458    }
459
460    pub fn index(
461        &mut self,
462        variable: &Variable,
463        index: &Variable,
464        unchecked: bool,
465    ) -> IndexedVariable {
466        let access_chain = if unchecked {
467            Builder::in_bounds_access_chain
468        } else {
469            Builder::access_chain
470        };
471        let index_id = self.read(index);
472        match variable {
473            Variable::GlobalInputArray(id, item, _)
474            | Variable::GlobalOutputArray(id, item, _)
475            | Variable::Named { id, item, .. } => {
476                let ptr_ty =
477                    Item::Pointer(StorageClass::StorageBuffer, Box::new(item.clone())).id(self);
478                let zero = self.const_u32(0);
479                let id = access_chain(self, ptr_ty, None, *id, vec![zero, index_id]).unwrap();
480
481                IndexedVariable::Pointer(id, item.clone())
482            }
483            Variable::Local {
484                id,
485                item: Item::Vector(elem, _),
486            } => {
487                let ptr_ty =
488                    Item::Pointer(StorageClass::Function, Box::new(Item::Scalar(*elem))).id(self);
489                let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
490
491                IndexedVariable::Pointer(id, Item::Scalar(*elem))
492            }
493            Variable::LocalBinding {
494                id,
495                item: Item::Vector(elem, vec),
496                variable,
497            } if index.as_const().is_some() => IndexedVariable::Composite(
498                self.get_binding(*id, variable),
499                index.as_const().unwrap().as_u32(),
500                Item::Vector(*elem, *vec),
501            ),
502            Variable::LocalBinding {
503                id,
504                item: Item::Vector(elem, vec),
505                variable,
506            } => IndexedVariable::DynamicComposite(
507                self.get_binding(*id, variable),
508                index_id,
509                Item::Vector(*elem, *vec),
510            ),
511            Variable::Versioned {
512                id,
513                item: Item::Vector(elem, vec),
514                variable,
515            } if index.as_const().is_some() => IndexedVariable::Composite(
516                self.get_versioned(*id, variable),
517                index.as_const().unwrap().as_u32(),
518                Item::Vector(*elem, *vec),
519            ),
520            Variable::Versioned {
521                id,
522                item: Item::Vector(elem, vec),
523                variable,
524            } => IndexedVariable::DynamicComposite(
525                self.get_versioned(*id, variable),
526                index_id,
527                Item::Vector(*elem, *vec),
528            ),
529            Variable::Local { .. } | Variable::LocalBinding { .. } | Variable::Versioned { .. } => {
530                IndexedVariable::Scalar(variable.clone())
531            }
532            Variable::Slice { ptr, offset, .. } => {
533                let item = Item::Scalar(Elem::Int(32, false));
534                let int = item.id(self);
535                let index = match index.as_const() {
536                    Some(ConstVal::Bit32(0)) => *offset,
537                    _ => self.i_add(int, None, *offset, index_id).unwrap(),
538                };
539                self.index(ptr, &Variable::Raw(index, item), unchecked)
540            }
541            Variable::SharedMemory(id, item, _) => {
542                let ptr_ty =
543                    Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
544                let mut index = vec![index_id];
545                if self.compilation_options.supports_explicit_smem {
546                    index.insert(0, self.const_u32(0));
547                }
548                let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
549                IndexedVariable::Pointer(id, item.clone())
550            }
551            Variable::ConstantArray(id, item, _) | Variable::LocalArray(id, item, _) => {
552                let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
553                let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
554                IndexedVariable::Pointer(id, item.clone())
555            }
556            var => unimplemented!("Can't index into {var:?}"),
557        }
558    }
559
560    pub fn read_indexed(&mut self, out: &Variable, variable: &Variable, index: &Variable) -> Word {
561        let always_in_bounds = is_always_in_bounds(variable, index);
562        let indexed = self.index(variable, index, always_in_bounds);
563
564        let read = |b: &mut Self| match indexed {
565            IndexedVariable::Pointer(ptr, item) => {
566                let ty = item.id(b);
567                let out_id = b.write_id(out);
568                b.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
569            }
570            IndexedVariable::Composite(var, index, item) => {
571                let elem = item.elem();
572                let ty = elem.id(b);
573                let out_id = b.write_id(out);
574                b.composite_extract(ty, Some(out_id), var, vec![index])
575                    .unwrap()
576            }
577            IndexedVariable::DynamicComposite(var, index, item) => {
578                let elem = item.elem();
579                let ty = elem.id(b);
580                let out_id = b.write_id(out);
581                b.vector_extract_dynamic(ty, Some(out_id), var, index)
582                    .unwrap()
583            }
584            IndexedVariable::Scalar(var) => {
585                let ty = out.item().id(b);
586                let input = b.read(&var);
587                let out_id = b.write_id(out);
588                b.copy_object(ty, Some(out_id), input).unwrap();
589                b.write(out, out_id);
590                out_id
591            }
592        };
593
594        read(self)
595    }
596
597    pub fn read_indexed_unchecked(
598        &mut self,
599        out: &Variable,
600        variable: &Variable,
601        index: &Variable,
602    ) -> Word {
603        let indexed = self.index(variable, index, true);
604
605        match indexed {
606            IndexedVariable::Pointer(ptr, item) => {
607                let ty = item.id(self);
608                let out_id = self.write_id(out);
609                self.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
610            }
611            IndexedVariable::Composite(var, index, item) => {
612                let elem = item.elem();
613                let ty = elem.id(self);
614                let out_id = self.write_id(out);
615                self.composite_extract(ty, Some(out_id), var, vec![index])
616                    .unwrap()
617            }
618            IndexedVariable::DynamicComposite(var, index, item) => {
619                let elem = item.elem();
620                let ty = elem.id(self);
621                let out_id = self.write_id(out);
622                self.vector_extract_dynamic(ty, Some(out_id), var, index)
623                    .unwrap()
624            }
625            IndexedVariable::Scalar(var) => {
626                let ty = out.item().id(self);
627                let input = self.read(&var);
628                let out_id = self.write_id(out);
629                self.copy_object(ty, Some(out_id), input).unwrap();
630                self.write(out, out_id);
631                out_id
632            }
633        }
634    }
635
636    pub fn index_ptr(&mut self, var: &Variable, index: &Variable) -> Word {
637        let always_in_bounds = is_always_in_bounds(var, index);
638        match self.index(var, index, always_in_bounds) {
639            IndexedVariable::Pointer(ptr, _) => ptr,
640            other => unreachable!("{other:?}"),
641        }
642    }
643
644    pub fn write_id(&mut self, variable: &Variable) -> Word {
645        match variable {
646            Variable::LocalBinding { id, variable, .. } => self.get_binding(*id, variable),
647            Variable::Versioned { id, variable, .. } => self.get_versioned(*id, variable),
648            Variable::Local { .. } => self.id(),
649            Variable::GlobalScalar(id, _) => *id,
650            Variable::Raw(id, _) => *id,
651            Variable::ConstantScalar(_, _, _) => panic!("Can't write to constant scalar"),
652            Variable::GlobalInputArray(_, _, _)
653            | Variable::GlobalOutputArray(_, _, _)
654            | Variable::Slice { .. }
655            | Variable::Named { .. }
656            | Variable::SharedMemory(_, _, _)
657            | Variable::ConstantArray(_, _, _)
658            | Variable::LocalArray(_, _, _) => panic!("Can't write to unindexed array"),
659            global => panic!("Can't write to builtin {global:?}"),
660        }
661    }
662
663    pub fn write(&mut self, variable: &Variable, value: Word) {
664        match variable {
665            Variable::Local { id, .. } => self.store(*id, value, None, vec![]).unwrap(),
666            Variable::Slice { ptr, .. } => self.write(ptr, value),
667            _ => {}
668        }
669    }
670
671    pub fn write_indexed(&mut self, out: &Variable, index: &Variable, value: Word) {
672        let always_in_bounds = is_always_in_bounds(out, index);
673        let variable = self.index(out, index, always_in_bounds);
674
675        let write = |b: &mut Self| match variable {
676            IndexedVariable::Pointer(ptr, _) => b.store(ptr, value, None, vec![]).unwrap(),
677            IndexedVariable::Composite(var, index, item) => {
678                let ty = item.id(b);
679                let id = b
680                    .composite_insert(ty, None, value, var, vec![index])
681                    .unwrap();
682                b.write(out, id);
683            }
684            IndexedVariable::DynamicComposite(var, index, item) => {
685                let ty = item.id(b);
686                let id = b
687                    .vector_insert_dynamic(ty, None, value, var, index)
688                    .unwrap();
689                b.write(out, id);
690            }
691            IndexedVariable::Scalar(var) => b.write(&var, value),
692        };
693
694        write(self)
695    }
696
697    pub fn write_indexed_unchecked(&mut self, out: &Variable, index: &Variable, value: Word) {
698        let variable = self.index(out, index, true);
699
700        match variable {
701            IndexedVariable::Pointer(ptr, _) => self.store(ptr, value, None, vec![]).unwrap(),
702            IndexedVariable::Composite(var, index, item) => {
703                let ty = item.id(self);
704                let out_id = self
705                    .composite_insert(ty, None, value, var, vec![index])
706                    .unwrap();
707                self.write(out, out_id);
708            }
709            IndexedVariable::DynamicComposite(var, index, item) => {
710                let ty = item.id(self);
711                let out_id = self
712                    .vector_insert_dynamic(ty, None, value, var, index)
713                    .unwrap();
714                self.write(out, out_id);
715            }
716            IndexedVariable::Scalar(var) => self.write(&var, value),
717        }
718    }
719}
720
721fn is_always_in_bounds(var: &Variable, index: &Variable) -> bool {
722    let len = match var {
723        Variable::SharedMemory(_, _, len)
724        | Variable::ConstantArray(_, _, len)
725        | Variable::LocalArray(_, _, len)
726        | Variable::Slice {
727            const_len: Some(len),
728            ..
729        } => *len,
730        _ => return false,
731    };
732
733    let const_index = match index {
734        Variable::ConstantScalar(_, value, _) => value.as_u32(),
735        _ => return false,
736    };
737
738    const_index < len
739}