cubecl_spirv/
variable.rs

1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{
6    SpirvCompiler, SpirvTarget,
7    item::{Elem, Item},
8    lookups::Array,
9};
10use cubecl_core::ir::{self, ConstantScalarValue, FloatKind, Id};
11use rspirv::{
12    dr::Builder,
13    spirv::{FPEncoding, StorageClass, Word},
14};
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum Variable {
18    SubgroupSize(Word),
19    GlobalInputArray(Word, Item, u32),
20    GlobalOutputArray(Word, Item, u32),
21    GlobalScalar(Word, Elem),
22    ConstantScalar(Word, ConstVal, Elem),
23    Local {
24        id: Word,
25        item: Item,
26    },
27    Versioned {
28        id: (Id, u16),
29        item: Item,
30        variable: ir::Variable,
31    },
32    LocalBinding {
33        id: Id,
34        item: Item,
35        variable: ir::Variable,
36    },
37    Raw(Word, Item),
38    Named {
39        id: Word,
40        item: Item,
41        is_array: bool,
42    },
43    Slice {
44        ptr: Box<Variable>,
45        offset: Word,
46        end: Word,
47        const_len: Option<u32>,
48        item: Item,
49    },
50    SharedMemory(Word, Item, u32),
51    ConstantArray(Word, Item, u32),
52    LocalArray(Word, Item, u32),
53    CoopMatrix(Id, Elem),
54    Id(Word),
55    LocalInvocationIndex(Word),
56    LocalInvocationIdX(Word),
57    LocalInvocationIdY(Word),
58    LocalInvocationIdZ(Word),
59    Rank(Word),
60    WorkgroupId(Word),
61    WorkgroupIdX(Word),
62    WorkgroupIdY(Word),
63    WorkgroupIdZ(Word),
64    GlobalInvocationIndex(Word),
65    GlobalInvocationIdX(Word),
66    GlobalInvocationIdY(Word),
67    GlobalInvocationIdZ(Word),
68    WorkgroupSize(Word),
69    WorkgroupSizeX(Word),
70    WorkgroupSizeY(Word),
71    WorkgroupSizeZ(Word),
72    NumWorkgroups(Word),
73    NumWorkgroupsX(Word),
74    NumWorkgroupsY(Word),
75    NumWorkgroupsZ(Word),
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
79pub enum ConstVal {
80    Bit32(u32),
81    Bit64(u64),
82}
83
84impl ConstVal {
85    pub fn as_u64(&self) -> u64 {
86        match self {
87            ConstVal::Bit32(val) => *val as u64,
88            ConstVal::Bit64(val) => *val,
89        }
90    }
91
92    pub fn as_u32(&self) -> u32 {
93        match self {
94            ConstVal::Bit32(val) => *val,
95            ConstVal::Bit64(_) => panic!("Truncating 64 bit variable to 32 bit"),
96        }
97    }
98
99    pub fn as_float(&self, width: u32, encoding: Option<FPEncoding>) -> f64 {
100        match (width, encoding) {
101            (64, _) => f64::from_bits(self.as_u64()),
102            (32, _) => f32::from_bits(self.as_u32()) as f64,
103            (16, None) => half::f16::from_bits(self.as_u32() as u16).to_f64(),
104            (_, Some(FPEncoding::BFloat16KHR)) => {
105                half::bf16::from_bits(self.as_u32() as u16).to_f64()
106            }
107            (_, Some(FPEncoding::Float8E4M3EXT)) => {
108                cubecl_common::e4m3::from_bits(self.as_u32() as u8).to_f64()
109            }
110            (_, Some(FPEncoding::Float8E5M2EXT)) => {
111                cubecl_common::e5m2::from_bits(self.as_u32() as u8).to_f64()
112            }
113            _ => unreachable!(),
114        }
115    }
116
117    pub fn as_int(&self, width: u32) -> i64 {
118        unsafe {
119            match width {
120                64 => transmute::<u64, i64>(self.as_u64()),
121                32 => transmute::<u32, i32>(self.as_u32()) as i64,
122                16 => transmute::<u16, i16>(self.as_u32() as u16) as i64,
123                8 => transmute::<u8, i8>(self.as_u32() as u8) as i64,
124                _ => unreachable!(),
125            }
126        }
127    }
128
129    pub fn from_float(value: f64, width: u32, encoding: Option<FPEncoding>) -> Self {
130        match (width, encoding) {
131            (64, _) => ConstVal::Bit64(value.to_bits()),
132            (32, _) => ConstVal::Bit32((value as f32).to_bits()),
133            (16, None) => ConstVal::Bit32(half::f16::from_f64(value).to_bits() as u32),
134            (_, Some(FPEncoding::BFloat16KHR)) => {
135                ConstVal::Bit32(half::bf16::from_f64(value).to_bits() as u32)
136            }
137            (_, Some(FPEncoding::Float8E4M3EXT)) => {
138                ConstVal::Bit32(cubecl_common::e4m3::from_f64(value).to_bits() as u32)
139            }
140            (_, Some(FPEncoding::Float8E5M2EXT)) => {
141                ConstVal::Bit32(cubecl_common::e5m2::from_f64(value).to_bits() as u32)
142            }
143            _ => unreachable!(),
144        }
145    }
146
147    pub fn from_int(value: i64, width: u32) -> Self {
148        match width {
149            64 => ConstVal::Bit64(unsafe { transmute::<i64, u64>(value) }),
150            32 => ConstVal::Bit32(unsafe { transmute::<i32, u32>(value as i32) }),
151            16 => ConstVal::Bit32(unsafe { transmute::<i16, u16>(value as i16) } as u32),
152            8 => ConstVal::Bit32(unsafe { transmute::<i8, u8>(value as i8) } as u32),
153            _ => unreachable!(),
154        }
155    }
156
157    pub fn from_uint(value: u64, width: u32) -> Self {
158        match width {
159            64 => ConstVal::Bit64(value),
160            32 => ConstVal::Bit32(value as u32),
161            16 => ConstVal::Bit32(value as u16 as u32),
162            8 => ConstVal::Bit32(value as u8 as u32),
163            _ => unreachable!(),
164        }
165    }
166
167    pub fn from_bool(value: bool) -> Self {
168        ConstVal::Bit32(value as u32)
169    }
170}
171
172impl From<ConstantScalarValue> for ConstVal {
173    fn from(value: ConstantScalarValue) -> Self {
174        let width = value.elem_type().size() as u32 * 8;
175        match value {
176            ConstantScalarValue::Int(val, _) => ConstVal::from_int(val, width),
177            ConstantScalarValue::Float(val, FloatKind::BF16) => {
178                ConstVal::from_float(val, width, Some(FPEncoding::BFloat16KHR))
179            }
180            ConstantScalarValue::Float(val, FloatKind::E4M3) => {
181                ConstVal::from_float(val, width, Some(FPEncoding::Float8E4M3EXT))
182            }
183            ConstantScalarValue::Float(val, FloatKind::E5M2) => {
184                ConstVal::from_float(val, width, Some(FPEncoding::Float8E5M2EXT))
185            }
186            ConstantScalarValue::Float(val, _) => ConstVal::from_float(val, width, None),
187            ConstantScalarValue::UInt(val, _) => ConstVal::from_uint(val, width),
188            ConstantScalarValue::Bool(val) => ConstVal::from_bool(val),
189        }
190    }
191}
192
193impl From<u32> for ConstVal {
194    fn from(value: u32) -> Self {
195        ConstVal::Bit32(value)
196    }
197}
198
199impl From<f32> for ConstVal {
200    fn from(value: f32) -> Self {
201        ConstVal::Bit32(value.to_bits())
202    }
203}
204
205impl Variable {
206    pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
207        match self {
208            Variable::GlobalInputArray(id, _, _) => *id,
209            Variable::GlobalOutputArray(id, _, _) => *id,
210            Variable::GlobalScalar(id, _) => *id,
211            Variable::ConstantScalar(id, _, _) => *id,
212            Variable::Local { id, .. } => *id,
213            Variable::Versioned {
214                id, variable: var, ..
215            } => b.get_versioned(*id, var),
216            Variable::LocalBinding {
217                id, variable: var, ..
218            } => b.get_binding(*id, var),
219            Variable::Raw(id, _) => *id,
220            Variable::Named { id, .. } => *id,
221            Variable::Slice { ptr, .. } => ptr.id(b),
222            Variable::SharedMemory(id, _, _) => *id,
223            Variable::ConstantArray(id, _, _) => *id,
224            Variable::LocalArray(id, _, _) => *id,
225            Variable::CoopMatrix(_, _) => unimplemented!("Can't get ID from matrix var"),
226            Variable::SubgroupSize(id) => *id,
227            Variable::Id(id) => *id,
228            Variable::LocalInvocationIndex(id) => *id,
229            Variable::LocalInvocationIdX(id) => *id,
230            Variable::LocalInvocationIdY(id) => *id,
231            Variable::LocalInvocationIdZ(id) => *id,
232            Variable::Rank(id) => *id,
233            Variable::WorkgroupId(id) => *id,
234            Variable::WorkgroupIdX(id) => *id,
235            Variable::WorkgroupIdY(id) => *id,
236            Variable::WorkgroupIdZ(id) => *id,
237            Variable::GlobalInvocationIndex(id) => *id,
238            Variable::GlobalInvocationIdX(id) => *id,
239            Variable::GlobalInvocationIdY(id) => *id,
240            Variable::GlobalInvocationIdZ(id) => *id,
241            Variable::WorkgroupSize(id) => *id,
242            Variable::WorkgroupSizeX(id) => *id,
243            Variable::WorkgroupSizeY(id) => *id,
244            Variable::WorkgroupSizeZ(id) => *id,
245            Variable::NumWorkgroups(id) => *id,
246            Variable::NumWorkgroupsX(id) => *id,
247            Variable::NumWorkgroupsY(id) => *id,
248            Variable::NumWorkgroupsZ(id) => *id,
249        }
250    }
251
252    pub fn item(&self) -> Item {
253        match self {
254            Variable::GlobalInputArray(_, item, _) => item.clone(),
255            Variable::GlobalOutputArray(_, item, _) => item.clone(),
256            Variable::GlobalScalar(_, elem) => Item::Scalar(*elem),
257            Variable::ConstantScalar(_, _, elem) => Item::Scalar(*elem),
258            Variable::Local { item, .. } => item.clone(),
259            Variable::Versioned { item, .. } => item.clone(),
260            Variable::LocalBinding { item, .. } => item.clone(),
261            Variable::Named { item, .. } => item.clone(),
262            Variable::Slice { item, .. } => item.clone(),
263            Variable::SharedMemory(_, item, _) => item.clone(),
264            Variable::ConstantArray(_, item, _) => item.clone(),
265            Variable::LocalArray(_, item, _) => item.clone(),
266            Variable::CoopMatrix(_, elem) => Item::Scalar(*elem),
267            _ => Item::Scalar(Elem::Int(32, false)), // builtin
268        }
269    }
270
271    pub fn indexed_item(&self) -> Item {
272        match self {
273            Variable::LocalBinding {
274                item: Item::Vector(elem, _),
275                ..
276            } => Item::Scalar(*elem),
277            Variable::Local {
278                item: Item::Vector(elem, _),
279                ..
280            } => Item::Scalar(*elem),
281            Variable::Versioned {
282                item: Item::Vector(elem, _),
283                ..
284            } => Item::Scalar(*elem),
285            other => other.item(),
286        }
287    }
288
289    pub fn elem(&self) -> Elem {
290        self.item().elem()
291    }
292
293    pub fn has_len(&self) -> bool {
294        matches!(
295            self,
296            Variable::GlobalInputArray(_, _, _)
297                | Variable::GlobalOutputArray(_, _, _)
298                | Variable::Named {
299                    is_array: false,
300                    ..
301                }
302                | Variable::Slice { .. }
303                | Variable::SharedMemory(_, _, _)
304                | Variable::ConstantArray(_, _, _)
305                | Variable::LocalArray(_, _, _)
306        )
307    }
308
309    pub fn has_buffer_len(&self) -> bool {
310        matches!(
311            self,
312            Variable::GlobalInputArray(_, _, _)
313                | Variable::GlobalOutputArray(_, _, _)
314                | Variable::Named {
315                    is_array: false,
316                    ..
317                }
318        )
319    }
320
321    pub fn as_const(&self) -> Option<ConstVal> {
322        match self {
323            Self::ConstantScalar(_, val, _) => Some(*val),
324            _ => None,
325        }
326    }
327
328    pub fn as_binding(&self) -> Option<Id> {
329        match self {
330            Self::LocalBinding { id, .. } => Some(*id),
331            _ => None,
332        }
333    }
334}
335
336#[derive(Debug)]
337pub enum IndexedVariable {
338    Pointer(Word, Item),
339    Composite(Word, u32, Item),
340    DynamicComposite(Word, u32, Item),
341    Scalar(Variable),
342}
343
344impl<T: SpirvTarget> SpirvCompiler<T> {
345    pub fn compile_variable(&mut self, variable: ir::Variable) -> Variable {
346        let item = variable.ty;
347        match variable.kind {
348            ir::VariableKind::ConstantScalar(value) => {
349                let item = self.compile_type(ir::Type::new(value.storage_type()));
350                let const_val = value.into();
351
352                if let Some(existing) = self.state.constants.get(&(const_val, item.clone())) {
353                    Variable::ConstantScalar(*existing, const_val, item.elem())
354                } else {
355                    let id = item.elem().constant(self, const_val);
356                    self.state.constants.insert((const_val, item.clone()), id);
357                    Variable::ConstantScalar(id, const_val, item.elem())
358                }
359            }
360            ir::VariableKind::GlobalInputArray(pos) => {
361                let id = self.state.buffers[pos as usize];
362                Variable::GlobalInputArray(id, self.compile_type(item), pos)
363            }
364            ir::VariableKind::GlobalOutputArray(pos) => {
365                let id = self.state.buffers[pos as usize];
366                Variable::GlobalOutputArray(id, self.compile_type(item), pos)
367            }
368            ir::VariableKind::GlobalScalar(id) => self.global_scalar(id, item.storage_type()),
369            ir::VariableKind::LocalMut { id } => {
370                let item = self.compile_type(item);
371                let var = self.get_local(id, &item, variable);
372                Variable::Local { id: var, item }
373            }
374            ir::VariableKind::Versioned { id, version } => {
375                let item = self.compile_type(item);
376                let id = (id, version);
377                Variable::Versioned { id, item, variable }
378            }
379            ir::VariableKind::LocalConst { id } => {
380                let item = self.compile_type(item);
381                Variable::LocalBinding { id, item, variable }
382            }
383            ir::VariableKind::Builtin(builtin) => self.compile_builtin(builtin),
384            ir::VariableKind::ConstantArray { id, length, .. } => {
385                let item = self.compile_type(item);
386                let id = self.state.const_arrays[id as usize].id;
387                Variable::ConstantArray(id, item, length)
388            }
389            ir::VariableKind::SharedMemory { id, length, .. } => {
390                let item = self.compile_type(item);
391                let id = self.state.shared_memories[&id].id;
392                Variable::SharedMemory(id, item, length)
393            }
394            ir::VariableKind::LocalArray {
395                id,
396                length,
397                unroll_factor,
398            } => {
399                let item = self.compile_type(item);
400                let id = if let Some(arr) = self.state.local_arrays.get(&id) {
401                    arr.id
402                } else {
403                    let arr_ty = Item::Array(Box::new(item.clone()), length);
404                    let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(arr_ty)).id(self);
405                    let arr_id = self.declare_function_variable(ptr_ty);
406                    self.debug_var_name(arr_id, variable);
407                    let arr = Array {
408                        id: arr_id,
409                        item: item.clone(),
410                        len: length * unroll_factor,
411                        var: variable,
412                        alignment: None,
413                    };
414                    self.state.local_arrays.insert(id, arr);
415                    arr_id
416                };
417                Variable::LocalArray(id, item, length)
418            }
419            ir::VariableKind::Matrix { id, mat } => {
420                let elem = self.compile_type(ir::Type::new(mat.storage)).elem();
421                if self.state.matrices.contains_key(&id) {
422                    Variable::CoopMatrix(id, elem)
423                } else {
424                    let matrix = self.init_coop_matrix(mat, variable);
425                    self.state.matrices.insert(id, matrix);
426                    Variable::CoopMatrix(id, elem)
427                }
428            }
429            ir::VariableKind::Pipeline { .. } => panic!("Pipeline not supported."),
430            ir::VariableKind::Barrier { .. } | ir::VariableKind::BarrierToken { .. } => {
431                panic!("Barrier not supported.")
432            }
433            ir::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
434            ir::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
435        }
436    }
437
438    pub fn read(&mut self, variable: &Variable) -> Word {
439        match variable {
440            Variable::Slice { ptr, .. } => self.read(ptr),
441            Variable::Local { id, item } => {
442                let ty = item.id(self);
443                self.load(ty, None, *id, None, vec![]).unwrap()
444            }
445            Variable::Named { id, item, .. } => {
446                let ty = item.id(self);
447                self.load(ty, None, *id, None, vec![]).unwrap()
448            }
449            ssa => ssa.id(self),
450        }
451    }
452
453    pub fn read_as(&mut self, variable: &Variable, item: &Item) -> Word {
454        if let Some(as_const) = variable.as_const() {
455            self.static_cast(as_const, &variable.elem(), item)
456        } else {
457            let id = self.read(variable);
458            variable.item().cast_to(self, None, id, item)
459        }
460    }
461
462    pub fn index(
463        &mut self,
464        variable: &Variable,
465        index: &Variable,
466        unchecked: bool,
467    ) -> IndexedVariable {
468        let access_chain = if unchecked {
469            Builder::in_bounds_access_chain
470        } else {
471            Builder::access_chain
472        };
473        let index_id = self.read(index);
474        match variable {
475            Variable::GlobalInputArray(id, item, _)
476            | Variable::GlobalOutputArray(id, item, _)
477            | Variable::Named { id, item, .. } => {
478                let ptr_ty =
479                    Item::Pointer(StorageClass::StorageBuffer, Box::new(item.clone())).id(self);
480                let zero = self.const_u32(0);
481                let id = access_chain(self, ptr_ty, None, *id, vec![zero, index_id]).unwrap();
482
483                IndexedVariable::Pointer(id, item.clone())
484            }
485            Variable::Local {
486                id,
487                item: Item::Vector(elem, _),
488            } => {
489                let ptr_ty =
490                    Item::Pointer(StorageClass::Function, Box::new(Item::Scalar(*elem))).id(self);
491                let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
492
493                IndexedVariable::Pointer(id, Item::Scalar(*elem))
494            }
495            Variable::LocalBinding {
496                id,
497                item: Item::Vector(elem, vec),
498                variable,
499            } if index.as_const().is_some() => IndexedVariable::Composite(
500                self.get_binding(*id, variable),
501                index.as_const().unwrap().as_u32(),
502                Item::Vector(*elem, *vec),
503            ),
504            Variable::LocalBinding {
505                id,
506                item: Item::Vector(elem, vec),
507                variable,
508            } => IndexedVariable::DynamicComposite(
509                self.get_binding(*id, variable),
510                index_id,
511                Item::Vector(*elem, *vec),
512            ),
513            Variable::Versioned {
514                id,
515                item: Item::Vector(elem, vec),
516                variable,
517            } if index.as_const().is_some() => IndexedVariable::Composite(
518                self.get_versioned(*id, variable),
519                index.as_const().unwrap().as_u32(),
520                Item::Vector(*elem, *vec),
521            ),
522            Variable::Versioned {
523                id,
524                item: Item::Vector(elem, vec),
525                variable,
526            } => IndexedVariable::DynamicComposite(
527                self.get_versioned(*id, variable),
528                index_id,
529                Item::Vector(*elem, *vec),
530            ),
531            Variable::Local { .. } | Variable::LocalBinding { .. } | Variable::Versioned { .. } => {
532                IndexedVariable::Scalar(variable.clone())
533            }
534            Variable::Slice { ptr, offset, .. } => {
535                let item = Item::Scalar(Elem::Int(32, false));
536                let int = item.id(self);
537                let index = match index.as_const() {
538                    Some(ConstVal::Bit32(0)) => *offset,
539                    _ => self.i_add(int, None, *offset, index_id).unwrap(),
540                };
541                self.index(ptr, &Variable::Raw(index, item), unchecked)
542            }
543            Variable::SharedMemory(id, item, _) => {
544                let ptr_ty =
545                    Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
546                let mut index = vec![index_id];
547                if self.compilation_options.supports_explicit_smem {
548                    index.insert(0, self.const_u32(0));
549                }
550                let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
551                IndexedVariable::Pointer(id, item.clone())
552            }
553            Variable::ConstantArray(id, item, _) | Variable::LocalArray(id, item, _) => {
554                let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
555                let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
556                IndexedVariable::Pointer(id, item.clone())
557            }
558            var => unimplemented!("Can't index into {var:?}"),
559        }
560    }
561
562    pub fn read_indexed(&mut self, out: &Variable, variable: &Variable, index: &Variable) -> Word {
563        let always_in_bounds = is_always_in_bounds(variable, index);
564        let indexed = self.index(variable, index, always_in_bounds);
565
566        let read = |b: &mut Self| match indexed {
567            IndexedVariable::Pointer(ptr, item) => {
568                let ty = item.id(b);
569                let out_id = b.write_id(out);
570                b.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
571            }
572            IndexedVariable::Composite(var, index, item) => {
573                let elem = item.elem();
574                let ty = elem.id(b);
575                let out_id = b.write_id(out);
576                b.composite_extract(ty, Some(out_id), var, vec![index])
577                    .unwrap()
578            }
579            IndexedVariable::DynamicComposite(var, index, item) => {
580                let elem = item.elem();
581                let ty = elem.id(b);
582                let out_id = b.write_id(out);
583                b.vector_extract_dynamic(ty, Some(out_id), var, index)
584                    .unwrap()
585            }
586            IndexedVariable::Scalar(var) => {
587                let ty = out.item().id(b);
588                let input = b.read(&var);
589                let out_id = b.write_id(out);
590                b.copy_object(ty, Some(out_id), input).unwrap();
591                b.write(out, out_id);
592                out_id
593            }
594        };
595
596        read(self)
597    }
598
599    pub fn read_indexed_unchecked(
600        &mut self,
601        out: &Variable,
602        variable: &Variable,
603        index: &Variable,
604    ) -> Word {
605        let indexed = self.index(variable, index, true);
606
607        match indexed {
608            IndexedVariable::Pointer(ptr, item) => {
609                let ty = item.id(self);
610                let out_id = self.write_id(out);
611                self.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
612            }
613            IndexedVariable::Composite(var, index, item) => {
614                let elem = item.elem();
615                let ty = elem.id(self);
616                let out_id = self.write_id(out);
617                self.composite_extract(ty, Some(out_id), var, vec![index])
618                    .unwrap()
619            }
620            IndexedVariable::DynamicComposite(var, index, item) => {
621                let elem = item.elem();
622                let ty = elem.id(self);
623                let out_id = self.write_id(out);
624                self.vector_extract_dynamic(ty, Some(out_id), var, index)
625                    .unwrap()
626            }
627            IndexedVariable::Scalar(var) => {
628                let ty = out.item().id(self);
629                let input = self.read(&var);
630                let out_id = self.write_id(out);
631                self.copy_object(ty, Some(out_id), input).unwrap();
632                self.write(out, out_id);
633                out_id
634            }
635        }
636    }
637
638    pub fn index_ptr(&mut self, var: &Variable, index: &Variable) -> Word {
639        let always_in_bounds = is_always_in_bounds(var, index);
640        match self.index(var, index, always_in_bounds) {
641            IndexedVariable::Pointer(ptr, _) => ptr,
642            other => unreachable!("{other:?}"),
643        }
644    }
645
646    pub fn write_id(&mut self, variable: &Variable) -> Word {
647        match variable {
648            Variable::LocalBinding { id, variable, .. } => self.get_binding(*id, variable),
649            Variable::Versioned { id, variable, .. } => self.get_versioned(*id, variable),
650            Variable::Local { .. } => self.id(),
651            Variable::GlobalScalar(id, _) => *id,
652            Variable::Raw(id, _) => *id,
653            Variable::ConstantScalar(_, _, _) => panic!("Can't write to constant scalar"),
654            Variable::GlobalInputArray(_, _, _)
655            | Variable::GlobalOutputArray(_, _, _)
656            | Variable::Slice { .. }
657            | Variable::Named { .. }
658            | Variable::SharedMemory(_, _, _)
659            | Variable::ConstantArray(_, _, _)
660            | Variable::LocalArray(_, _, _) => panic!("Can't write to unindexed array"),
661            global => panic!("Can't write to builtin {global:?}"),
662        }
663    }
664
665    pub fn write(&mut self, variable: &Variable, value: Word) {
666        match variable {
667            Variable::Local { id, .. } => self.store(*id, value, None, vec![]).unwrap(),
668            Variable::Slice { ptr, .. } => self.write(ptr, value),
669            _ => {}
670        }
671    }
672
673    pub fn write_indexed(&mut self, out: &Variable, index: &Variable, value: Word) {
674        let always_in_bounds = is_always_in_bounds(out, index);
675        let variable = self.index(out, index, always_in_bounds);
676
677        let write = |b: &mut Self| match variable {
678            IndexedVariable::Pointer(ptr, _) => b.store(ptr, value, None, vec![]).unwrap(),
679            IndexedVariable::Composite(var, index, item) => {
680                let ty = item.id(b);
681                let id = b
682                    .composite_insert(ty, None, value, var, vec![index])
683                    .unwrap();
684                b.write(out, id);
685            }
686            IndexedVariable::DynamicComposite(var, index, item) => {
687                let ty = item.id(b);
688                let id = b
689                    .vector_insert_dynamic(ty, None, value, var, index)
690                    .unwrap();
691                b.write(out, id);
692            }
693            IndexedVariable::Scalar(var) => b.write(&var, value),
694        };
695
696        write(self)
697    }
698
699    pub fn write_indexed_unchecked(&mut self, out: &Variable, index: &Variable, value: Word) {
700        let variable = self.index(out, index, true);
701
702        match variable {
703            IndexedVariable::Pointer(ptr, _) => self.store(ptr, value, None, vec![]).unwrap(),
704            IndexedVariable::Composite(var, index, item) => {
705                let ty = item.id(self);
706                let out_id = self
707                    .composite_insert(ty, None, value, var, vec![index])
708                    .unwrap();
709                self.write(out, out_id);
710            }
711            IndexedVariable::DynamicComposite(var, index, item) => {
712                let ty = item.id(self);
713                let out_id = self
714                    .vector_insert_dynamic(ty, None, value, var, index)
715                    .unwrap();
716                self.write(out, out_id);
717            }
718            IndexedVariable::Scalar(var) => self.write(&var, value),
719        }
720    }
721}
722
723fn is_always_in_bounds(var: &Variable, index: &Variable) -> bool {
724    let len = match var {
725        Variable::SharedMemory(_, _, len)
726        | Variable::ConstantArray(_, _, len)
727        | Variable::LocalArray(_, _, len)
728        | Variable::Slice {
729            const_len: Some(len),
730            ..
731        } => *len,
732        _ => return false,
733    };
734
735    let const_index = match index {
736        Variable::ConstantScalar(_, value, _) => value.as_u32(),
737        _ => return false,
738    };
739
740    const_index < len
741}