1use core::{fmt::Display, hash::Hash};
2
3use crate::{BarrierLevel, StorageType, TypeHash};
4
5use super::{ElemType, FloatKind, IntKind, Matrix, Type, UIntKind};
6use float_ord::FloatOrd;
7
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
10#[allow(missing_docs)]
11pub struct Variable {
12 pub kind: VariableKind,
13 pub ty: Type,
14}
15
16impl Variable {
17 pub fn new(kind: VariableKind, item: Type) -> Self {
18 Self { kind, ty: item }
19 }
20
21 pub fn builtin(builtin: Builtin) -> Self {
22 Self::new(
23 VariableKind::Builtin(builtin),
24 Type::scalar(ElemType::UInt(UIntKind::U32)),
25 )
26 }
27
28 pub fn constant(scalar: ConstantScalarValue) -> Self {
29 let elem = match scalar {
30 ConstantScalarValue::Int(_, int_kind) => ElemType::Int(int_kind),
31 ConstantScalarValue::Float(_, float_kind) => ElemType::Float(float_kind),
32 ConstantScalarValue::UInt(_, kind) => ElemType::UInt(kind),
33 ConstantScalarValue::Bool(_) => ElemType::Bool,
34 };
35 Self::new(VariableKind::ConstantScalar(scalar), Type::scalar(elem))
36 }
37
38 pub fn elem_type(&self) -> ElemType {
39 self.ty.elem_type()
40 }
41
42 pub fn storage_type(&self) -> StorageType {
43 self.ty.storage_type()
44 }
45}
46
47pub type Id = u32;
48
49#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
50#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
51pub enum VariableKind {
52 GlobalInputArray(Id),
53 GlobalOutputArray(Id),
54 GlobalScalar(Id),
55 TensorMapInput(Id),
56 TensorMapOutput(Id),
57 LocalArray {
58 id: Id,
59 length: u32,
60 unroll_factor: u32,
61 },
62 LocalMut {
63 id: Id,
64 },
65 LocalConst {
66 id: Id,
67 },
68 Versioned {
69 id: Id,
70 version: u16,
71 },
72 ConstantScalar(ConstantScalarValue),
73 ConstantArray {
74 id: Id,
75 length: u32,
76 unroll_factor: u32,
77 },
78 SharedMemory {
79 id: Id,
80 length: u32,
81 unroll_factor: u32,
82 alignment: Option<u32>,
83 },
84 Matrix {
85 id: Id,
86 mat: Matrix,
87 },
88 Builtin(Builtin),
89 Pipeline {
90 id: Id,
91 num_stages: u8,
92 },
93 Barrier {
94 id: Id,
95 level: BarrierLevel,
96 },
97 BarrierToken {
98 id: Id,
99 level: BarrierLevel,
100 },
101}
102
103#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
105#[repr(u8)]
106pub enum Builtin {
107 UnitPos,
108 UnitPosX,
109 UnitPosY,
110 UnitPosZ,
111 CubePosCluster,
112 CubePosClusterX,
113 CubePosClusterY,
114 CubePosClusterZ,
115 CubePos,
116 CubePosX,
117 CubePosY,
118 CubePosZ,
119 CubeDim,
120 CubeDimX,
121 CubeDimY,
122 CubeDimZ,
123 CubeClusterDim,
124 CubeClusterDimX,
125 CubeClusterDimY,
126 CubeClusterDimZ,
127 CubeCount,
128 CubeCountX,
129 CubeCountY,
130 CubeCountZ,
131 PlaneDim,
132 UnitPosPlane,
133 AbsolutePos,
134 AbsolutePosX,
135 AbsolutePosY,
136 AbsolutePosZ,
137}
138
139impl Variable {
140 pub fn is_immutable(&self) -> bool {
143 match self.kind {
144 VariableKind::GlobalOutputArray { .. } => false,
145 VariableKind::TensorMapInput(_) => true,
146 VariableKind::TensorMapOutput(_) => false,
147 VariableKind::LocalMut { .. } => false,
148 VariableKind::SharedMemory { .. } => false,
149 VariableKind::Matrix { .. } => false,
150 VariableKind::LocalArray { .. } => false,
151 VariableKind::GlobalInputArray { .. } => false,
152 VariableKind::GlobalScalar { .. } => true,
153 VariableKind::Versioned { .. } => true,
154 VariableKind::LocalConst { .. } => true,
155 VariableKind::ConstantScalar(_) => true,
156 VariableKind::ConstantArray { .. } => true,
157 VariableKind::Builtin(_) => true,
158 VariableKind::Pipeline { .. } => false,
159 VariableKind::Barrier { .. } => false,
160 VariableKind::BarrierToken { .. } => false,
161 }
162 }
163
164 pub fn is_array(&self) -> bool {
167 matches!(
168 self.kind,
169 VariableKind::GlobalInputArray { .. }
170 | VariableKind::GlobalOutputArray { .. }
171 | VariableKind::ConstantArray { .. }
172 | VariableKind::SharedMemory { .. }
173 | VariableKind::LocalArray { .. }
174 | VariableKind::Matrix { .. }
175 )
176 }
177
178 pub fn has_length(&self) -> bool {
179 matches!(
180 self.kind,
181 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
182 )
183 }
184
185 pub fn has_buffer_length(&self) -> bool {
186 matches!(
187 self.kind,
188 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
189 )
190 }
191
192 pub fn is_constant(&self, value: i64) -> bool {
194 match self.kind {
195 VariableKind::ConstantScalar(ConstantScalarValue::Int(val, _)) => val == value,
196 VariableKind::ConstantScalar(ConstantScalarValue::UInt(val, _)) => val as i64 == value,
197 VariableKind::ConstantScalar(ConstantScalarValue::Float(val, _)) => val == value as f64,
198 _ => false,
199 }
200 }
201
202 pub fn is_true(&self) -> bool {
204 match self.kind {
205 VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => val,
206 _ => false,
207 }
208 }
209
210 pub fn is_false(&self) -> bool {
212 match self.kind {
213 VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => !val,
214 _ => false,
215 }
216 }
217}
218
219#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
222#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd)]
223#[allow(missing_docs)]
224pub enum ConstantScalarValue {
225 Int(i64, IntKind),
226 Float(f64, FloatKind),
227 UInt(u64, UIntKind),
228 Bool(bool),
229}
230
231impl Eq for ConstantScalarValue {}
232impl Hash for ConstantScalarValue {
233 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
234 core::mem::discriminant(self).hash(ra_expand_state);
235 match self {
236 ConstantScalarValue::Int(f0, f1) => {
237 f0.hash(ra_expand_state);
238 f1.hash(ra_expand_state);
239 }
240 ConstantScalarValue::Float(f0, f1) => {
241 FloatOrd(*f0).hash(ra_expand_state);
242 f1.hash(ra_expand_state);
243 }
244 ConstantScalarValue::UInt(f0, f1) => {
245 f0.hash(ra_expand_state);
246 f1.hash(ra_expand_state);
247 }
248 ConstantScalarValue::Bool(f0) => {
249 f0.hash(ra_expand_state);
250 }
251 }
252 }
253}
254
255impl ConstantScalarValue {
256 pub fn elem_type(&self) -> ElemType {
258 match self {
259 ConstantScalarValue::Int(_, kind) => ElemType::Int(*kind),
260 ConstantScalarValue::Float(_, kind) => ElemType::Float(*kind),
261 ConstantScalarValue::UInt(_, kind) => ElemType::UInt(*kind),
262 ConstantScalarValue::Bool(_) => ElemType::Bool,
263 }
264 }
265
266 pub fn storage_type(&self) -> StorageType {
267 self.elem_type().into()
268 }
269
270 pub fn try_as_usize(&self) -> Option<usize> {
274 match self {
275 ConstantScalarValue::UInt(val, _) => Some(*val as usize),
276 ConstantScalarValue::Int(val, _) => Some(*val as usize),
277 ConstantScalarValue::Float(_, _) => None,
278 ConstantScalarValue::Bool(_) => None,
279 }
280 }
281
282 pub fn as_usize(&self) -> usize {
286 self.try_as_usize()
287 .expect("Only Int and UInt kind can be made into usize.")
288 }
289
290 pub fn try_as_u32(&self) -> Option<u32> {
294 match self {
295 ConstantScalarValue::UInt(val, _) => Some(*val as u32),
296 ConstantScalarValue::Int(val, _) => Some(*val as u32),
297 ConstantScalarValue::Float(_, _) => None,
298 ConstantScalarValue::Bool(_) => None,
299 }
300 }
301
302 pub fn as_u32(&self) -> u32 {
306 self.try_as_u32()
307 .expect("Only Int and UInt kind can be made into u32.")
308 }
309
310 pub fn try_as_u64(&self) -> Option<u64> {
314 match self {
315 ConstantScalarValue::UInt(val, _) => Some(*val),
316 ConstantScalarValue::Int(val, _) => Some(*val as u64),
317 ConstantScalarValue::Float(_, _) => None,
318 ConstantScalarValue::Bool(_) => None,
319 }
320 }
321
322 pub fn as_u64(&self) -> u64 {
326 self.try_as_u64()
327 .expect("Only Int and UInt kind can be made into u64.")
328 }
329
330 pub fn try_as_i64(&self) -> Option<i64> {
334 match self {
335 ConstantScalarValue::UInt(val, _) => Some(*val as i64),
336 ConstantScalarValue::Int(val, _) => Some(*val),
337 ConstantScalarValue::Float(_, _) => None,
338 ConstantScalarValue::Bool(_) => None,
339 }
340 }
341
342 pub fn as_i64(&self) -> i64 {
346 self.try_as_i64()
347 .expect("Only Int and UInt kind can be made into i64.")
348 }
349
350 pub fn try_as_f64(&self) -> Option<f64> {
354 match self {
355 ConstantScalarValue::Float(val, _) => Some(*val),
356 _ => None,
357 }
358 }
359
360 pub fn as_f64(&self) -> f64 {
364 self.try_as_f64()
365 .expect("Only Float kind can be made into f64.")
366 }
367
368 pub fn try_as_bool(&self) -> Option<bool> {
370 match self {
371 ConstantScalarValue::Bool(val) => Some(*val),
372 _ => None,
373 }
374 }
375
376 pub fn as_bool(&self) -> bool {
380 self.try_as_bool()
381 .expect("Only bool can be made into a bool")
382 }
383
384 pub fn is_zero(&self) -> bool {
385 match self {
386 ConstantScalarValue::Int(val, _) => *val == 0,
387 ConstantScalarValue::Float(val, _) => *val == 0.0,
388 ConstantScalarValue::UInt(val, _) => *val == 0,
389 ConstantScalarValue::Bool(_) => false,
390 }
391 }
392
393 pub fn is_one(&self) -> bool {
394 match self {
395 ConstantScalarValue::Int(val, _) => *val == 1,
396 ConstantScalarValue::Float(val, _) => *val == 1.0,
397 ConstantScalarValue::UInt(val, _) => *val == 1,
398 ConstantScalarValue::Bool(_) => false,
399 }
400 }
401
402 pub fn cast_to(&self, other: StorageType) -> ConstantScalarValue {
403 match (self, other.elem_type()) {
404 (ConstantScalarValue::Int(val, _), ElemType::Float(float_kind)) => {
405 ConstantScalarValue::Float(*val as f64, float_kind)
406 }
407 (ConstantScalarValue::Int(val, _), ElemType::Int(int_kind)) => {
408 ConstantScalarValue::Int(*val, int_kind)
409 }
410 (ConstantScalarValue::Int(val, _), ElemType::UInt(kind)) => {
411 ConstantScalarValue::UInt(*val as u64, kind)
412 }
413 (ConstantScalarValue::Int(val, _), ElemType::Bool) => {
414 ConstantScalarValue::Bool(*val == 1)
415 }
416 (ConstantScalarValue::Float(val, _), ElemType::Float(float_kind)) => {
417 ConstantScalarValue::Float(*val, float_kind)
418 }
419 (ConstantScalarValue::Float(val, _), ElemType::Int(int_kind)) => {
420 ConstantScalarValue::Int(*val as i64, int_kind)
421 }
422 (ConstantScalarValue::Float(val, _), ElemType::UInt(kind)) => {
423 ConstantScalarValue::UInt(*val as u64, kind)
424 }
425 (ConstantScalarValue::Float(val, _), ElemType::Bool) => {
426 ConstantScalarValue::Bool(*val == 0.0)
427 }
428 (ConstantScalarValue::UInt(val, _), ElemType::Float(float_kind)) => {
429 ConstantScalarValue::Float(*val as f64, float_kind)
430 }
431 (ConstantScalarValue::UInt(val, _), ElemType::Int(int_kind)) => {
432 ConstantScalarValue::Int(*val as i64, int_kind)
433 }
434 (ConstantScalarValue::UInt(val, _), ElemType::UInt(kind)) => {
435 ConstantScalarValue::UInt(*val, kind)
436 }
437 (ConstantScalarValue::UInt(val, _), ElemType::Bool) => {
438 ConstantScalarValue::Bool(*val == 1)
439 }
440 (ConstantScalarValue::Bool(val), ElemType::Float(float_kind)) => {
441 ConstantScalarValue::Float(*val as u32 as f64, float_kind)
442 }
443 (ConstantScalarValue::Bool(val), ElemType::Int(int_kind)) => {
444 ConstantScalarValue::Int(*val as i64, int_kind)
445 }
446 (ConstantScalarValue::Bool(val), ElemType::UInt(kind)) => {
447 ConstantScalarValue::UInt(*val as u64, kind)
448 }
449 (ConstantScalarValue::Bool(val), ElemType::Bool) => ConstantScalarValue::Bool(*val),
450 }
451 }
452}
453
454impl Display for ConstantScalarValue {
455 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
456 match self {
457 ConstantScalarValue::Int(val, IntKind::I8) => write!(f, "{val}i8"),
458 ConstantScalarValue::Int(val, IntKind::I16) => write!(f, "{val}i16"),
459 ConstantScalarValue::Int(val, IntKind::I32) => write!(f, "{val}i32"),
460 ConstantScalarValue::Int(val, IntKind::I64) => write!(f, "{val}i64"),
461 ConstantScalarValue::Float(val, FloatKind::E2M1) => write!(f, "{val}e2m1"),
462 ConstantScalarValue::Float(val, FloatKind::E2M3) => write!(f, "{val}e2m3"),
463 ConstantScalarValue::Float(val, FloatKind::E3M2) => write!(f, "{val}e3m2"),
464 ConstantScalarValue::Float(val, FloatKind::E4M3) => write!(f, "{val}e4m3"),
465 ConstantScalarValue::Float(val, FloatKind::E5M2) => write!(f, "{val}e5m2"),
466 ConstantScalarValue::Float(val, FloatKind::UE8M0) => write!(f, "{val}ue8m0"),
467 ConstantScalarValue::Float(val, FloatKind::BF16) => write!(f, "{val}bf16"),
468 ConstantScalarValue::Float(val, FloatKind::F16) => write!(f, "{val}f16"),
469 ConstantScalarValue::Float(val, FloatKind::TF32) => write!(f, "{val}tf32"),
470 ConstantScalarValue::Float(val, FloatKind::Flex32) => write!(f, "{val}flex32"),
471 ConstantScalarValue::Float(val, FloatKind::F32) => write!(f, "{val}f32"),
472 ConstantScalarValue::Float(val, FloatKind::F64) => write!(f, "{val}f64"),
473 ConstantScalarValue::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
474 ConstantScalarValue::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
475 ConstantScalarValue::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
476 ConstantScalarValue::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
477 ConstantScalarValue::Bool(val) => write!(f, "{val}"),
478 }
479 }
480}
481
482impl Variable {
483 pub fn line_size(&self) -> u32 {
484 self.ty.line_size()
485 }
486
487 pub fn index(&self) -> Option<Id> {
488 match self.kind {
489 VariableKind::GlobalInputArray(id)
490 | VariableKind::GlobalOutputArray(id)
491 | VariableKind::TensorMapInput(id)
492 | VariableKind::TensorMapOutput(id)
493 | VariableKind::GlobalScalar(id)
494 | VariableKind::LocalMut { id, .. }
495 | VariableKind::Versioned { id, .. }
496 | VariableKind::LocalConst { id, .. }
497 | VariableKind::ConstantArray { id, .. }
498 | VariableKind::SharedMemory { id, .. }
499 | VariableKind::LocalArray { id, .. }
500 | VariableKind::Matrix { id, .. } => Some(id),
501 _ => None,
502 }
503 }
504
505 pub fn as_const(&self) -> Option<ConstantScalarValue> {
506 match self.kind {
507 VariableKind::ConstantScalar(constant) => Some(constant),
508 _ => None,
509 }
510 }
511}
512
513impl Display for Variable {
514 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
515 match self.kind {
516 VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
517 VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
518 VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"),
519 VariableKind::TensorMapInput(id) => write!(f, "tensor_map({id})"),
520 VariableKind::TensorMapOutput(id) => write!(f, "tensor_map({id})"),
521 VariableKind::ConstantScalar(constant) => write!(f, "{constant}"),
522 VariableKind::LocalMut { id } => write!(f, "local({id})"),
523 VariableKind::Versioned { id, version } => {
524 write!(f, "local({id}).v{version}")
525 }
526 VariableKind::LocalConst { id } => write!(f, "binding({id})"),
527 VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
528 VariableKind::SharedMemory { id, .. } => write!(f, "shared({id})"),
529 VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
530 VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
531 VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
532 VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
533 VariableKind::Barrier { id, .. } => write!(f, "barrier({id})"),
534 VariableKind::BarrierToken { id, .. } => write!(f, "barrier_token({id})"),
535 }
536 }
537}
538
539impl From<&Variable> for Variable {
541 fn from(value: &Variable) -> Self {
542 *value
543 }
544}