cubecl_spirv/
variable.rs

1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::mem::transmute;
4
5use crate::{
6    SpirvCompiler, SpirvTarget,
7    item::{Elem, Item},
8    lookups::Array,
9};
10use cubecl_core::ir::{self, ConstantScalarValue, FloatKind, Id};
11use rspirv::{
12    dr::Builder,
13    spirv::{FPEncoding, StorageClass, Word},
14};
15
16#[derive(Debug, Clone, PartialEq)]
17pub enum Variable {
18    SubgroupSize(Word),
19    GlobalInputArray(Word, Item, u32),
20    GlobalOutputArray(Word, Item, u32),
21    GlobalScalar(Word, Elem),
22    ConstantScalar(Word, ConstVal, Elem),
23    Local {
24        id: Word,
25        item: Item,
26    },
27    Versioned {
28        id: (Id, u16),
29        item: Item,
30        variable: ir::Variable,
31    },
32    LocalBinding {
33        id: Id,
34        item: Item,
35        variable: ir::Variable,
36    },
37    Raw(Word, Item),
38    Named {
39        id: Word,
40        item: Item,
41        is_array: bool,
42    },
43    Slice {
44        ptr: Box<Variable>,
45        offset: Word,
46        end: Word,
47        const_len: Option<u32>,
48        item: Item,
49    },
50    SharedArray(Word, Item, u32),
51    Shared(Word, Item),
52    ConstantArray(Word, Item, u32),
53    LocalArray(Word, Item, u32),
54    CoopMatrix(Id, Elem),
55    Id(Word),
56    LocalInvocationIndex(Word),
57    LocalInvocationIdX(Word),
58    LocalInvocationIdY(Word),
59    LocalInvocationIdZ(Word),
60    Rank(Word),
61    WorkgroupId(Word),
62    WorkgroupIdX(Word),
63    WorkgroupIdY(Word),
64    WorkgroupIdZ(Word),
65    GlobalInvocationIndex(Word),
66    GlobalInvocationIdX(Word),
67    GlobalInvocationIdY(Word),
68    GlobalInvocationIdZ(Word),
69    WorkgroupSize(Word),
70    WorkgroupSizeX(Word),
71    WorkgroupSizeY(Word),
72    WorkgroupSizeZ(Word),
73    NumWorkgroups(Word),
74    NumWorkgroupsX(Word),
75    NumWorkgroupsY(Word),
76    NumWorkgroupsZ(Word),
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
80pub enum ConstVal {
81    Bit32(u32),
82    Bit64(u64),
83}
84
85impl ConstVal {
86    pub fn as_u64(&self) -> u64 {
87        match self {
88            ConstVal::Bit32(val) => *val as u64,
89            ConstVal::Bit64(val) => *val,
90        }
91    }
92
93    pub fn as_u32(&self) -> u32 {
94        match self {
95            ConstVal::Bit32(val) => *val,
96            ConstVal::Bit64(_) => panic!("Truncating 64 bit variable to 32 bit"),
97        }
98    }
99
100    pub fn as_float(&self, width: u32, encoding: Option<FPEncoding>) -> f64 {
101        match (width, encoding) {
102            (64, _) => f64::from_bits(self.as_u64()),
103            (32, _) => f32::from_bits(self.as_u32()) as f64,
104            (16, None) => half::f16::from_bits(self.as_u32() as u16).to_f64(),
105            (_, Some(FPEncoding::BFloat16KHR)) => {
106                half::bf16::from_bits(self.as_u32() as u16).to_f64()
107            }
108            (_, Some(FPEncoding::Float8E4M3EXT)) => {
109                cubecl_common::e4m3::from_bits(self.as_u32() as u8).to_f64()
110            }
111            (_, Some(FPEncoding::Float8E5M2EXT)) => {
112                cubecl_common::e5m2::from_bits(self.as_u32() as u8).to_f64()
113            }
114            _ => unreachable!(),
115        }
116    }
117
118    pub fn as_int(&self, width: u32) -> i64 {
119        unsafe {
120            match width {
121                64 => transmute::<u64, i64>(self.as_u64()),
122                32 => transmute::<u32, i32>(self.as_u32()) as i64,
123                16 => transmute::<u16, i16>(self.as_u32() as u16) as i64,
124                8 => transmute::<u8, i8>(self.as_u32() as u8) as i64,
125                _ => unreachable!(),
126            }
127        }
128    }
129
130    pub fn from_float(value: f64, width: u32, encoding: Option<FPEncoding>) -> Self {
131        match (width, encoding) {
132            (64, _) => ConstVal::Bit64(value.to_bits()),
133            (32, _) => ConstVal::Bit32((value as f32).to_bits()),
134            (16, None) => ConstVal::Bit32(half::f16::from_f64(value).to_bits() as u32),
135            (_, Some(FPEncoding::BFloat16KHR)) => {
136                ConstVal::Bit32(half::bf16::from_f64(value).to_bits() as u32)
137            }
138            (_, Some(FPEncoding::Float8E4M3EXT)) => {
139                ConstVal::Bit32(cubecl_common::e4m3::from_f64(value).to_bits() as u32)
140            }
141            (_, Some(FPEncoding::Float8E5M2EXT)) => {
142                ConstVal::Bit32(cubecl_common::e5m2::from_f64(value).to_bits() as u32)
143            }
144            _ => unreachable!(),
145        }
146    }
147
148    pub fn from_int(value: i64, width: u32) -> Self {
149        match width {
150            64 => ConstVal::Bit64(unsafe { transmute::<i64, u64>(value) }),
151            32 => ConstVal::Bit32(unsafe { transmute::<i32, u32>(value as i32) }),
152            16 => ConstVal::Bit32(unsafe { transmute::<i16, u16>(value as i16) } as u32),
153            8 => ConstVal::Bit32(unsafe { transmute::<i8, u8>(value as i8) } as u32),
154            _ => unreachable!(),
155        }
156    }
157
158    pub fn from_uint(value: u64, width: u32) -> Self {
159        match width {
160            64 => ConstVal::Bit64(value),
161            32 => ConstVal::Bit32(value as u32),
162            16 => ConstVal::Bit32(value as u16 as u32),
163            8 => ConstVal::Bit32(value as u8 as u32),
164            _ => unreachable!(),
165        }
166    }
167
168    pub fn from_bool(value: bool) -> Self {
169        ConstVal::Bit32(value as u32)
170    }
171}
172
173impl From<ConstantScalarValue> for ConstVal {
174    fn from(value: ConstantScalarValue) -> Self {
175        let width = value.elem_type().size() as u32 * 8;
176        match value {
177            ConstantScalarValue::Int(val, _) => ConstVal::from_int(val, width),
178            ConstantScalarValue::Float(val, FloatKind::BF16) => {
179                ConstVal::from_float(val, width, Some(FPEncoding::BFloat16KHR))
180            }
181            ConstantScalarValue::Float(val, FloatKind::E4M3) => {
182                ConstVal::from_float(val, width, Some(FPEncoding::Float8E4M3EXT))
183            }
184            ConstantScalarValue::Float(val, FloatKind::E5M2) => {
185                ConstVal::from_float(val, width, Some(FPEncoding::Float8E5M2EXT))
186            }
187            ConstantScalarValue::Float(val, _) => ConstVal::from_float(val, width, None),
188            ConstantScalarValue::UInt(val, _) => ConstVal::from_uint(val, width),
189            ConstantScalarValue::Bool(val) => ConstVal::from_bool(val),
190        }
191    }
192}
193
194impl From<u32> for ConstVal {
195    fn from(value: u32) -> Self {
196        ConstVal::Bit32(value)
197    }
198}
199
200impl From<f32> for ConstVal {
201    fn from(value: f32) -> Self {
202        ConstVal::Bit32(value.to_bits())
203    }
204}
205
206impl Variable {
207    pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
208        match self {
209            Variable::GlobalInputArray(id, _, _) => *id,
210            Variable::GlobalOutputArray(id, _, _) => *id,
211            Variable::GlobalScalar(id, _) => *id,
212            Variable::ConstantScalar(id, _, _) => *id,
213            Variable::Local { id, .. } => *id,
214            Variable::Versioned {
215                id, variable: var, ..
216            } => b.get_versioned(*id, var),
217            Variable::LocalBinding {
218                id, variable: var, ..
219            } => b.get_binding(*id, var),
220            Variable::Raw(id, _) => *id,
221            Variable::Named { id, .. } => *id,
222            Variable::Slice { ptr, .. } => ptr.id(b),
223            Variable::SharedArray(id, _, _) => *id,
224            Variable::Shared(id, _) => *id,
225            Variable::ConstantArray(id, _, _) => *id,
226            Variable::LocalArray(id, _, _) => *id,
227            Variable::CoopMatrix(_, _) => unimplemented!("Can't get ID from matrix var"),
228            Variable::SubgroupSize(id) => *id,
229            Variable::Id(id) => *id,
230            Variable::LocalInvocationIndex(id) => *id,
231            Variable::LocalInvocationIdX(id) => *id,
232            Variable::LocalInvocationIdY(id) => *id,
233            Variable::LocalInvocationIdZ(id) => *id,
234            Variable::Rank(id) => *id,
235            Variable::WorkgroupId(id) => *id,
236            Variable::WorkgroupIdX(id) => *id,
237            Variable::WorkgroupIdY(id) => *id,
238            Variable::WorkgroupIdZ(id) => *id,
239            Variable::GlobalInvocationIndex(id) => *id,
240            Variable::GlobalInvocationIdX(id) => *id,
241            Variable::GlobalInvocationIdY(id) => *id,
242            Variable::GlobalInvocationIdZ(id) => *id,
243            Variable::WorkgroupSize(id) => *id,
244            Variable::WorkgroupSizeX(id) => *id,
245            Variable::WorkgroupSizeY(id) => *id,
246            Variable::WorkgroupSizeZ(id) => *id,
247            Variable::NumWorkgroups(id) => *id,
248            Variable::NumWorkgroupsX(id) => *id,
249            Variable::NumWorkgroupsY(id) => *id,
250            Variable::NumWorkgroupsZ(id) => *id,
251        }
252    }
253
254    pub fn item(&self) -> Item {
255        match self {
256            Variable::GlobalInputArray(_, item, _) => item.clone(),
257            Variable::GlobalOutputArray(_, item, _) => item.clone(),
258            Variable::GlobalScalar(_, elem) => Item::Scalar(*elem),
259            Variable::ConstantScalar(_, _, elem) => Item::Scalar(*elem),
260            Variable::Local { item, .. } => item.clone(),
261            Variable::Versioned { item, .. } => item.clone(),
262            Variable::LocalBinding { item, .. } => item.clone(),
263            Variable::Named { item, .. } => item.clone(),
264            Variable::Slice { item, .. } => item.clone(),
265            Variable::SharedArray(_, item, _) => item.clone(),
266            Variable::Shared(_, item) => item.clone(),
267            Variable::ConstantArray(_, item, _) => item.clone(),
268            Variable::LocalArray(_, item, _) => item.clone(),
269            Variable::CoopMatrix(_, elem) => Item::Scalar(*elem),
270            _ => Item::Scalar(Elem::Int(32, false)), // builtin
271        }
272    }
273
274    pub fn indexed_item(&self) -> Item {
275        match self {
276            Variable::LocalBinding {
277                item: Item::Vector(elem, _),
278                ..
279            } => Item::Scalar(*elem),
280            Variable::Local {
281                item: Item::Vector(elem, _),
282                ..
283            } => Item::Scalar(*elem),
284            Variable::Versioned {
285                item: Item::Vector(elem, _),
286                ..
287            } => Item::Scalar(*elem),
288            other => other.item(),
289        }
290    }
291
292    pub fn elem(&self) -> Elem {
293        self.item().elem()
294    }
295
296    pub fn has_len(&self) -> bool {
297        matches!(
298            self,
299            Variable::GlobalInputArray(_, _, _)
300                | Variable::GlobalOutputArray(_, _, _)
301                | Variable::Named {
302                    is_array: false,
303                    ..
304                }
305                | Variable::Slice { .. }
306                | Variable::SharedArray(_, _, _)
307                | Variable::ConstantArray(_, _, _)
308                | Variable::LocalArray(_, _, _)
309        )
310    }
311
312    pub fn has_buffer_len(&self) -> bool {
313        matches!(
314            self,
315            Variable::GlobalInputArray(_, _, _)
316                | Variable::GlobalOutputArray(_, _, _)
317                | Variable::Named {
318                    is_array: false,
319                    ..
320                }
321        )
322    }
323
324    pub fn as_const(&self) -> Option<ConstVal> {
325        match self {
326            Self::ConstantScalar(_, val, _) => Some(*val),
327            _ => None,
328        }
329    }
330
331    pub fn as_binding(&self) -> Option<Id> {
332        match self {
333            Self::LocalBinding { id, .. } => Some(*id),
334            _ => None,
335        }
336    }
337}
338
339#[derive(Debug)]
340pub enum IndexedVariable {
341    Pointer(Word, Item),
342    Composite(Word, u32, Item),
343    DynamicComposite(Word, u32, Item),
344    Scalar(Variable),
345}
346
347impl<T: SpirvTarget> SpirvCompiler<T> {
348    pub fn compile_variable(&mut self, variable: ir::Variable) -> Variable {
349        let item = variable.ty;
350        match variable.kind {
351            ir::VariableKind::ConstantScalar(value) => {
352                let item = self.compile_type(ir::Type::new(value.storage_type()));
353                let const_val = value.into();
354
355                if let Some(existing) = self.state.constants.get(&(const_val, item.clone())) {
356                    Variable::ConstantScalar(*existing, const_val, item.elem())
357                } else {
358                    let id = item.elem().constant(self, const_val);
359                    self.state.constants.insert((const_val, item.clone()), id);
360                    Variable::ConstantScalar(id, const_val, item.elem())
361                }
362            }
363            ir::VariableKind::GlobalInputArray(pos) => {
364                let id = self.state.buffers[pos as usize];
365                Variable::GlobalInputArray(id, self.compile_type(item), pos)
366            }
367            ir::VariableKind::GlobalOutputArray(pos) => {
368                let id = self.state.buffers[pos as usize];
369                Variable::GlobalOutputArray(id, self.compile_type(item), pos)
370            }
371            ir::VariableKind::GlobalScalar(id) => self.global_scalar(id, item.storage_type()),
372            ir::VariableKind::LocalMut { id } => {
373                let item = self.compile_type(item);
374                let var = self.get_local(id, &item, variable);
375                Variable::Local { id: var, item }
376            }
377            ir::VariableKind::Versioned { id, version } => {
378                let item = self.compile_type(item);
379                let id = (id, version);
380                Variable::Versioned { id, item, variable }
381            }
382            ir::VariableKind::LocalConst { id } => {
383                let item = self.compile_type(item);
384                Variable::LocalBinding { id, item, variable }
385            }
386            ir::VariableKind::Builtin(builtin) => self.compile_builtin(builtin),
387            ir::VariableKind::ConstantArray { id, length, .. } => {
388                let item = self.compile_type(item);
389                let id = self.state.const_arrays[id as usize].id;
390                Variable::ConstantArray(id, item, length)
391            }
392            ir::VariableKind::SharedArray { id, length, .. } => {
393                let item = self.compile_type(item);
394                let id = self.state.shared_arrays[&id].id;
395                Variable::SharedArray(id, item, length)
396            }
397            ir::VariableKind::Shared { id } => {
398                let item = self.compile_type(item);
399                let id = self.state.shared[&id].id;
400                Variable::Shared(id, item)
401            }
402            ir::VariableKind::LocalArray {
403                id,
404                length,
405                unroll_factor,
406            } => {
407                let item = self.compile_type(item);
408                let id = if let Some(arr) = self.state.local_arrays.get(&id) {
409                    arr.id
410                } else {
411                    let arr_ty = Item::Array(Box::new(item.clone()), length);
412                    let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(arr_ty)).id(self);
413                    let arr_id = self.declare_function_variable(ptr_ty);
414                    self.debug_var_name(arr_id, variable);
415                    let arr = Array {
416                        id: arr_id,
417                        item: item.clone(),
418                        len: length * unroll_factor,
419                        var: variable,
420                        alignment: None,
421                    };
422                    self.state.local_arrays.insert(id, arr);
423                    arr_id
424                };
425                Variable::LocalArray(id, item, length)
426            }
427            ir::VariableKind::Matrix { id, mat } => {
428                let elem = self.compile_type(ir::Type::new(mat.storage)).elem();
429                if self.state.matrices.contains_key(&id) {
430                    Variable::CoopMatrix(id, elem)
431                } else {
432                    let matrix = self.init_coop_matrix(mat, variable);
433                    self.state.matrices.insert(id, matrix);
434                    Variable::CoopMatrix(id, elem)
435                }
436            }
437            ir::VariableKind::Pipeline { .. } => panic!("Pipeline not supported."),
438            ir::VariableKind::BarrierToken { .. } => {
439                panic!("Barrier not supported.")
440            }
441            ir::VariableKind::TensorMapInput(_) => panic!("Tensor map not supported."),
442            ir::VariableKind::TensorMapOutput(_) => panic!("Tensor map not supported."),
443        }
444    }
445
446    pub fn read(&mut self, variable: &Variable) -> Word {
447        match variable {
448            Variable::Slice { ptr, .. } => self.read(ptr),
449            Variable::Shared(id, item) if self.compilation_options.supports_explicit_smem => {
450                let ty = item.id(self);
451                let ptr_ty =
452                    Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
453                let index = vec![self.const_u32(0)];
454                let access = self.access_chain(ptr_ty, None, *id, index).unwrap();
455                self.load(ty, None, access, None, []).unwrap()
456            }
457            Variable::Local { id, item } | Variable::Shared(id, item) => {
458                let ty = item.id(self);
459                self.load(ty, None, *id, None, []).unwrap()
460            }
461            Variable::Named { id, item, .. } => {
462                let ty = item.id(self);
463                self.load(ty, None, *id, None, []).unwrap()
464            }
465            ssa => ssa.id(self),
466        }
467    }
468
469    pub fn read_as(&mut self, variable: &Variable, item: &Item) -> Word {
470        if let Some(as_const) = variable.as_const() {
471            self.static_cast(as_const, &variable.elem(), item)
472        } else {
473            let id = self.read(variable);
474            variable.item().cast_to(self, None, id, item)
475        }
476    }
477
478    pub fn index(
479        &mut self,
480        variable: &Variable,
481        index: &Variable,
482        unchecked: bool,
483    ) -> IndexedVariable {
484        let access_chain = if unchecked {
485            Builder::in_bounds_access_chain
486        } else {
487            Builder::access_chain
488        };
489        let index_id = self.read(index);
490        match variable {
491            Variable::GlobalInputArray(id, item, _)
492            | Variable::GlobalOutputArray(id, item, _)
493            | Variable::Named { id, item, .. } => {
494                let ptr_ty =
495                    Item::Pointer(StorageClass::StorageBuffer, Box::new(item.clone())).id(self);
496                let zero = self.const_u32(0);
497                let id = access_chain(self, ptr_ty, None, *id, vec![zero, index_id]).unwrap();
498
499                IndexedVariable::Pointer(id, item.clone())
500            }
501            Variable::Local {
502                id,
503                item: Item::Vector(elem, _),
504            } => {
505                let ptr_ty =
506                    Item::Pointer(StorageClass::Function, Box::new(Item::Scalar(*elem))).id(self);
507                let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
508
509                IndexedVariable::Pointer(id, Item::Scalar(*elem))
510            }
511            Variable::Shared(id, Item::Vector(elem, _)) => {
512                let ptr_ty =
513                    Item::Pointer(StorageClass::Workgroup, Box::new(Item::Scalar(*elem))).id(self);
514
515                let mut index = vec![index_id];
516                if self.compilation_options.supports_explicit_smem {
517                    index.insert(0, self.const_u32(0));
518                }
519
520                let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
521
522                IndexedVariable::Pointer(id, Item::Scalar(*elem))
523            }
524            Variable::LocalBinding {
525                id,
526                item: Item::Vector(elem, vec),
527                variable,
528            } if index.as_const().is_some() => IndexedVariable::Composite(
529                self.get_binding(*id, variable),
530                index.as_const().unwrap().as_u32(),
531                Item::Vector(*elem, *vec),
532            ),
533            Variable::LocalBinding {
534                id,
535                item: Item::Vector(elem, vec),
536                variable,
537            } => IndexedVariable::DynamicComposite(
538                self.get_binding(*id, variable),
539                index_id,
540                Item::Vector(*elem, *vec),
541            ),
542            Variable::Versioned {
543                id,
544                item: Item::Vector(elem, vec),
545                variable,
546            } if index.as_const().is_some() => IndexedVariable::Composite(
547                self.get_versioned(*id, variable),
548                index.as_const().unwrap().as_u32(),
549                Item::Vector(*elem, *vec),
550            ),
551            Variable::Versioned {
552                id,
553                item: Item::Vector(elem, vec),
554                variable,
555            } => IndexedVariable::DynamicComposite(
556                self.get_versioned(*id, variable),
557                index_id,
558                Item::Vector(*elem, *vec),
559            ),
560            Variable::Shared(id, item) if self.compilation_options.supports_explicit_smem => {
561                let ptr_ty =
562                    Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
563                let index = vec![self.const_u32(0)];
564                let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
565                IndexedVariable::Pointer(id, item.clone())
566            }
567            Variable::Local { .. }
568            | Variable::Shared(..)
569            | Variable::LocalBinding { .. }
570            | Variable::Versioned { .. } => IndexedVariable::Scalar(variable.clone()),
571            Variable::Slice { ptr, offset, .. } => {
572                let item = Item::Scalar(Elem::Int(32, false));
573                let int = item.id(self);
574                let index = match index.as_const() {
575                    Some(ConstVal::Bit32(0)) => *offset,
576                    _ => self.i_add(int, None, *offset, index_id).unwrap(),
577                };
578                self.index(ptr, &Variable::Raw(index, item), unchecked)
579            }
580            Variable::SharedArray(id, item, _) => {
581                let ptr_ty =
582                    Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
583                let mut index = vec![index_id];
584                if self.compilation_options.supports_explicit_smem {
585                    index.insert(0, self.const_u32(0));
586                }
587                let id = access_chain(self, ptr_ty, None, *id, index).unwrap();
588                IndexedVariable::Pointer(id, item.clone())
589            }
590            Variable::ConstantArray(id, item, _) | Variable::LocalArray(id, item, _) => {
591                let ptr_ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
592                let id = access_chain(self, ptr_ty, None, *id, vec![index_id]).unwrap();
593                IndexedVariable::Pointer(id, item.clone())
594            }
595            var => unimplemented!("Can't index into {var:?}"),
596        }
597    }
598
599    pub fn read_indexed(&mut self, out: &Variable, variable: &Variable, index: &Variable) -> Word {
600        let always_in_bounds = is_always_in_bounds(variable, index);
601        let indexed = self.index(variable, index, always_in_bounds);
602
603        let read = |b: &mut Self| match indexed {
604            IndexedVariable::Pointer(ptr, item) => {
605                let ty = item.id(b);
606                let out_id = b.write_id(out);
607                b.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
608            }
609            IndexedVariable::Composite(var, index, item) => {
610                let elem = item.elem();
611                let ty = elem.id(b);
612                let out_id = b.write_id(out);
613                b.composite_extract(ty, Some(out_id), var, vec![index])
614                    .unwrap()
615            }
616            IndexedVariable::DynamicComposite(var, index, item) => {
617                let elem = item.elem();
618                let ty = elem.id(b);
619                let out_id = b.write_id(out);
620                b.vector_extract_dynamic(ty, Some(out_id), var, index)
621                    .unwrap()
622            }
623            IndexedVariable::Scalar(var) => {
624                let ty = out.item().id(b);
625                let input = b.read(&var);
626                let out_id = b.write_id(out);
627                b.copy_object(ty, Some(out_id), input).unwrap();
628                b.write(out, out_id);
629                out_id
630            }
631        };
632
633        read(self)
634    }
635
636    pub fn read_indexed_unchecked(
637        &mut self,
638        out: &Variable,
639        variable: &Variable,
640        index: &Variable,
641    ) -> Word {
642        let indexed = self.index(variable, index, true);
643
644        match indexed {
645            IndexedVariable::Pointer(ptr, item) => {
646                let ty = item.id(self);
647                let out_id = self.write_id(out);
648                self.load(ty, Some(out_id), ptr, None, vec![]).unwrap()
649            }
650            IndexedVariable::Composite(var, index, item) => {
651                let elem = item.elem();
652                let ty = elem.id(self);
653                let out_id = self.write_id(out);
654                self.composite_extract(ty, Some(out_id), var, vec![index])
655                    .unwrap()
656            }
657            IndexedVariable::DynamicComposite(var, index, item) => {
658                let elem = item.elem();
659                let ty = elem.id(self);
660                let out_id = self.write_id(out);
661                self.vector_extract_dynamic(ty, Some(out_id), var, index)
662                    .unwrap()
663            }
664            IndexedVariable::Scalar(var) => {
665                let ty = out.item().id(self);
666                let input = self.read(&var);
667                let out_id = self.write_id(out);
668                self.copy_object(ty, Some(out_id), input).unwrap();
669                self.write(out, out_id);
670                out_id
671            }
672        }
673    }
674
675    pub fn index_ptr(&mut self, var: &Variable, index: &Variable) -> Word {
676        let always_in_bounds = is_always_in_bounds(var, index);
677        match self.index(var, index, always_in_bounds) {
678            IndexedVariable::Pointer(ptr, _) => ptr,
679            other => unreachable!("{other:?}"),
680        }
681    }
682
683    pub fn write_id(&mut self, variable: &Variable) -> Word {
684        match variable {
685            Variable::LocalBinding { id, variable, .. } => self.get_binding(*id, variable),
686            Variable::Versioned { id, variable, .. } => self.get_versioned(*id, variable),
687            Variable::Local { .. } => self.id(),
688            Variable::Shared(..) => self.id(),
689            Variable::GlobalScalar(id, _) => *id,
690            Variable::Raw(id, _) => *id,
691            Variable::ConstantScalar(_, _, _) => panic!("Can't write to constant scalar"),
692            Variable::GlobalInputArray(_, _, _)
693            | Variable::GlobalOutputArray(_, _, _)
694            | Variable::Slice { .. }
695            | Variable::Named { .. }
696            | Variable::SharedArray(_, _, _)
697            | Variable::ConstantArray(_, _, _)
698            | Variable::LocalArray(_, _, _) => panic!("Can't write to unindexed array"),
699            global => panic!("Can't write to builtin {global:?}"),
700        }
701    }
702
703    pub fn write(&mut self, variable: &Variable, value: Word) {
704        match variable {
705            Variable::Shared(id, item) if self.compilation_options.supports_explicit_smem => {
706                let ptr_ty =
707                    Item::Pointer(StorageClass::Workgroup, Box::new(item.clone())).id(self);
708                let index = vec![self.const_u32(0)];
709                let access = self.access_chain(ptr_ty, None, *id, index).unwrap();
710                self.store(access, value, None, []).unwrap()
711            }
712            Variable::Local { id, .. } | Variable::Shared(id, _) => {
713                self.store(*id, value, None, []).unwrap()
714            }
715
716            Variable::Slice { ptr, .. } => self.write(ptr, value),
717            _ => {}
718        }
719    }
720
721    pub fn write_indexed(&mut self, out: &Variable, index: &Variable, value: Word) {
722        let always_in_bounds = is_always_in_bounds(out, index);
723        let variable = self.index(out, index, always_in_bounds);
724
725        let write = |b: &mut Self| match variable {
726            IndexedVariable::Pointer(ptr, _) => b.store(ptr, value, None, vec![]).unwrap(),
727            IndexedVariable::Composite(var, index, item) => {
728                let ty = item.id(b);
729                let id = b
730                    .composite_insert(ty, None, value, var, vec![index])
731                    .unwrap();
732                b.write(out, id);
733            }
734            IndexedVariable::DynamicComposite(var, index, item) => {
735                let ty = item.id(b);
736                let id = b
737                    .vector_insert_dynamic(ty, None, value, var, index)
738                    .unwrap();
739                b.write(out, id);
740            }
741            IndexedVariable::Scalar(var) => b.write(&var, value),
742        };
743
744        write(self)
745    }
746
747    pub fn write_indexed_unchecked(&mut self, out: &Variable, index: &Variable, value: Word) {
748        let variable = self.index(out, index, true);
749
750        match variable {
751            IndexedVariable::Pointer(ptr, _) => self.store(ptr, value, None, vec![]).unwrap(),
752            IndexedVariable::Composite(var, index, item) => {
753                let ty = item.id(self);
754                let out_id = self
755                    .composite_insert(ty, None, value, var, vec![index])
756                    .unwrap();
757                self.write(out, out_id);
758            }
759            IndexedVariable::DynamicComposite(var, index, item) => {
760                let ty = item.id(self);
761                let out_id = self
762                    .vector_insert_dynamic(ty, None, value, var, index)
763                    .unwrap();
764                self.write(out, out_id);
765            }
766            IndexedVariable::Scalar(var) => self.write(&var, value),
767        }
768    }
769}
770
771fn is_always_in_bounds(var: &Variable, index: &Variable) -> bool {
772    let len = match var {
773        Variable::SharedArray(_, _, len)
774        | Variable::ConstantArray(_, _, len)
775        | Variable::LocalArray(_, _, len)
776        | Variable::Slice {
777            const_len: Some(len),
778            ..
779        } => *len,
780        _ => return false,
781    };
782
783    let const_index = match index {
784        Variable::ConstantScalar(_, value, _) => value.as_u32(),
785        _ => return false,
786    };
787
788    const_index < len
789}