cubecl_spirv/
variable.rs

1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{
6    SpirvCompiler, SpirvTarget,
7    item::{Elem, Item},
8    lookups::Array,
9};
10use cubecl_core::ir::{self, ConstantValue, Id};
11use rspirv::{
12    dr::Builder,
13    spirv::{self, FPEncoding, StorageClass, Word},
14};
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum Variable {
18    GlobalInputArray(Word, Item, u32),
19    GlobalOutputArray(Word, Item, u32),
20    GlobalScalar(Word, Elem),
21    Constant(Word, ConstVal, Item),
22    Local {
23        id: Word,
24        item: Item,
25    },
26    Versioned {
27        id: (Id, u16),
28        item: Item,
29        variable: ir::Variable,
30    },
31    LocalBinding {
32        id: Id,
33        item: Item,
34        variable: ir::Variable,
35    },
36    Raw(Word, Item),
37    Named {
38        id: Word,
39        item: Item,
40        is_array: bool,
41    },
42    Slice {
43        ptr: Box<Variable>,
44        offset: Word,
45        end: Word,
46        const_len: Option<u32>,
47        item: Item,
48    },
49    SharedArray(Word, Item, u32),
50    Shared(Word, Item),
51    ConstantArray(Word, Item, u32),
52    LocalArray(Word, Item, u32),
53    CoopMatrix(Id, Elem),
54    Id(Word),
55    Builtin(Word, Item),
56}
57
58impl Variable {
59    pub fn scope(&self) -> spirv::Scope {
60        match self {
61            Variable::GlobalInputArray(..)
62            | Variable::GlobalOutputArray(..)
63            | Variable::Named { .. }
64            | Variable::GlobalScalar(..) => spirv::Scope::Device,
65            Variable::SharedArray(..) | Variable::Shared(..) => spirv::Scope::Workgroup,
66            Variable::CoopMatrix(..) => spirv::Scope::Subgroup,
67            Variable::Slice { ptr, .. } => ptr.scope(),
68            Variable::Raw(..) => unimplemented!("Can't get scope of raw variable"),
69            Variable::Id(_) => unimplemented!("Can't get scope of raw id"),
70            _ => spirv::Scope::Invocation,
71        }
72    }
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76pub enum ConstVal {
77    Bit32(u32),
78    Bit64(u64),
79}
80
81impl ConstVal {
82    pub fn as_u64(&self) -> u64 {
83        match self {
84            ConstVal::Bit32(val) => *val as u64,
85            ConstVal::Bit64(val) => *val,
86        }
87    }
88
89    pub fn as_u32(&self) -> u32 {
90        match self {
91            ConstVal::Bit32(val) => *val,
92            ConstVal::Bit64(_) => panic!("Truncating 64 bit variable to 32 bit"),
93        }
94    }
95
96    pub fn as_float(&self, width: u32, encoding: Option<FPEncoding>) -> f64 {
97        match (width, encoding) {
98            (64, _) => f64::from_bits(self.as_u64()),
99            (32, _) => f32::from_bits(self.as_u32()) as f64,
100            (16, None) => half::f16::from_bits(self.as_u32() as u16).to_f64(),
101            (_, Some(FPEncoding::BFloat16KHR)) => {
102                half::bf16::from_bits(self.as_u32() as u16).to_f64()
103            }
104            (_, Some(FPEncoding::Float8E4M3EXT)) => {
105                cubecl_common::e4m3::from_bits(self.as_u32() as u8).to_f64()
106            }
107            (_, Some(FPEncoding::Float8E5M2EXT)) => {
108                cubecl_common::e5m2::from_bits(self.as_u32() as u8).to_f64()
109            }
110            _ => unreachable!(),
111        }
112    }
113
114    pub fn as_int(&self, width: u32) -> i64 {
115        unsafe {
116            match width {
117                64 => transmute::<u64, i64>(self.as_u64()),
118                32 => transmute::<u32, i32>(self.as_u32()) as i64,
119                16 => transmute::<u16, i16>(self.as_u32() as u16) as i64,
120                8 => transmute::<u8, i8>(self.as_u32() as u8) as i64,
121                _ => unreachable!(),
122            }
123        }
124    }
125
126    pub fn from_float(value: f64, width: u32, encoding: Option<FPEncoding>) -> Self {
127        match (width, encoding) {
128            (64, _) => ConstVal::Bit64(value.to_bits()),
129            (32, _) => ConstVal::Bit32((value as f32).to_bits()),
130            (16, None) => ConstVal::Bit32(half::f16::from_f64(value).to_bits() as u32),
131            (_, Some(FPEncoding::BFloat16KHR)) => {
132                ConstVal::Bit32(half::bf16::from_f64(value).to_bits() as u32)
133            }
134            (_, Some(FPEncoding::Float8E4M3EXT)) => {
135                ConstVal::Bit32(cubecl_common::e4m3::from_f64(value).to_bits() as u32)
136            }
137            (_, Some(FPEncoding::Float8E5M2EXT)) => {
138                ConstVal::Bit32(cubecl_common::e5m2::from_f64(value).to_bits() as u32)
139            }
140            _ => unreachable!(),
141        }
142    }
143
144    pub fn from_int(value: i64, width: u32) -> Self {
145        match width {
146            64 => ConstVal::Bit64(unsafe { transmute::<i64, u64>(value) }),
147            32 => ConstVal::Bit32(unsafe { transmute::<i32, u32>(value as i32) }),
148            16 => ConstVal::Bit32(unsafe { transmute::<i16, u16>(value as i16) } as u32),
149            8 => ConstVal::Bit32(unsafe { transmute::<i8, u8>(value as i8) } as u32),
150            _ => unreachable!(),
151        }
152    }
153
154    pub fn from_uint(value: u64, width: u32) -> Self {
155        match width {
156            64 => ConstVal::Bit64(value),
157            32 => ConstVal::Bit32(value as u32),
158            16 => ConstVal::Bit32(value as u16 as u32),
159            8 => ConstVal::Bit32(value as u8 as u32),
160            _ => unreachable!(),
161        }
162    }
163
164    pub fn from_bool(value: bool) -> Self {
165        ConstVal::Bit32(value as u32)
166    }
167}
168
169impl From<(ConstantValue, Item)> for ConstVal {
170    fn from((value, ty): (ConstantValue, Item)) -> Self {
171        let elem = ty.elem();
172        let width = elem.size() * 8;
173        match value {
174            ConstantValue::Int(val) => ConstVal::from_int(val, width),
175            ConstantValue::Float(val) => ConstVal::from_float(val, width, elem.float_encoding()),
176            ConstantValue::UInt(val) => ConstVal::from_uint(val, width),
177            ConstantValue::Bool(val) => ConstVal::from_bool(val),
178        }
179    }
180}
181
182impl From<u32> for ConstVal {
183    fn from(value: u32) -> Self {
184        ConstVal::Bit32(value)
185    }
186}
187
188impl From<f32> for ConstVal {
189    fn from(value: f32) -> Self {
190        ConstVal::Bit32(value.to_bits())
191    }
192}
193
194impl Variable {
195    pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
196        match self {
197            Variable::GlobalInputArray(id, _, _) => *id,
198            Variable::GlobalOutputArray(id, _, _) => *id,
199            Variable::GlobalScalar(id, _) => *id,
200            Variable::Constant(id, _, _) => *id,
201            Variable::Local { id, .. } => *id,
202            Variable::Versioned {
203                id, variable: var, ..
204            } => b.get_versioned(*id, var),
205            Variable::LocalBinding {
206                id, variable: var, ..
207            } => b.get_binding(*id, var),
208            Variable::Raw(id, _) => *id,
209            Variable::Named { id, .. } => *id,
210            Variable::Slice { ptr, .. } => ptr.id(b),
211            Variable::SharedArray(id, _, _) => *id,
212            Variable::Shared(id, _) => *id,
213            Variable::ConstantArray(id, _, _) => *id,
214            Variable::LocalArray(id, _, _) => *id,
215            Variable::CoopMatrix(_, _) => unimplemented!("Can't get ID from matrix var"),
216            Variable::Id(id) => *id,
217            Variable::Builtin(id, ..) => *id,
218        }
219    }
220
221    pub fn item(&self) -> Item {
222        match self {
223            Variable::GlobalInputArray(_, item, _) => item.clone(),
224            Variable::GlobalOutputArray(_, item, _) => item.clone(),
225            Variable::GlobalScalar(_, elem) => Item::Scalar(*elem),
226            Variable::Constant(_, _, item) => item.clone(),
227            Variable::Local { item, .. } => item.clone(),
228            Variable::Versioned { item, .. } => item.clone(),
229            Variable::LocalBinding { item, .. } => item.clone(),
230            Variable::Named { item, .. } => item.clone(),
231            Variable::Slice { item, .. } => item.clone(),
232            Variable::SharedArray(_, item, _) => item.clone(),
233            Variable::Shared(_, item) => item.clone(),
234            Variable::ConstantArray(_, item, _) => item.clone(),
235            Variable::LocalArray(_, item, _) => item.clone(),
236            Variable::CoopMatrix(_, elem) => Item::Scalar(*elem),
237            Variable::Builtin(_, item) => item.clone(),
238            Variable::Raw(_, item) => item.clone(),
239            Variable::Id(_) => unimplemented!("Can't get item of raw ID"),
240        }
241    }
242
243    pub fn indexed_item(&self) -> Item {
244        match self {
245            Variable::LocalBinding {
246                item: Item::Vector(elem, _),
247                ..
248            } => Item::Scalar(*elem),
249            Variable::Local {
250                item: Item::Vector(elem, _),
251                ..
252            } => Item::Scalar(*elem),
253            Variable::Versioned {
254                item: Item::Vector(elem, _),
255                ..
256            } => Item::Scalar(*elem),
257            other => other.item(),
258        }
259    }
260
261    pub fn elem(&self) -> Elem {
262        self.item().elem()
263    }
264
265    pub fn has_len(&self) -> bool {
266        matches!(
267            self,
268            Variable::GlobalInputArray(_, _, _)
269                | Variable::GlobalOutputArray(_, _, _)
270                | Variable::Named {
271                    is_array: false,
272                    ..
273                }
274                | Variable::Slice { .. }
275                | Variable::SharedArray(_, _, _)
276                | Variable::ConstantArray(_, _, _)
277                | Variable::LocalArray(_, _, _)
278        )
279    }
280
281    pub fn has_buffer_len(&self) -> bool {
282        matches!(
283            self,
284            Variable::GlobalInputArray(_, _, _)
285                | Variable::GlobalOutputArray(_, _, _)
286                | Variable::Named {
287                    is_array: false,
288                    ..
289                }
290        )
291    }
292
293    pub fn as_const(&self) -> Option<ConstVal> {
294        match self {
295            Self::Constant(_, val, _) => Some(*val),
296            _ => None,
297        }
298    }
299
300    pub fn as_binding(&self) -> Option<Id> {
301        match self {
302            Self::LocalBinding { id, .. } => Some(*id),
303            _ => None,
304        }
305    }
306}
307
308#[derive(Debug)]
309pub enum IndexedVariable {
310    Pointer(Word, Item),
311    Composite(Word, u32, Item),
312    DynamicComposite(Word, u32, Item),
313    Scalar(Variable),
314}
315
316impl<T: SpirvTarget> SpirvCompiler<T> {
317    pub fn compile_variable(&mut self, variable: ir::Variable) -> Variable {
318        let item = variable.ty;
319        match variable.kind {
320            ir::VariableKind::Constant(value) => {
321                let item = self.compile_type(item);
322                let const_val = (value, item.clone()).into();
323
324                if let Some(existing) = self.state.constants.get(&(const_val, item.clone())) {
325                    Variable::Constant(*existing, const_val, item)
326                } else {
327                    let id = item.constant(self, const_val);
328                    self.state.constants.insert((const_val, item.clone()), id);
329                    Variable::Constant(id, const_val, item)
330                }
331            }
332            ir::VariableKind::GlobalInputArray(pos) => {
333                let id = self.state.buffers[pos as usize];
334                Variable::GlobalInputArray(id, self.compile_type(item), pos)
335            }
336            ir::VariableKind::GlobalOutputArray(pos) => {
337                let id = self.state.buffers[pos as usize];
338                Variable::GlobalOutputArray(id, self.compile_type(item), pos)
339            }
340            ir::VariableKind::GlobalScalar(id) => self.global_scalar(id, item.storage_type()),
341            ir::VariableKind::LocalMut { id } => {
342                let item = self.compile_type(item);
343                let var = self.get_local(id, &item, variable);
344                Variable::Local { id: var, item }
345            }
346            ir::VariableKind::Versioned { id, version } => {
347                let item = self.compile_type(item);
348                let id = (id, version);
349                Variable::Versioned { id, item, variable }
350            }
351            ir::VariableKind::LocalConst { id } => {
352                let item = self.compile_type(item);
353                Variable::LocalBinding { id, item, variable }
354            }
355            ir::VariableKind::Builtin(builtin) => {
356                let item = self.compile_type(item);
357                self.compile_builtin(builtin, item)
358            }
359            ir::VariableKind::ConstantArray { id, length, .. } => {
360                let item = self.compile_type(item);
361                let id = self.state.const_arrays[id as usize].id;
362                Variable::ConstantArray(id, item, length as u32)
363            }
364            ir::VariableKind::SharedArray { id, length, .. } => {
365                let item = self.compile_type(item);
366                let id = self.state.shared_arrays[&id].id;
367                Variable::SharedArray(id, item, length as u32)
368            }
369            ir::VariableKind::Shared { id } => {
370                let item = self.compile_type(item);
371                let id = self.state.shared[&id].id;
372                Variable::Shared(id, item)
373            }
374            ir::VariableKind::LocalArray {
375                id,
376                length,
377                unroll_factor,
378            } => {
379                let item = self.compile_type(item);
380                let id = if let Some(arr) = self.state.local_arrays.get(&id) {
381                    arr.id
382                } else {
383                    let arr_ty = Item::Array(Box::new(item.clone()), length as u32);
384                    let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(arr_ty)).id(self);
385                    let arr_id = self.declare_function_variable(ptr_ty);
386                    self.debug_var_name(arr_id, variable);
387                    let arr = Array {
388                        id: arr_id,
389                        item: item.clone(),
390                        len: (length * unroll_factor) as u32,
391                        var: variable,
392                        alignment: None,
393                    };
394                    self.state.local_arrays.insert(id, arr);
395                    arr_id
396                };
397                Variable::LocalArray(id, item, length as u32)
398            }
399            ir::VariableKind::Matrix { id, mat } => {
400                let elem = self.compile_type(ir::Type::new(mat.storage)).elem();
401                if self.state.matrices.contains_key(&id) {
402                    Variable::CoopMatrix(id, elem)
403                } else {
404                    let matrix = self.init_coop_matrix(mat, variable);
405                    self.state.matrices.insert(id, matrix);
406                    Variable::CoopMatrix(id, elem)
407                }
408            }
409            ir::VariableKind::Pipeline { .. } => panic!("Pipeline not supported."),
410            ir::VariableKind::BarrierToken { .. } => {
411                panic!("Barrier not supported.")
412            }
413            ir::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
414            ir::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
415        }
416    }
417
418    pub fn read(&mut self, variable: &Variable) -> Word {
419        match variable {
420            Variable::Slice { ptr, .. } => self.read(ptr),
421            Variable::Shared(id, item) if self.compilation_options.supports_explicit_smem => {
422                let ty = item.id(self);
423                let ptr_ty =
424                    Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
425                let index = vec![self.const_u32(0)];
426                let access = self.access_chain(ptr_ty, None, *id, index).unwrap();
427                self.load(ty, None, access, None, []).unwrap()
428            }
429            Variable::Local { id, item } | Variable::Shared(id, item) => {
430                let ty = item.id(self);
431                self.load(ty, None, *id, None, []).unwrap()
432            }
433            Variable::Named { id, item, .. } => {
434                let ty = item.id(self);
435                self.load(ty, None, *id, None, []).unwrap()
436            }
437            ssa => ssa.id(self),
438        }
439    }
440
441    pub fn read_as(&mut self, variable: &Variable, item: &Item) -> Word {
442        if let Some(as_const) = variable.as_const() {
443            self.static_cast(as_const, &variable.elem(), item).0
444        } else {
445            let id = self.read(variable);
446            variable.item().cast_to(self, None, id, item)
447        }
448    }
449
450    pub fn index(
451        &mut self,
452        variable: &Variable,
453        index: &Variable,
454        unchecked: bool,
455    ) -> IndexedVariable {
456        let access_chain = if unchecked {
457            Builder::in_bounds_access_chain
458        } else {
459            Builder::access_chain
460        };
461        let index_id = self.read(index);
462        match variable {
463            Variable::GlobalInputArray(id, item, _)
464            | Variable::GlobalOutputArray(id, item, _)
465            | Variable::Named { id, item, .. } => {
466                let ptr_ty =
467                    Item::Pointer(StorageClass::StorageBuffer, Box::new(item.clone())).id(self);
468                let zero = self.const_u32(0);
469                let id = access_chain(self, ptr_ty, None, *id, vec![zero, index_id]).unwrap();
470
471                IndexedVariable::Pointer(id, item.clone())
472            }
473            Variable::Local {
474                id,
475                item: Item::Vector(elem, _),
476            } => {
477                let ptr_ty =
478                    Item::Pointer(StorageClass::Function, Box::new(Item::Scalar(*elem))).id(self);
479                let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
480
481                IndexedVariable::Pointer(id, Item::Scalar(*elem))
482            }
483            Variable::Shared(id, Item::Vector(elem, _)) => {
484                let ptr_ty =
485                    Item::Pointer(StorageClass::Workgroup, Box::new(Item::Scalar(*elem))).id(self);
486
487                let mut index = vec![index_id];
488                if self.compilation_options.supports_explicit_smem {
489                    index.insert(0, self.const_u32(0));
490                }
491
492                let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
493
494                IndexedVariable::Pointer(id, Item::Scalar(*elem))
495            }
496            Variable::LocalBinding {
497                id,
498                item: Item::Vector(elem, vec),
499                variable,
500            } if index.as_const().is_some() => IndexedVariable::Composite(
501                self.get_binding(*id, variable),
502                index.as_const().unwrap().as_u32(),
503                Item::Vector(*elem, *vec),
504            ),
505            Variable::LocalBinding {
506                id,
507                item: Item::Vector(elem, vec),
508                variable,
509            } => IndexedVariable::DynamicComposite(
510                self.get_binding(*id, variable),
511                index_id,
512                Item::Vector(*elem, *vec),
513            ),
514            Variable::Versioned {
515                id,
516                item: Item::Vector(elem, vec),
517                variable,
518            } if index.as_const().is_some() => IndexedVariable::Composite(
519                self.get_versioned(*id, variable),
520                index.as_const().unwrap().as_u32(),
521                Item::Vector(*elem, *vec),
522            ),
523            Variable::Versioned {
524                id,
525                item: Item::Vector(elem, vec),
526                variable,
527            } => IndexedVariable::DynamicComposite(
528                self.get_versioned(*id, variable),
529                index_id,
530                Item::Vector(*elem, *vec),
531            ),
532            Variable::Shared(id, item) if self.compilation_options.supports_explicit_smem => {
533                let ptr_ty =
534                    Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
535                let index = vec![self.const_u32(0)];
536                let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
537                IndexedVariable::Pointer(id, item.clone())
538            }
539            Variable::Local { .. }
540            | Variable::Shared(..)
541            | Variable::LocalBinding { .. }
542            | Variable::Versioned { .. } => IndexedVariable::Scalar(variable.clone()),
543            Variable::Constant(_, val, item) => {
544                let scalar_item = Item::Scalar(item.elem());
545                let (id, val) = self.static_cast(*val, &item.elem(), &scalar_item);
546                IndexedVariable::Scalar(Variable::Constant(id, val, scalar_item))
547            }
548            Variable::Slice { ptr, offset, .. } => {
549                let item = Item::Scalar(Elem::Int(32, false));
550                let int = item.id(self);
551                let index = match index.as_const() {
552                    Some(ConstVal::Bit32(0)) => *offset,
553                    _ => self.i_add(int, None, *offset, index_id).unwrap(),
554                };
555                self.index(ptr, &Variable::Raw(index, item), unchecked)
556            }
557            Variable::SharedArray(id, item, _) => {
558                let ptr_ty =
559                    Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
560                let mut index = vec![index_id];
561                if self.compilation_options.supports_explicit_smem {
562                    index.insert(0, self.const_u32(0));
563                }
564                let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
565                IndexedVariable::Pointer(id, item.clone())
566            }
567            Variable::ConstantArray(id, item, _) | Variable::LocalArray(id, item, _) => {
568                let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
569                let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
570                IndexedVariable::Pointer(id, item.clone())
571            }
572            var => unimplemented!("Can't index into {var:?}"),
573        }
574    }
575
576    pub fn read_indexed(&mut self, out: &Variable, variable: &Variable, index: &Variable) -> Word {
577        let always_in_bounds = is_always_in_bounds(variable, index);
578        let indexed = self.index(variable, index, always_in_bounds);
579
580        let read = |b: &mut Self| match indexed {
581            IndexedVariable::Pointer(ptr, item) => {
582                let ty = item.id(b);
583                let out_id = b.write_id(out);
584                b.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
585            }
586            IndexedVariable::Composite(var, index, item) => {
587                let elem = item.elem();
588                let ty = elem.id(b);
589                let out_id = b.write_id(out);
590                b.composite_extract(ty, Some(out_id), var, vec![index])
591                    .unwrap()
592            }
593            IndexedVariable::DynamicComposite(var, index, item) => {
594                let elem = item.elem();
595                let ty = elem.id(b);
596                let out_id = b.write_id(out);
597                b.vector_extract_dynamic(ty, Some(out_id), var, index)
598                    .unwrap()
599            }
600            IndexedVariable::Scalar(var) => {
601                let ty = out.item().id(b);
602                let input = b.read(&var);
603                let out_id = b.write_id(out);
604                b.copy_object(ty, Some(out_id), input).unwrap();
605                b.write(out, out_id);
606                out_id
607            }
608        };
609
610        read(self)
611    }
612
613    pub fn read_indexed_unchecked(
614        &mut self,
615        out: &Variable,
616        variable: &Variable,
617        index: &Variable,
618    ) -> Word {
619        let indexed = self.index(variable, index, true);
620
621        match indexed {
622            IndexedVariable::Pointer(ptr, item) => {
623                let ty = item.id(self);
624                let out_id = self.write_id(out);
625                self.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
626            }
627            IndexedVariable::Composite(var, index, item) => {
628                let elem = item.elem();
629                let ty = elem.id(self);
630                let out_id = self.write_id(out);
631                self.composite_extract(ty, Some(out_id), var, vec![index])
632                    .unwrap()
633            }
634            IndexedVariable::DynamicComposite(var, index, item) => {
635                let elem = item.elem();
636                let ty = elem.id(self);
637                let out_id = self.write_id(out);
638                self.vector_extract_dynamic(ty, Some(out_id), var, index)
639                    .unwrap()
640            }
641            IndexedVariable::Scalar(var) => {
642                let ty = out.item().id(self);
643                let input = self.read(&var);
644                let out_id = self.write_id(out);
645                self.copy_object(ty, Some(out_id), input).unwrap();
646                self.write(out, out_id);
647                out_id
648            }
649        }
650    }
651
652    pub fn index_ptr(&mut self, var: &Variable, index: &Variable) -> Word {
653        let always_in_bounds = is_always_in_bounds(var, index);
654        match self.index(var, index, always_in_bounds) {
655            IndexedVariable::Pointer(ptr, _) => ptr,
656            other => unreachable!("{other:?}"),
657        }
658    }
659
660    pub fn write_id(&mut self, variable: &Variable) -> Word {
661        match variable {
662            Variable::LocalBinding { id, variable, .. } => self.get_binding(*id, variable),
663            Variable::Versioned { id, variable, .. } => self.get_versioned(*id, variable),
664            Variable::Local { .. } => self.id(),
665            Variable::Shared(..) => self.id(),
666            Variable::GlobalScalar(id, _) => *id,
667            Variable::Raw(id, _) => *id,
668            Variable::Constant(_, _, _) => panic!("Can't write to constant scalar"),
669            Variable::GlobalInputArray(_, _, _)
670            | Variable::GlobalOutputArray(_, _, _)
671            | Variable::Slice { .. }
672            | Variable::Named { .. }
673            | Variable::SharedArray(_, _, _)
674            | Variable::ConstantArray(_, _, _)
675            | Variable::LocalArray(_, _, _) => panic!("Can't write to unindexed array"),
676            global => panic!("Can't write to builtin {global:?}"),
677        }
678    }
679
680    pub fn write(&mut self, variable: &Variable, value: Word) {
681        match variable {
682            Variable::Shared(id, item) if self.compilation_options.supports_explicit_smem => {
683                let ptr_ty =
684                    Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
685                let index = vec![self.const_u32(0)];
686                let access = self.access_chain(ptr_ty, None, *id, index).unwrap();
687                self.store(access, value, None, []).unwrap()
688            }
689            Variable::Local { id, .. } | Variable::Shared(id, _) => {
690                self.store(*id, value, None, []).unwrap()
691            }
692
693            Variable::Slice { ptr, .. } => self.write(ptr, value),
694            _ => {}
695        }
696    }
697
698    pub fn write_indexed(&mut self, out: &Variable, index: &Variable, value: Word) {
699        let always_in_bounds = is_always_in_bounds(out, index);
700        let variable = self.index(out, index, always_in_bounds);
701
702        let write = |b: &mut Self| match variable {
703            IndexedVariable::Pointer(ptr, _) => b.store(ptr, value, None, vec![]).unwrap(),
704            IndexedVariable::Composite(var, index, item) => {
705                let ty = item.id(b);
706                let id = b
707                    .composite_insert(ty, None, value, var, vec![index])
708                    .unwrap();
709                b.write(out, id);
710            }
711            IndexedVariable::DynamicComposite(var, index, item) => {
712                let ty = item.id(b);
713                let id = b
714                    .vector_insert_dynamic(ty, None, value, var, index)
715                    .unwrap();
716                b.write(out, id);
717            }
718            IndexedVariable::Scalar(var) => b.write(&var, value),
719        };
720
721        write(self)
722    }
723
724    pub fn write_indexed_unchecked(&mut self, out: &Variable, index: &Variable, value: Word) {
725        let variable = self.index(out, index, true);
726
727        match variable {
728            IndexedVariable::Pointer(ptr, _) => self.store(ptr, value, None, vec![]).unwrap(),
729            IndexedVariable::Composite(var, index, item) => {
730                let ty = item.id(self);
731                let out_id = self
732                    .composite_insert(ty, None, value, var, vec![index])
733                    .unwrap();
734                self.write(out, out_id);
735            }
736            IndexedVariable::DynamicComposite(var, index, item) => {
737                let ty = item.id(self);
738                let out_id = self
739                    .vector_insert_dynamic(ty, None, value, var, index)
740                    .unwrap();
741                self.write(out, out_id);
742            }
743            IndexedVariable::Scalar(var) => self.write(&var, value),
744        }
745    }
746}
747
748fn is_always_in_bounds(var: &Variable, index: &Variable) -> bool {
749    let len = match var {
750        Variable::SharedArray(_, _, len)
751        | Variable::ConstantArray(_, _, len)
752        | Variable::LocalArray(_, _, len)
753        | Variable::Slice {
754            const_len: Some(len),
755            ..
756        } => *len,
757        _ => return false,
758    };
759
760    let const_index = match index {
761        Variable::Constant(_, value, _) => value.as_u32(),
762        _ => return false,
763    };
764
765    const_index < len
766}