1use core::num::NonZero;
2use core::{fmt::Display, hash::Hash};
3
4use crate::{BarrierLevel, TypeHash};
5
6use super::{Elem, FloatKind, IntKind, Item, Matrix, UIntKind};
7use float_ord::FloatOrd;
8
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
11#[allow(missing_docs)]
12pub struct Variable {
13 pub kind: VariableKind,
14 pub item: Item,
15}
16
17impl Variable {
18 pub fn new(kind: VariableKind, item: Item) -> Self {
19 Self { kind, item }
20 }
21
22 pub fn builtin(builtin: Builtin) -> Self {
23 Self::new(
24 VariableKind::Builtin(builtin),
25 Item::new(Elem::UInt(UIntKind::U32)),
26 )
27 }
28
29 pub fn constant(scalar: ConstantScalarValue) -> Self {
30 let elem = match scalar {
31 ConstantScalarValue::Int(_, int_kind) => Elem::Int(int_kind),
32 ConstantScalarValue::Float(_, float_kind) => Elem::Float(float_kind),
33 ConstantScalarValue::UInt(_, kind) => Elem::UInt(kind),
34 ConstantScalarValue::Bool(_) => Elem::Bool,
35 };
36 Self::new(VariableKind::ConstantScalar(scalar), Item::new(elem))
37 }
38
39 pub fn elem(&self) -> Elem {
40 self.item.elem
41 }
42}
43
44pub type Id = u32;
45
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
48pub enum VariableKind {
49 GlobalInputArray(Id),
50 GlobalOutputArray(Id),
51 GlobalScalar(Id),
52 TensorMap(Id),
53 LocalArray {
54 id: Id,
55 length: u32,
56 },
57 LocalMut {
58 id: Id,
59 },
60 LocalConst {
61 id: Id,
62 },
63 Versioned {
64 id: Id,
65 version: u16,
66 },
67 ConstantScalar(ConstantScalarValue),
68 ConstantArray {
69 id: Id,
70 length: u32,
71 },
72 SharedMemory {
73 id: Id,
74 length: u32,
75 alignment: Option<u32>,
76 },
77 Matrix {
78 id: Id,
79 mat: Matrix,
80 },
81 Builtin(Builtin),
82 Pipeline {
83 id: Id,
84 item: Item,
85 num_stages: u8,
86 },
87 Barrier {
88 id: Id,
89 item: Item,
90 level: BarrierLevel,
91 },
92}
93
94#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
96pub enum Builtin {
97 UnitPos,
98 UnitPosX,
99 UnitPosY,
100 UnitPosZ,
101 CubePosCluster,
102 CubePosClusterX,
103 CubePosClusterY,
104 CubePosClusterZ,
105 CubePos,
106 CubePosX,
107 CubePosY,
108 CubePosZ,
109 CubeDim,
110 CubeDimX,
111 CubeDimY,
112 CubeDimZ,
113 CubeClusterDim,
114 CubeClusterDimX,
115 CubeClusterDimY,
116 CubeClusterDimZ,
117 CubeCount,
118 CubeCountX,
119 CubeCountY,
120 CubeCountZ,
121 PlaneDim,
122 UnitPosPlane,
123 AbsolutePos,
124 AbsolutePosX,
125 AbsolutePosY,
126 AbsolutePosZ,
127}
128
129impl Variable {
130 pub fn is_immutable(&self) -> bool {
133 match self.kind {
134 VariableKind::GlobalOutputArray { .. } => false,
135 VariableKind::TensorMap(_) => false,
136 VariableKind::LocalMut { .. } => false,
137 VariableKind::SharedMemory { .. } => false,
138 VariableKind::Matrix { .. } => false,
139 VariableKind::LocalArray { .. } => false,
140 VariableKind::GlobalInputArray { .. } => false,
141 VariableKind::GlobalScalar { .. } => true,
142 VariableKind::Versioned { .. } => true,
143 VariableKind::LocalConst { .. } => true,
144 VariableKind::ConstantScalar(_) => true,
145 VariableKind::ConstantArray { .. } => true,
146 VariableKind::Builtin(_) => true,
147 VariableKind::Pipeline { .. } => false,
148 VariableKind::Barrier { .. } => false,
149 }
150 }
151
152 pub fn is_array(&self) -> bool {
155 matches!(
156 self.kind,
157 VariableKind::GlobalInputArray { .. }
158 | VariableKind::GlobalOutputArray { .. }
159 | VariableKind::ConstantArray { .. }
160 | VariableKind::SharedMemory { .. }
161 | VariableKind::LocalArray { .. }
162 | VariableKind::Matrix { .. }
163 )
164 }
165
166 pub fn has_length(&self) -> bool {
167 matches!(
168 self.kind,
169 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
170 )
171 }
172
173 pub fn has_buffer_length(&self) -> bool {
174 matches!(
175 self.kind,
176 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
177 )
178 }
179
180 pub fn is_constant(&self, value: i64) -> bool {
182 match self.kind {
183 VariableKind::ConstantScalar(ConstantScalarValue::Int(val, _)) => val == value,
184 VariableKind::ConstantScalar(ConstantScalarValue::UInt(val, _)) => val as i64 == value,
185 VariableKind::ConstantScalar(ConstantScalarValue::Float(val, _)) => val == value as f64,
186 _ => false,
187 }
188 }
189
190 pub fn is_true(&self) -> bool {
192 match self.kind {
193 VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => val,
194 _ => false,
195 }
196 }
197
198 pub fn is_false(&self) -> bool {
200 match self.kind {
201 VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => !val,
202 _ => false,
203 }
204 }
205}
206
207#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
210#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd)]
211#[allow(missing_docs)]
212pub enum ConstantScalarValue {
213 Int(i64, IntKind),
214 Float(f64, FloatKind),
215 UInt(u64, UIntKind),
216 Bool(bool),
217}
218
219impl Eq for ConstantScalarValue {}
220impl Hash for ConstantScalarValue {
221 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
222 core::mem::discriminant(self).hash(ra_expand_state);
223 match self {
224 ConstantScalarValue::Int(f0, f1) => {
225 f0.hash(ra_expand_state);
226 f1.hash(ra_expand_state);
227 }
228 ConstantScalarValue::Float(f0, f1) => {
229 FloatOrd(*f0).hash(ra_expand_state);
230 f1.hash(ra_expand_state);
231 }
232 ConstantScalarValue::UInt(f0, f1) => {
233 f0.hash(ra_expand_state);
234 f1.hash(ra_expand_state);
235 }
236 ConstantScalarValue::Bool(f0) => {
237 f0.hash(ra_expand_state);
238 }
239 }
240 }
241}
242
243impl ConstantScalarValue {
244 pub fn elem(&self) -> Elem {
246 match self {
247 ConstantScalarValue::Int(_, kind) => Elem::Int(*kind),
248 ConstantScalarValue::Float(_, kind) => Elem::Float(*kind),
249 ConstantScalarValue::UInt(_, kind) => Elem::UInt(*kind),
250 ConstantScalarValue::Bool(_) => Elem::Bool,
251 }
252 }
253
254 pub fn try_as_usize(&self) -> Option<usize> {
258 match self {
259 ConstantScalarValue::UInt(val, _) => Some(*val as usize),
260 ConstantScalarValue::Int(val, _) => Some(*val as usize),
261 ConstantScalarValue::Float(_, _) => None,
262 ConstantScalarValue::Bool(_) => None,
263 }
264 }
265
266 pub fn as_usize(&self) -> usize {
270 self.try_as_usize()
271 .expect("Only Int and UInt kind can be made into usize.")
272 }
273
274 pub fn try_as_u32(&self) -> Option<u32> {
278 match self {
279 ConstantScalarValue::UInt(val, _) => Some(*val as u32),
280 ConstantScalarValue::Int(val, _) => Some(*val as u32),
281 ConstantScalarValue::Float(_, _) => None,
282 ConstantScalarValue::Bool(_) => None,
283 }
284 }
285
286 pub fn as_u32(&self) -> u32 {
290 self.try_as_u32()
291 .expect("Only Int and UInt kind can be made into u32.")
292 }
293
294 pub fn try_as_u64(&self) -> Option<u64> {
298 match self {
299 ConstantScalarValue::UInt(val, _) => Some(*val),
300 ConstantScalarValue::Int(val, _) => Some(*val as u64),
301 ConstantScalarValue::Float(_, _) => None,
302 ConstantScalarValue::Bool(_) => None,
303 }
304 }
305
306 pub fn as_u64(&self) -> u64 {
310 self.try_as_u64()
311 .expect("Only Int and UInt kind can be made into u64.")
312 }
313
314 pub fn try_as_i64(&self) -> Option<i64> {
318 match self {
319 ConstantScalarValue::UInt(val, _) => Some(*val as i64),
320 ConstantScalarValue::Int(val, _) => Some(*val),
321 ConstantScalarValue::Float(_, _) => None,
322 ConstantScalarValue::Bool(_) => None,
323 }
324 }
325
326 pub fn as_i64(&self) -> i64 {
330 self.try_as_i64()
331 .expect("Only Int and UInt kind can be made into i64.")
332 }
333
334 pub fn try_as_bool(&self) -> Option<bool> {
336 match self {
337 ConstantScalarValue::Bool(val) => Some(*val),
338 _ => None,
339 }
340 }
341
342 pub fn as_bool(&self) -> bool {
346 self.try_as_bool()
347 .expect("Only bool can be made into a bool")
348 }
349
350 pub fn is_zero(&self) -> bool {
351 match self {
352 ConstantScalarValue::Int(val, _) => *val == 0,
353 ConstantScalarValue::Float(val, _) => *val == 0.0,
354 ConstantScalarValue::UInt(val, _) => *val == 0,
355 ConstantScalarValue::Bool(_) => false,
356 }
357 }
358
359 pub fn is_one(&self) -> bool {
360 match self {
361 ConstantScalarValue::Int(val, _) => *val == 1,
362 ConstantScalarValue::Float(val, _) => *val == 1.0,
363 ConstantScalarValue::UInt(val, _) => *val == 1,
364 ConstantScalarValue::Bool(_) => false,
365 }
366 }
367
368 pub fn cast_to(&self, other: Elem) -> ConstantScalarValue {
369 match (self, other) {
370 (ConstantScalarValue::Int(val, _), Elem::Float(float_kind)) => {
371 ConstantScalarValue::Float(*val as f64, float_kind)
372 }
373 (ConstantScalarValue::Int(val, _), Elem::Int(int_kind)) => {
374 ConstantScalarValue::Int(*val, int_kind)
375 }
376 (ConstantScalarValue::Int(val, _), Elem::UInt(kind)) => {
377 ConstantScalarValue::UInt(*val as u64, kind)
378 }
379 (ConstantScalarValue::Int(val, _), Elem::Bool) => ConstantScalarValue::Bool(*val == 1),
380 (ConstantScalarValue::Float(val, _), Elem::Float(float_kind)) => {
381 ConstantScalarValue::Float(*val, float_kind)
382 }
383 (ConstantScalarValue::Float(val, _), Elem::Int(int_kind)) => {
384 ConstantScalarValue::Int(*val as i64, int_kind)
385 }
386 (ConstantScalarValue::Float(val, _), Elem::UInt(kind)) => {
387 ConstantScalarValue::UInt(*val as u64, kind)
388 }
389 (ConstantScalarValue::Float(val, _), Elem::Bool) => {
390 ConstantScalarValue::Bool(*val == 0.0)
391 }
392 (ConstantScalarValue::UInt(val, _), Elem::Float(float_kind)) => {
393 ConstantScalarValue::Float(*val as f64, float_kind)
394 }
395 (ConstantScalarValue::UInt(val, _), Elem::Int(int_kind)) => {
396 ConstantScalarValue::Int(*val as i64, int_kind)
397 }
398 (ConstantScalarValue::UInt(val, _), Elem::UInt(kind)) => {
399 ConstantScalarValue::UInt(*val, kind)
400 }
401 (ConstantScalarValue::UInt(val, _), Elem::Bool) => ConstantScalarValue::Bool(*val == 1),
402 (ConstantScalarValue::Bool(val), Elem::Float(float_kind)) => {
403 ConstantScalarValue::Float(*val as u32 as f64, float_kind)
404 }
405 (ConstantScalarValue::Bool(val), Elem::Int(int_kind)) => {
406 ConstantScalarValue::Int(*val as i64, int_kind)
407 }
408 (ConstantScalarValue::Bool(val), Elem::UInt(kind)) => {
409 ConstantScalarValue::UInt(*val as u64, kind)
410 }
411 (ConstantScalarValue::Bool(val), Elem::Bool) => ConstantScalarValue::Bool(*val),
412 _ => unreachable!(),
413 }
414 }
415}
416
417impl Display for ConstantScalarValue {
418 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
419 match self {
420 ConstantScalarValue::Int(val, IntKind::I8) => write!(f, "{val}i8"),
421 ConstantScalarValue::Int(val, IntKind::I16) => write!(f, "{val}i16"),
422 ConstantScalarValue::Int(val, IntKind::I32) => write!(f, "{val}i32"),
423 ConstantScalarValue::Int(val, IntKind::I64) => write!(f, "{val}i64"),
424 ConstantScalarValue::Float(val, FloatKind::E2M1) => write!(f, "{val}e2m1"),
425 ConstantScalarValue::Float(val, FloatKind::E2M3) => write!(f, "{val}e2m3"),
426 ConstantScalarValue::Float(val, FloatKind::E3M2) => write!(f, "{val}e3m2"),
427 ConstantScalarValue::Float(val, FloatKind::E4M3) => write!(f, "{val}e4m3"),
428 ConstantScalarValue::Float(val, FloatKind::E5M2) => write!(f, "{val}e5m2"),
429 ConstantScalarValue::Float(val, FloatKind::UE8M0) => write!(f, "{val}ue8m0"),
430 ConstantScalarValue::Float(val, FloatKind::BF16) => write!(f, "{val}bf16"),
431 ConstantScalarValue::Float(val, FloatKind::F16) => write!(f, "{val}f16"),
432 ConstantScalarValue::Float(val, FloatKind::TF32) => write!(f, "{val}tf32"),
433 ConstantScalarValue::Float(val, FloatKind::Flex32) => write!(f, "{val}flex32"),
434 ConstantScalarValue::Float(val, FloatKind::F32) => write!(f, "{val}f32"),
435 ConstantScalarValue::Float(val, FloatKind::F64) => write!(f, "{val}f64"),
436 ConstantScalarValue::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
437 ConstantScalarValue::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
438 ConstantScalarValue::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
439 ConstantScalarValue::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
440 ConstantScalarValue::Bool(val) => write!(f, "{val}"),
441 }
442 }
443}
444
445impl Variable {
446 pub fn vectorization_factor(&self) -> u8 {
447 self.item.vectorization.map(NonZero::get).unwrap_or(1u8)
448 }
449
450 pub fn index(&self) -> Option<Id> {
451 match self.kind {
452 VariableKind::GlobalInputArray(id)
453 | VariableKind::GlobalOutputArray(id)
454 | VariableKind::TensorMap(id)
455 | VariableKind::GlobalScalar(id)
456 | VariableKind::LocalMut { id, .. }
457 | VariableKind::Versioned { id, .. }
458 | VariableKind::LocalConst { id, .. }
459 | VariableKind::ConstantArray { id, .. }
460 | VariableKind::SharedMemory { id, .. }
461 | VariableKind::LocalArray { id, .. }
462 | VariableKind::Matrix { id, .. } => Some(id),
463 _ => None,
464 }
465 }
466
467 pub fn as_const(&self) -> Option<ConstantScalarValue> {
468 match self.kind {
469 VariableKind::ConstantScalar(constant) => Some(constant),
470 _ => None,
471 }
472 }
473}
474
475impl Display for Variable {
476 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
477 match self.kind {
478 VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
479 VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
480 VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"),
481 VariableKind::TensorMap(id) => write!(f, "tensor_map({id})"),
482 VariableKind::ConstantScalar(constant) => write!(f, "{constant}"),
483 VariableKind::LocalMut { id } => write!(f, "local({id})"),
484 VariableKind::Versioned { id, version } => {
485 write!(f, "local({id}).v{version}")
486 }
487 VariableKind::LocalConst { id } => write!(f, "binding({id})"),
488 VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
489 VariableKind::SharedMemory { id, .. } => write!(f, "shared({id})"),
490 VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
491 VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
492 VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
493 VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
494 VariableKind::Barrier { id, .. } => write!(f, "barrier({id})"),
495 }
496 }
497}
498
499impl From<&Variable> for Variable {
501 fn from(value: &Variable) -> Self {
502 *value
503 }
504}