1use core::{fmt::Display, hash::Hash};
2
3use crate::{BarrierLevel, StorageType, TypeHash};
4
5use super::{ElemType, FloatKind, IntKind, Matrix, Type, UIntKind};
6use float_ord::FloatOrd;
7
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
10#[allow(missing_docs)]
11pub struct Variable {
12 pub kind: VariableKind,
13 pub ty: Type,
14}
15
16impl Variable {
17 pub fn new(kind: VariableKind, item: Type) -> Self {
18 Self { kind, ty: item }
19 }
20
21 pub fn builtin(builtin: Builtin) -> Self {
22 Self::new(
23 VariableKind::Builtin(builtin),
24 Type::scalar(ElemType::UInt(UIntKind::U32)),
25 )
26 }
27
28 pub fn constant(scalar: ConstantScalarValue) -> Self {
29 let elem = match scalar {
30 ConstantScalarValue::Int(_, int_kind) => ElemType::Int(int_kind),
31 ConstantScalarValue::Float(_, float_kind) => ElemType::Float(float_kind),
32 ConstantScalarValue::UInt(_, kind) => ElemType::UInt(kind),
33 ConstantScalarValue::Bool(_) => ElemType::Bool,
34 };
35 Self::new(VariableKind::ConstantScalar(scalar), Type::scalar(elem))
36 }
37
38 pub fn elem_type(&self) -> ElemType {
39 self.ty.elem_type()
40 }
41
42 pub fn storage_type(&self) -> StorageType {
43 self.ty.storage_type()
44 }
45}
46
47pub type Id = u32;
48
49#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
50#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
51pub enum VariableKind {
52 GlobalInputArray(Id),
53 GlobalOutputArray(Id),
54 GlobalScalar(Id),
55 TensorMapInput(Id),
56 TensorMapOutput(Id),
57 LocalArray {
58 id: Id,
59 length: u32,
60 unroll_factor: u32,
61 },
62 LocalMut {
63 id: Id,
64 },
65 LocalConst {
66 id: Id,
67 },
68 Versioned {
69 id: Id,
70 version: u16,
71 },
72 ConstantScalar(ConstantScalarValue),
73 ConstantArray {
74 id: Id,
75 length: u32,
76 unroll_factor: u32,
77 },
78 SharedArray {
79 id: Id,
80 length: u32,
81 unroll_factor: u32,
82 alignment: Option<u32>,
83 },
84 Shared {
85 id: Id,
86 },
87 Matrix {
88 id: Id,
89 mat: Matrix,
90 },
91 Builtin(Builtin),
92 Pipeline {
93 id: Id,
94 num_stages: u8,
95 },
96 BarrierToken {
97 id: Id,
98 level: BarrierLevel,
99 },
100}
101
102#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
104#[repr(u8)]
105pub enum Builtin {
106 UnitPos,
107 UnitPosX,
108 UnitPosY,
109 UnitPosZ,
110 CubePosCluster,
111 CubePosClusterX,
112 CubePosClusterY,
113 CubePosClusterZ,
114 CubePos,
115 CubePosX,
116 CubePosY,
117 CubePosZ,
118 CubeDim,
119 CubeDimX,
120 CubeDimY,
121 CubeDimZ,
122 CubeClusterDim,
123 CubeClusterDimX,
124 CubeClusterDimY,
125 CubeClusterDimZ,
126 CubeCount,
127 CubeCountX,
128 CubeCountY,
129 CubeCountZ,
130 PlaneDim,
131 UnitPosPlane,
132 AbsolutePos,
133 AbsolutePosX,
134 AbsolutePosY,
135 AbsolutePosZ,
136}
137
138impl Variable {
139 pub fn is_immutable(&self) -> bool {
142 match self.kind {
143 VariableKind::GlobalOutputArray { .. } => false,
144 VariableKind::TensorMapInput(_) => true,
145 VariableKind::TensorMapOutput(_) => false,
146 VariableKind::LocalMut { .. } => false,
147 VariableKind::SharedArray { .. } => false,
148 VariableKind::Shared { .. } => false,
149 VariableKind::Matrix { .. } => false,
150 VariableKind::LocalArray { .. } => false,
151 VariableKind::GlobalInputArray { .. } => false,
152 VariableKind::GlobalScalar { .. } => true,
153 VariableKind::Versioned { .. } => true,
154 VariableKind::LocalConst { .. } => true,
155 VariableKind::ConstantScalar(_) => true,
156 VariableKind::ConstantArray { .. } => true,
157 VariableKind::Builtin(_) => true,
158 VariableKind::Pipeline { .. } => false,
159 VariableKind::BarrierToken { .. } => false,
160 }
161 }
162
163 pub fn is_array(&self) -> bool {
166 matches!(
167 self.kind,
168 VariableKind::GlobalInputArray { .. }
169 | VariableKind::GlobalOutputArray { .. }
170 | VariableKind::ConstantArray { .. }
171 | VariableKind::SharedArray { .. }
172 | VariableKind::LocalArray { .. }
173 | VariableKind::Matrix { .. }
174 )
175 }
176
177 pub fn has_length(&self) -> bool {
178 matches!(
179 self.kind,
180 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
181 )
182 }
183
184 pub fn has_buffer_length(&self) -> bool {
185 matches!(
186 self.kind,
187 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
188 )
189 }
190
191 pub fn is_constant(&self, value: i64) -> bool {
193 match self.kind {
194 VariableKind::ConstantScalar(ConstantScalarValue::Int(val, _)) => val == value,
195 VariableKind::ConstantScalar(ConstantScalarValue::UInt(val, _)) => val as i64 == value,
196 VariableKind::ConstantScalar(ConstantScalarValue::Float(val, _)) => val == value as f64,
197 _ => false,
198 }
199 }
200
201 pub fn is_true(&self) -> bool {
203 match self.kind {
204 VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => val,
205 _ => false,
206 }
207 }
208
209 pub fn is_false(&self) -> bool {
211 match self.kind {
212 VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => !val,
213 _ => false,
214 }
215 }
216}
217
218#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
221#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd)]
222#[allow(missing_docs)]
223pub enum ConstantScalarValue {
224 Int(i64, IntKind),
225 Float(f64, FloatKind),
226 UInt(u64, UIntKind),
227 Bool(bool),
228}
229
230impl Eq for ConstantScalarValue {}
231impl Hash for ConstantScalarValue {
232 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
233 core::mem::discriminant(self).hash(ra_expand_state);
234 match self {
235 ConstantScalarValue::Int(f0, f1) => {
236 f0.hash(ra_expand_state);
237 f1.hash(ra_expand_state);
238 }
239 ConstantScalarValue::Float(f0, f1) => {
240 FloatOrd(*f0).hash(ra_expand_state);
241 f1.hash(ra_expand_state);
242 }
243 ConstantScalarValue::UInt(f0, f1) => {
244 f0.hash(ra_expand_state);
245 f1.hash(ra_expand_state);
246 }
247 ConstantScalarValue::Bool(f0) => {
248 f0.hash(ra_expand_state);
249 }
250 }
251 }
252}
253
254impl ConstantScalarValue {
255 pub fn elem_type(&self) -> ElemType {
257 match self {
258 ConstantScalarValue::Int(_, kind) => ElemType::Int(*kind),
259 ConstantScalarValue::Float(_, kind) => ElemType::Float(*kind),
260 ConstantScalarValue::UInt(_, kind) => ElemType::UInt(*kind),
261 ConstantScalarValue::Bool(_) => ElemType::Bool,
262 }
263 }
264
265 pub fn storage_type(&self) -> StorageType {
266 self.elem_type().into()
267 }
268
269 pub fn try_as_usize(&self) -> Option<usize> {
273 match self {
274 ConstantScalarValue::UInt(val, _) => Some(*val as usize),
275 ConstantScalarValue::Int(val, _) => Some(*val as usize),
276 ConstantScalarValue::Float(_, _) => None,
277 ConstantScalarValue::Bool(_) => None,
278 }
279 }
280
281 pub fn as_usize(&self) -> usize {
285 self.try_as_usize()
286 .expect("Only Int and UInt kind can be made into usize.")
287 }
288
289 pub fn try_as_u32(&self) -> Option<u32> {
293 match self {
294 ConstantScalarValue::UInt(val, _) => Some(*val as u32),
295 ConstantScalarValue::Int(val, _) => Some(*val as u32),
296 ConstantScalarValue::Float(_, _) => None,
297 ConstantScalarValue::Bool(_) => None,
298 }
299 }
300
301 pub fn as_u32(&self) -> u32 {
305 self.try_as_u32()
306 .expect("Only Int and UInt kind can be made into u32.")
307 }
308
309 pub fn try_as_u64(&self) -> Option<u64> {
313 match self {
314 ConstantScalarValue::UInt(val, _) => Some(*val),
315 ConstantScalarValue::Int(val, _) => Some(*val as u64),
316 ConstantScalarValue::Float(_, _) => None,
317 ConstantScalarValue::Bool(_) => None,
318 }
319 }
320
321 pub fn as_u64(&self) -> u64 {
325 self.try_as_u64()
326 .expect("Only Int and UInt kind can be made into u64.")
327 }
328
329 pub fn try_as_i64(&self) -> Option<i64> {
333 match self {
334 ConstantScalarValue::UInt(val, _) => Some(*val as i64),
335 ConstantScalarValue::Int(val, _) => Some(*val),
336 ConstantScalarValue::Float(_, _) => None,
337 ConstantScalarValue::Bool(_) => None,
338 }
339 }
340
341 pub fn as_i64(&self) -> i64 {
345 self.try_as_i64()
346 .expect("Only Int and UInt kind can be made into i64.")
347 }
348
349 pub fn try_as_f64(&self) -> Option<f64> {
353 match self {
354 ConstantScalarValue::Float(val, _) => Some(*val),
355 _ => None,
356 }
357 }
358
359 pub fn as_f64(&self) -> f64 {
363 self.try_as_f64()
364 .expect("Only Float kind can be made into f64.")
365 }
366
367 pub fn try_as_bool(&self) -> Option<bool> {
369 match self {
370 ConstantScalarValue::Bool(val) => Some(*val),
371 _ => None,
372 }
373 }
374
375 pub fn as_bool(&self) -> bool {
379 self.try_as_bool()
380 .expect("Only bool can be made into a bool")
381 }
382
383 pub fn is_zero(&self) -> bool {
384 match self {
385 ConstantScalarValue::Int(val, _) => *val == 0,
386 ConstantScalarValue::Float(val, _) => *val == 0.0,
387 ConstantScalarValue::UInt(val, _) => *val == 0,
388 ConstantScalarValue::Bool(_) => false,
389 }
390 }
391
392 pub fn is_one(&self) -> bool {
393 match self {
394 ConstantScalarValue::Int(val, _) => *val == 1,
395 ConstantScalarValue::Float(val, _) => *val == 1.0,
396 ConstantScalarValue::UInt(val, _) => *val == 1,
397 ConstantScalarValue::Bool(_) => false,
398 }
399 }
400
401 pub fn cast_to(&self, other: StorageType) -> ConstantScalarValue {
402 match (self, other.elem_type()) {
403 (ConstantScalarValue::Int(val, _), ElemType::Float(float_kind)) => {
404 ConstantScalarValue::Float(*val as f64, float_kind)
405 }
406 (ConstantScalarValue::Int(val, _), ElemType::Int(int_kind)) => {
407 ConstantScalarValue::Int(*val, int_kind)
408 }
409 (ConstantScalarValue::Int(val, _), ElemType::UInt(kind)) => {
410 ConstantScalarValue::UInt(*val as u64, kind)
411 }
412 (ConstantScalarValue::Int(val, _), ElemType::Bool) => {
413 ConstantScalarValue::Bool(*val == 1)
414 }
415 (ConstantScalarValue::Float(val, _), ElemType::Float(float_kind)) => {
416 ConstantScalarValue::Float(*val, float_kind)
417 }
418 (ConstantScalarValue::Float(val, _), ElemType::Int(int_kind)) => {
419 ConstantScalarValue::Int(*val as i64, int_kind)
420 }
421 (ConstantScalarValue::Float(val, _), ElemType::UInt(kind)) => {
422 ConstantScalarValue::UInt(*val as u64, kind)
423 }
424 (ConstantScalarValue::Float(val, _), ElemType::Bool) => {
425 ConstantScalarValue::Bool(*val == 0.0)
426 }
427 (ConstantScalarValue::UInt(val, _), ElemType::Float(float_kind)) => {
428 ConstantScalarValue::Float(*val as f64, float_kind)
429 }
430 (ConstantScalarValue::UInt(val, _), ElemType::Int(int_kind)) => {
431 ConstantScalarValue::Int(*val as i64, int_kind)
432 }
433 (ConstantScalarValue::UInt(val, _), ElemType::UInt(kind)) => {
434 ConstantScalarValue::UInt(*val, kind)
435 }
436 (ConstantScalarValue::UInt(val, _), ElemType::Bool) => {
437 ConstantScalarValue::Bool(*val == 1)
438 }
439 (ConstantScalarValue::Bool(val), ElemType::Float(float_kind)) => {
440 ConstantScalarValue::Float(*val as u32 as f64, float_kind)
441 }
442 (ConstantScalarValue::Bool(val), ElemType::Int(int_kind)) => {
443 ConstantScalarValue::Int(*val as i64, int_kind)
444 }
445 (ConstantScalarValue::Bool(val), ElemType::UInt(kind)) => {
446 ConstantScalarValue::UInt(*val as u64, kind)
447 }
448 (ConstantScalarValue::Bool(val), ElemType::Bool) => ConstantScalarValue::Bool(*val),
449 }
450 }
451}
452
453impl Display for ConstantScalarValue {
454 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
455 match self {
456 ConstantScalarValue::Int(val, IntKind::I8) => write!(f, "{val}i8"),
457 ConstantScalarValue::Int(val, IntKind::I16) => write!(f, "{val}i16"),
458 ConstantScalarValue::Int(val, IntKind::I32) => write!(f, "{val}i32"),
459 ConstantScalarValue::Int(val, IntKind::I64) => write!(f, "{val}i64"),
460 ConstantScalarValue::Float(val, FloatKind::E2M1) => write!(f, "{val}e2m1"),
461 ConstantScalarValue::Float(val, FloatKind::E2M3) => write!(f, "{val}e2m3"),
462 ConstantScalarValue::Float(val, FloatKind::E3M2) => write!(f, "{val}e3m2"),
463 ConstantScalarValue::Float(val, FloatKind::E4M3) => write!(f, "{val}e4m3"),
464 ConstantScalarValue::Float(val, FloatKind::E5M2) => write!(f, "{val}e5m2"),
465 ConstantScalarValue::Float(val, FloatKind::UE8M0) => write!(f, "{val}ue8m0"),
466 ConstantScalarValue::Float(val, FloatKind::BF16) => write!(f, "{val}bf16"),
467 ConstantScalarValue::Float(val, FloatKind::F16) => write!(f, "{val}f16"),
468 ConstantScalarValue::Float(val, FloatKind::TF32) => write!(f, "{val}tf32"),
469 ConstantScalarValue::Float(val, FloatKind::Flex32) => write!(f, "{val}flex32"),
470 ConstantScalarValue::Float(val, FloatKind::F32) => write!(f, "{val}f32"),
471 ConstantScalarValue::Float(val, FloatKind::F64) => write!(f, "{val}f64"),
472 ConstantScalarValue::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
473 ConstantScalarValue::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
474 ConstantScalarValue::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
475 ConstantScalarValue::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
476 ConstantScalarValue::Bool(val) => write!(f, "{val}"),
477 }
478 }
479}
480
481impl Variable {
482 pub fn line_size(&self) -> u32 {
483 self.ty.line_size()
484 }
485
486 pub fn index(&self) -> Option<Id> {
487 match self.kind {
488 VariableKind::GlobalInputArray(id)
489 | VariableKind::GlobalOutputArray(id)
490 | VariableKind::TensorMapInput(id)
491 | VariableKind::TensorMapOutput(id)
492 | VariableKind::GlobalScalar(id)
493 | VariableKind::LocalMut { id, .. }
494 | VariableKind::Versioned { id, .. }
495 | VariableKind::LocalConst { id, .. }
496 | VariableKind::ConstantArray { id, .. }
497 | VariableKind::SharedArray { id, .. }
498 | VariableKind::Shared { id, .. }
499 | VariableKind::LocalArray { id, .. }
500 | VariableKind::Matrix { id, .. } => Some(id),
501 _ => None,
502 }
503 }
504
505 pub fn as_const(&self) -> Option<ConstantScalarValue> {
506 match self.kind {
507 VariableKind::ConstantScalar(constant) => Some(constant),
508 _ => None,
509 }
510 }
511}
512
513impl Display for Variable {
514 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
515 match self.kind {
516 VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
517 VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
518 VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"),
519 VariableKind::TensorMapInput(id) => write!(f, "tensor_map({id})"),
520 VariableKind::TensorMapOutput(id) => write!(f, "tensor_map({id})"),
521 VariableKind::ConstantScalar(constant) => write!(f, "{constant}"),
522 VariableKind::LocalMut { id } => write!(f, "local({id})"),
523 VariableKind::Versioned { id, version } => {
524 write!(f, "local({id}).v{version}")
525 }
526 VariableKind::LocalConst { id } => write!(f, "binding({id})"),
527 VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
528 VariableKind::SharedArray { id, .. } => write!(f, "shared_array({id})"),
529 VariableKind::Shared { id } => write!(f, "shared({id})"),
530 VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
531 VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
532 VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
533 VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
534 VariableKind::BarrierToken { id, .. } => write!(f, "barrier_token({id})"),
535 }
536 }
537}
538
539impl From<&Variable> for Variable {
541 fn from(value: &Variable) -> Self {
542 *value
543 }
544}