1use core::num::NonZero;
2use core::{fmt::Display, hash::Hash};
3
4use crate::{BarrierLevel, TypeHash};
5
6use super::{Elem, FloatKind, IntKind, Item, Matrix, UIntKind};
7use float_ord::FloatOrd;
8
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
11#[allow(missing_docs)]
12pub struct Variable {
13 pub kind: VariableKind,
14 pub item: Item,
15}
16
17impl Variable {
18 pub fn new(kind: VariableKind, item: Item) -> Self {
19 Self { kind, item }
20 }
21
22 pub fn builtin(builtin: Builtin) -> Self {
23 Self::new(
24 VariableKind::Builtin(builtin),
25 Item::new(Elem::UInt(UIntKind::U32)),
26 )
27 }
28
29 pub fn constant(scalar: ConstantScalarValue) -> Self {
30 let elem = match scalar {
31 ConstantScalarValue::Int(_, int_kind) => Elem::Int(int_kind),
32 ConstantScalarValue::Float(_, float_kind) => Elem::Float(float_kind),
33 ConstantScalarValue::UInt(_, kind) => Elem::UInt(kind),
34 ConstantScalarValue::Bool(_) => Elem::Bool,
35 };
36 Self::new(VariableKind::ConstantScalar(scalar), Item::new(elem))
37 }
38
39 pub fn elem(&self) -> Elem {
40 self.item.elem
41 }
42}
43
44pub type Id = u32;
45
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
48pub enum VariableKind {
49 GlobalInputArray(Id),
50 GlobalOutputArray(Id),
51 GlobalScalar(Id),
52 TensorMap(Id),
53 LocalArray {
54 id: Id,
55 length: u32,
56 },
57 LocalMut {
58 id: Id,
59 },
60 LocalConst {
61 id: Id,
62 },
63 Versioned {
64 id: Id,
65 version: u16,
66 },
67 ConstantScalar(ConstantScalarValue),
68 ConstantArray {
69 id: Id,
70 length: u32,
71 },
72 SharedMemory {
73 id: Id,
74 length: u32,
75 alignment: Option<u32>,
76 },
77 Matrix {
78 id: Id,
79 mat: Matrix,
80 },
81 Slice {
82 id: Id,
83 },
84 Builtin(Builtin),
85 Pipeline {
86 id: Id,
87 item: Item,
88 num_stages: u8,
89 },
90 Barrier {
91 id: Id,
92 item: Item,
93 level: BarrierLevel,
94 },
95}
96
97#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
99pub enum Builtin {
100 UnitPos,
101 UnitPosX,
102 UnitPosY,
103 UnitPosZ,
104 CubePosCluster,
105 CubePosClusterX,
106 CubePosClusterY,
107 CubePosClusterZ,
108 CubePos,
109 CubePosX,
110 CubePosY,
111 CubePosZ,
112 CubeDim,
113 CubeDimX,
114 CubeDimY,
115 CubeDimZ,
116 CubeClusterDim,
117 CubeClusterDimX,
118 CubeClusterDimY,
119 CubeClusterDimZ,
120 CubeCount,
121 CubeCountX,
122 CubeCountY,
123 CubeCountZ,
124 PlaneDim,
125 UnitPosPlane,
126 AbsolutePos,
127 AbsolutePosX,
128 AbsolutePosY,
129 AbsolutePosZ,
130}
131
132impl Variable {
133 pub fn is_immutable(&self) -> bool {
136 match self.kind {
137 VariableKind::GlobalOutputArray { .. } => false,
138 VariableKind::TensorMap(_) => false,
139 VariableKind::LocalMut { .. } => false,
140 VariableKind::SharedMemory { .. } => false,
141 VariableKind::Matrix { .. } => false,
142 VariableKind::Slice { .. } => false,
143 VariableKind::LocalArray { .. } => false,
144 VariableKind::GlobalInputArray { .. } => false,
145 VariableKind::GlobalScalar { .. } => true,
146 VariableKind::Versioned { .. } => true,
147 VariableKind::LocalConst { .. } => true,
148 VariableKind::ConstantScalar(_) => true,
149 VariableKind::ConstantArray { .. } => true,
150 VariableKind::Builtin(_) => true,
151 VariableKind::Pipeline { .. } => false,
152 VariableKind::Barrier { .. } => false,
153 }
154 }
155
156 pub fn is_array(&self) -> bool {
159 matches!(
160 self.kind,
161 VariableKind::GlobalInputArray { .. }
162 | VariableKind::GlobalOutputArray { .. }
163 | VariableKind::ConstantArray { .. }
164 | VariableKind::SharedMemory { .. }
165 | VariableKind::LocalArray { .. }
166 | VariableKind::Matrix { .. }
167 | VariableKind::Slice { .. }
168 )
169 }
170
171 pub fn has_length(&self) -> bool {
172 matches!(
173 self.kind,
174 VariableKind::GlobalInputArray { .. }
175 | VariableKind::GlobalOutputArray { .. }
176 | VariableKind::Slice { .. }
177 )
178 }
179
180 pub fn has_buffer_length(&self) -> bool {
181 matches!(
182 self.kind,
183 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
184 )
185 }
186
187 pub fn is_constant(&self, value: i64) -> bool {
189 match self.kind {
190 VariableKind::ConstantScalar(ConstantScalarValue::Int(val, _)) => val == value,
191 VariableKind::ConstantScalar(ConstantScalarValue::UInt(val, _)) => val as i64 == value,
192 VariableKind::ConstantScalar(ConstantScalarValue::Float(val, _)) => val == value as f64,
193 _ => false,
194 }
195 }
196
197 pub fn is_true(&self) -> bool {
199 match self.kind {
200 VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => val,
201 _ => false,
202 }
203 }
204
205 pub fn is_false(&self) -> bool {
207 match self.kind {
208 VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => !val,
209 _ => false,
210 }
211 }
212}
213
214#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
217#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd)]
218#[allow(missing_docs)]
219pub enum ConstantScalarValue {
220 Int(i64, IntKind),
221 Float(f64, FloatKind),
222 UInt(u64, UIntKind),
223 Bool(bool),
224}
225
226impl Eq for ConstantScalarValue {}
227impl Hash for ConstantScalarValue {
228 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
229 core::mem::discriminant(self).hash(ra_expand_state);
230 match self {
231 ConstantScalarValue::Int(f0, f1) => {
232 f0.hash(ra_expand_state);
233 f1.hash(ra_expand_state);
234 }
235 ConstantScalarValue::Float(f0, f1) => {
236 FloatOrd(*f0).hash(ra_expand_state);
237 f1.hash(ra_expand_state);
238 }
239 ConstantScalarValue::UInt(f0, f1) => {
240 f0.hash(ra_expand_state);
241 f1.hash(ra_expand_state);
242 }
243 ConstantScalarValue::Bool(f0) => {
244 f0.hash(ra_expand_state);
245 }
246 }
247 }
248}
249
250impl ConstantScalarValue {
251 pub fn elem(&self) -> Elem {
253 match self {
254 ConstantScalarValue::Int(_, kind) => Elem::Int(*kind),
255 ConstantScalarValue::Float(_, kind) => Elem::Float(*kind),
256 ConstantScalarValue::UInt(_, kind) => Elem::UInt(*kind),
257 ConstantScalarValue::Bool(_) => Elem::Bool,
258 }
259 }
260
261 pub fn try_as_usize(&self) -> Option<usize> {
265 match self {
266 ConstantScalarValue::UInt(val, _) => Some(*val as usize),
267 ConstantScalarValue::Int(val, _) => Some(*val as usize),
268 ConstantScalarValue::Float(_, _) => None,
269 ConstantScalarValue::Bool(_) => None,
270 }
271 }
272
273 pub fn as_usize(&self) -> usize {
277 self.try_as_usize()
278 .expect("Only Int and UInt kind can be made into usize.")
279 }
280
281 pub fn try_as_u32(&self) -> Option<u32> {
285 match self {
286 ConstantScalarValue::UInt(val, _) => Some(*val as u32),
287 ConstantScalarValue::Int(val, _) => Some(*val as u32),
288 ConstantScalarValue::Float(_, _) => None,
289 ConstantScalarValue::Bool(_) => None,
290 }
291 }
292
293 pub fn as_u32(&self) -> u32 {
297 self.try_as_u32()
298 .expect("Only Int and UInt kind can be made into u32.")
299 }
300
301 pub fn try_as_u64(&self) -> Option<u64> {
305 match self {
306 ConstantScalarValue::UInt(val, _) => Some(*val),
307 ConstantScalarValue::Int(val, _) => Some(*val as u64),
308 ConstantScalarValue::Float(_, _) => None,
309 ConstantScalarValue::Bool(_) => None,
310 }
311 }
312
313 pub fn as_u64(&self) -> u64 {
317 self.try_as_u64()
318 .expect("Only Int and UInt kind can be made into u64.")
319 }
320
321 pub fn try_as_i64(&self) -> Option<i64> {
325 match self {
326 ConstantScalarValue::UInt(val, _) => Some(*val as i64),
327 ConstantScalarValue::Int(val, _) => Some(*val),
328 ConstantScalarValue::Float(_, _) => None,
329 ConstantScalarValue::Bool(_) => None,
330 }
331 }
332
333 pub fn as_i64(&self) -> i64 {
337 self.try_as_i64()
338 .expect("Only Int and UInt kind can be made into i64.")
339 }
340
341 pub fn try_as_bool(&self) -> Option<bool> {
343 match self {
344 ConstantScalarValue::Bool(val) => Some(*val),
345 _ => None,
346 }
347 }
348
349 pub fn as_bool(&self) -> bool {
353 self.try_as_bool()
354 .expect("Only bool can be made into a bool")
355 }
356
357 pub fn is_zero(&self) -> bool {
358 match self {
359 ConstantScalarValue::Int(val, _) => *val == 0,
360 ConstantScalarValue::Float(val, _) => *val == 0.0,
361 ConstantScalarValue::UInt(val, _) => *val == 0,
362 ConstantScalarValue::Bool(_) => false,
363 }
364 }
365
366 pub fn is_one(&self) -> bool {
367 match self {
368 ConstantScalarValue::Int(val, _) => *val == 1,
369 ConstantScalarValue::Float(val, _) => *val == 1.0,
370 ConstantScalarValue::UInt(val, _) => *val == 1,
371 ConstantScalarValue::Bool(_) => false,
372 }
373 }
374
375 pub fn cast_to(&self, other: Elem) -> ConstantScalarValue {
376 match (self, other) {
377 (ConstantScalarValue::Int(val, _), Elem::Float(float_kind)) => {
378 ConstantScalarValue::Float(*val as f64, float_kind)
379 }
380 (ConstantScalarValue::Int(val, _), Elem::Int(int_kind)) => {
381 ConstantScalarValue::Int(*val, int_kind)
382 }
383 (ConstantScalarValue::Int(val, _), Elem::UInt(kind)) => {
384 ConstantScalarValue::UInt(*val as u64, kind)
385 }
386 (ConstantScalarValue::Int(val, _), Elem::Bool) => ConstantScalarValue::Bool(*val == 1),
387 (ConstantScalarValue::Float(val, _), Elem::Float(float_kind)) => {
388 ConstantScalarValue::Float(*val, float_kind)
389 }
390 (ConstantScalarValue::Float(val, _), Elem::Int(int_kind)) => {
391 ConstantScalarValue::Int(*val as i64, int_kind)
392 }
393 (ConstantScalarValue::Float(val, _), Elem::UInt(kind)) => {
394 ConstantScalarValue::UInt(*val as u64, kind)
395 }
396 (ConstantScalarValue::Float(val, _), Elem::Bool) => {
397 ConstantScalarValue::Bool(*val == 0.0)
398 }
399 (ConstantScalarValue::UInt(val, _), Elem::Float(float_kind)) => {
400 ConstantScalarValue::Float(*val as f64, float_kind)
401 }
402 (ConstantScalarValue::UInt(val, _), Elem::Int(int_kind)) => {
403 ConstantScalarValue::Int(*val as i64, int_kind)
404 }
405 (ConstantScalarValue::UInt(val, _), Elem::UInt(kind)) => {
406 ConstantScalarValue::UInt(*val, kind)
407 }
408 (ConstantScalarValue::UInt(val, _), Elem::Bool) => ConstantScalarValue::Bool(*val == 1),
409 (ConstantScalarValue::Bool(val), Elem::Float(float_kind)) => {
410 ConstantScalarValue::Float(*val as u32 as f64, float_kind)
411 }
412 (ConstantScalarValue::Bool(val), Elem::Int(int_kind)) => {
413 ConstantScalarValue::Int(*val as i64, int_kind)
414 }
415 (ConstantScalarValue::Bool(val), Elem::UInt(kind)) => {
416 ConstantScalarValue::UInt(*val as u64, kind)
417 }
418 (ConstantScalarValue::Bool(val), Elem::Bool) => ConstantScalarValue::Bool(*val),
419 _ => unreachable!(),
420 }
421 }
422}
423
424impl Display for ConstantScalarValue {
425 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
426 match self {
427 ConstantScalarValue::Int(val, IntKind::I8) => write!(f, "{val}i8"),
428 ConstantScalarValue::Int(val, IntKind::I16) => write!(f, "{val}i16"),
429 ConstantScalarValue::Int(val, IntKind::I32) => write!(f, "{val}i32"),
430 ConstantScalarValue::Int(val, IntKind::I64) => write!(f, "{val}i64"),
431 ConstantScalarValue::Float(val, FloatKind::BF16) => write!(f, "{val}bf16"),
432 ConstantScalarValue::Float(val, FloatKind::F16) => write!(f, "{val}f16"),
433 ConstantScalarValue::Float(val, FloatKind::TF32) => write!(f, "{val}tf32"),
434 ConstantScalarValue::Float(val, FloatKind::Flex32) => write!(f, "{val}flex32"),
435 ConstantScalarValue::Float(val, FloatKind::F32) => write!(f, "{val}f32"),
436 ConstantScalarValue::Float(val, FloatKind::F64) => write!(f, "{val}f64"),
437 ConstantScalarValue::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
438 ConstantScalarValue::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
439 ConstantScalarValue::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
440 ConstantScalarValue::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
441 ConstantScalarValue::Bool(val) => write!(f, "{val}"),
442 }
443 }
444}
445
446impl Variable {
447 pub fn vectorization_factor(&self) -> u8 {
448 self.item.vectorization.map(NonZero::get).unwrap_or(1u8)
449 }
450
451 pub fn index(&self) -> Option<Id> {
452 match self.kind {
453 VariableKind::GlobalInputArray(id)
454 | VariableKind::GlobalOutputArray(id)
455 | VariableKind::TensorMap(id)
456 | VariableKind::GlobalScalar(id)
457 | VariableKind::LocalMut { id, .. }
458 | VariableKind::Versioned { id, .. }
459 | VariableKind::LocalConst { id, .. }
460 | VariableKind::Slice { id, .. }
461 | VariableKind::ConstantArray { id, .. }
462 | VariableKind::SharedMemory { id, .. }
463 | VariableKind::LocalArray { id, .. }
464 | VariableKind::Matrix { id, .. } => Some(id),
465 _ => None,
466 }
467 }
468
469 pub fn as_const(&self) -> Option<ConstantScalarValue> {
470 match self.kind {
471 VariableKind::ConstantScalar(constant) => Some(constant),
472 _ => None,
473 }
474 }
475}
476
477impl Display for Variable {
478 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
479 match self.kind {
480 VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
481 VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
482 VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"),
483 VariableKind::TensorMap(id) => write!(f, "tensor_map({id})"),
484 VariableKind::ConstantScalar(constant) => write!(f, "{constant}"),
485 VariableKind::LocalMut { id } => write!(f, "local({id})"),
486 VariableKind::Versioned { id, version } => {
487 write!(f, "local({id}).v{version}")
488 }
489 VariableKind::LocalConst { id } => write!(f, "binding({id})"),
490 VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
491 VariableKind::SharedMemory { id, .. } => write!(f, "shared({id})"),
492 VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
493 VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
494 VariableKind::Slice { id } => write!(f, "slice({id})"),
495 VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
496 VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
497 VariableKind::Barrier { id, .. } => write!(f, "barrier({id})"),
498 }
499 }
500}
501
502impl From<&Variable> for Variable {
504 fn from(value: &Variable) -> Self {
505 *value
506 }
507}