1use core::{fmt::Display, hash::Hash};
2
3use crate::{BarrierLevel, StorageType, TypeHash};
4
5use super::{ElemType, FloatKind, IntKind, Matrix, Type, UIntKind};
6use float_ord::FloatOrd;
7
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
10#[allow(missing_docs)]
11pub struct Variable {
12 pub kind: VariableKind,
13 pub ty: Type,
14}
15
16impl Variable {
17 pub fn new(kind: VariableKind, item: Type) -> Self {
18 Self { kind, ty: item }
19 }
20
21 pub fn builtin(builtin: Builtin) -> Self {
22 Self::new(
23 VariableKind::Builtin(builtin),
24 Type::scalar(ElemType::UInt(UIntKind::U32)),
25 )
26 }
27
28 pub fn constant(scalar: ConstantScalarValue) -> Self {
29 let elem = match scalar {
30 ConstantScalarValue::Int(_, int_kind) => ElemType::Int(int_kind),
31 ConstantScalarValue::Float(_, float_kind) => ElemType::Float(float_kind),
32 ConstantScalarValue::UInt(_, kind) => ElemType::UInt(kind),
33 ConstantScalarValue::Bool(_) => ElemType::Bool,
34 };
35 Self::new(VariableKind::ConstantScalar(scalar), Type::scalar(elem))
36 }
37
38 pub fn elem_type(&self) -> ElemType {
39 self.ty.elem_type()
40 }
41
42 pub fn storage_type(&self) -> StorageType {
43 self.ty.storage_type()
44 }
45}
46
47pub type Id = u32;
48
49#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
50#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
51pub enum VariableKind {
52 GlobalInputArray(Id),
53 GlobalOutputArray(Id),
54 GlobalScalar(Id),
55 TensorMapInput(Id),
56 TensorMapOutput(Id),
57 LocalArray {
58 id: Id,
59 length: u32,
60 unroll_factor: u32,
61 },
62 LocalMut {
63 id: Id,
64 },
65 LocalConst {
66 id: Id,
67 },
68 Versioned {
69 id: Id,
70 version: u16,
71 },
72 ConstantScalar(ConstantScalarValue),
73 ConstantArray {
74 id: Id,
75 length: u32,
76 unroll_factor: u32,
77 },
78 SharedMemory {
79 id: Id,
80 length: u32,
81 unroll_factor: u32,
82 alignment: Option<u32>,
83 },
84 Matrix {
85 id: Id,
86 mat: Matrix,
87 },
88 Builtin(Builtin),
89 Pipeline {
90 id: Id,
91 num_stages: u8,
92 },
93 Barrier {
94 id: Id,
95 level: BarrierLevel,
96 },
97}
98
99#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
100#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
101#[repr(u8)]
102pub enum Builtin {
103 UnitPos,
104 UnitPosX,
105 UnitPosY,
106 UnitPosZ,
107 CubePosCluster,
108 CubePosClusterX,
109 CubePosClusterY,
110 CubePosClusterZ,
111 CubePos,
112 CubePosX,
113 CubePosY,
114 CubePosZ,
115 CubeDim,
116 CubeDimX,
117 CubeDimY,
118 CubeDimZ,
119 CubeClusterDim,
120 CubeClusterDimX,
121 CubeClusterDimY,
122 CubeClusterDimZ,
123 CubeCount,
124 CubeCountX,
125 CubeCountY,
126 CubeCountZ,
127 PlaneDim,
128 UnitPosPlane,
129 AbsolutePos,
130 AbsolutePosX,
131 AbsolutePosY,
132 AbsolutePosZ,
133}
134
135impl Variable {
136 pub fn is_immutable(&self) -> bool {
139 match self.kind {
140 VariableKind::GlobalOutputArray { .. } => false,
141 VariableKind::TensorMapInput(_) => true,
142 VariableKind::TensorMapOutput(_) => false,
143 VariableKind::LocalMut { .. } => false,
144 VariableKind::SharedMemory { .. } => false,
145 VariableKind::Matrix { .. } => false,
146 VariableKind::LocalArray { .. } => false,
147 VariableKind::GlobalInputArray { .. } => false,
148 VariableKind::GlobalScalar { .. } => true,
149 VariableKind::Versioned { .. } => true,
150 VariableKind::LocalConst { .. } => true,
151 VariableKind::ConstantScalar(_) => true,
152 VariableKind::ConstantArray { .. } => true,
153 VariableKind::Builtin(_) => true,
154 VariableKind::Pipeline { .. } => false,
155 VariableKind::Barrier { .. } => false,
156 }
157 }
158
159 pub fn is_array(&self) -> bool {
162 matches!(
163 self.kind,
164 VariableKind::GlobalInputArray { .. }
165 | VariableKind::GlobalOutputArray { .. }
166 | VariableKind::ConstantArray { .. }
167 | VariableKind::SharedMemory { .. }
168 | VariableKind::LocalArray { .. }
169 | VariableKind::Matrix { .. }
170 )
171 }
172
173 pub fn has_length(&self) -> bool {
174 matches!(
175 self.kind,
176 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
177 )
178 }
179
180 pub fn has_buffer_length(&self) -> bool {
181 matches!(
182 self.kind,
183 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
184 )
185 }
186
187 pub fn is_constant(&self, value: i64) -> bool {
189 match self.kind {
190 VariableKind::ConstantScalar(ConstantScalarValue::Int(val, _)) => val == value,
191 VariableKind::ConstantScalar(ConstantScalarValue::UInt(val, _)) => val as i64 == value,
192 VariableKind::ConstantScalar(ConstantScalarValue::Float(val, _)) => val == value as f64,
193 _ => false,
194 }
195 }
196
197 pub fn is_true(&self) -> bool {
199 match self.kind {
200 VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => val,
201 _ => false,
202 }
203 }
204
205 pub fn is_false(&self) -> bool {
207 match self.kind {
208 VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => !val,
209 _ => false,
210 }
211 }
212}
213
214#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
217#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd)]
218#[allow(missing_docs)]
219pub enum ConstantScalarValue {
220 Int(i64, IntKind),
221 Float(f64, FloatKind),
222 UInt(u64, UIntKind),
223 Bool(bool),
224}
225
226impl Eq for ConstantScalarValue {}
227impl Hash for ConstantScalarValue {
228 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
229 core::mem::discriminant(self).hash(ra_expand_state);
230 match self {
231 ConstantScalarValue::Int(f0, f1) => {
232 f0.hash(ra_expand_state);
233 f1.hash(ra_expand_state);
234 }
235 ConstantScalarValue::Float(f0, f1) => {
236 FloatOrd(*f0).hash(ra_expand_state);
237 f1.hash(ra_expand_state);
238 }
239 ConstantScalarValue::UInt(f0, f1) => {
240 f0.hash(ra_expand_state);
241 f1.hash(ra_expand_state);
242 }
243 ConstantScalarValue::Bool(f0) => {
244 f0.hash(ra_expand_state);
245 }
246 }
247 }
248}
249
250impl ConstantScalarValue {
251 pub fn elem_type(&self) -> ElemType {
253 match self {
254 ConstantScalarValue::Int(_, kind) => ElemType::Int(*kind),
255 ConstantScalarValue::Float(_, kind) => ElemType::Float(*kind),
256 ConstantScalarValue::UInt(_, kind) => ElemType::UInt(*kind),
257 ConstantScalarValue::Bool(_) => ElemType::Bool,
258 }
259 }
260
261 pub fn storage_type(&self) -> StorageType {
262 self.elem_type().into()
263 }
264
265 pub fn try_as_usize(&self) -> Option<usize> {
269 match self {
270 ConstantScalarValue::UInt(val, _) => Some(*val as usize),
271 ConstantScalarValue::Int(val, _) => Some(*val as usize),
272 ConstantScalarValue::Float(_, _) => None,
273 ConstantScalarValue::Bool(_) => None,
274 }
275 }
276
277 pub fn as_usize(&self) -> usize {
281 self.try_as_usize()
282 .expect("Only Int and UInt kind can be made into usize.")
283 }
284
285 pub fn try_as_u32(&self) -> Option<u32> {
289 match self {
290 ConstantScalarValue::UInt(val, _) => Some(*val as u32),
291 ConstantScalarValue::Int(val, _) => Some(*val as u32),
292 ConstantScalarValue::Float(_, _) => None,
293 ConstantScalarValue::Bool(_) => None,
294 }
295 }
296
297 pub fn as_u32(&self) -> u32 {
301 self.try_as_u32()
302 .expect("Only Int and UInt kind can be made into u32.")
303 }
304
305 pub fn try_as_u64(&self) -> Option<u64> {
309 match self {
310 ConstantScalarValue::UInt(val, _) => Some(*val),
311 ConstantScalarValue::Int(val, _) => Some(*val as u64),
312 ConstantScalarValue::Float(_, _) => None,
313 ConstantScalarValue::Bool(_) => None,
314 }
315 }
316
317 pub fn as_u64(&self) -> u64 {
321 self.try_as_u64()
322 .expect("Only Int and UInt kind can be made into u64.")
323 }
324
325 pub fn try_as_i64(&self) -> Option<i64> {
329 match self {
330 ConstantScalarValue::UInt(val, _) => Some(*val as i64),
331 ConstantScalarValue::Int(val, _) => Some(*val),
332 ConstantScalarValue::Float(_, _) => None,
333 ConstantScalarValue::Bool(_) => None,
334 }
335 }
336
337 pub fn as_i64(&self) -> i64 {
341 self.try_as_i64()
342 .expect("Only Int and UInt kind can be made into i64.")
343 }
344
345 pub fn try_as_bool(&self) -> Option<bool> {
347 match self {
348 ConstantScalarValue::Bool(val) => Some(*val),
349 _ => None,
350 }
351 }
352
353 pub fn as_bool(&self) -> bool {
357 self.try_as_bool()
358 .expect("Only bool can be made into a bool")
359 }
360
361 pub fn is_zero(&self) -> bool {
362 match self {
363 ConstantScalarValue::Int(val, _) => *val == 0,
364 ConstantScalarValue::Float(val, _) => *val == 0.0,
365 ConstantScalarValue::UInt(val, _) => *val == 0,
366 ConstantScalarValue::Bool(_) => false,
367 }
368 }
369
370 pub fn is_one(&self) -> bool {
371 match self {
372 ConstantScalarValue::Int(val, _) => *val == 1,
373 ConstantScalarValue::Float(val, _) => *val == 1.0,
374 ConstantScalarValue::UInt(val, _) => *val == 1,
375 ConstantScalarValue::Bool(_) => false,
376 }
377 }
378
379 pub fn cast_to(&self, other: StorageType) -> ConstantScalarValue {
380 match (self, other.elem_type()) {
381 (ConstantScalarValue::Int(val, _), ElemType::Float(float_kind)) => {
382 ConstantScalarValue::Float(*val as f64, float_kind)
383 }
384 (ConstantScalarValue::Int(val, _), ElemType::Int(int_kind)) => {
385 ConstantScalarValue::Int(*val, int_kind)
386 }
387 (ConstantScalarValue::Int(val, _), ElemType::UInt(kind)) => {
388 ConstantScalarValue::UInt(*val as u64, kind)
389 }
390 (ConstantScalarValue::Int(val, _), ElemType::Bool) => {
391 ConstantScalarValue::Bool(*val == 1)
392 }
393 (ConstantScalarValue::Float(val, _), ElemType::Float(float_kind)) => {
394 ConstantScalarValue::Float(*val, float_kind)
395 }
396 (ConstantScalarValue::Float(val, _), ElemType::Int(int_kind)) => {
397 ConstantScalarValue::Int(*val as i64, int_kind)
398 }
399 (ConstantScalarValue::Float(val, _), ElemType::UInt(kind)) => {
400 ConstantScalarValue::UInt(*val as u64, kind)
401 }
402 (ConstantScalarValue::Float(val, _), ElemType::Bool) => {
403 ConstantScalarValue::Bool(*val == 0.0)
404 }
405 (ConstantScalarValue::UInt(val, _), ElemType::Float(float_kind)) => {
406 ConstantScalarValue::Float(*val as f64, float_kind)
407 }
408 (ConstantScalarValue::UInt(val, _), ElemType::Int(int_kind)) => {
409 ConstantScalarValue::Int(*val as i64, int_kind)
410 }
411 (ConstantScalarValue::UInt(val, _), ElemType::UInt(kind)) => {
412 ConstantScalarValue::UInt(*val, kind)
413 }
414 (ConstantScalarValue::UInt(val, _), ElemType::Bool) => {
415 ConstantScalarValue::Bool(*val == 1)
416 }
417 (ConstantScalarValue::Bool(val), ElemType::Float(float_kind)) => {
418 ConstantScalarValue::Float(*val as u32 as f64, float_kind)
419 }
420 (ConstantScalarValue::Bool(val), ElemType::Int(int_kind)) => {
421 ConstantScalarValue::Int(*val as i64, int_kind)
422 }
423 (ConstantScalarValue::Bool(val), ElemType::UInt(kind)) => {
424 ConstantScalarValue::UInt(*val as u64, kind)
425 }
426 (ConstantScalarValue::Bool(val), ElemType::Bool) => ConstantScalarValue::Bool(*val),
427 }
428 }
429}
430
431impl Display for ConstantScalarValue {
432 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
433 match self {
434 ConstantScalarValue::Int(val, IntKind::I8) => write!(f, "{val}i8"),
435 ConstantScalarValue::Int(val, IntKind::I16) => write!(f, "{val}i16"),
436 ConstantScalarValue::Int(val, IntKind::I32) => write!(f, "{val}i32"),
437 ConstantScalarValue::Int(val, IntKind::I64) => write!(f, "{val}i64"),
438 ConstantScalarValue::Float(val, FloatKind::E2M1) => write!(f, "{val}e2m1"),
439 ConstantScalarValue::Float(val, FloatKind::E2M3) => write!(f, "{val}e2m3"),
440 ConstantScalarValue::Float(val, FloatKind::E3M2) => write!(f, "{val}e3m2"),
441 ConstantScalarValue::Float(val, FloatKind::E4M3) => write!(f, "{val}e4m3"),
442 ConstantScalarValue::Float(val, FloatKind::E5M2) => write!(f, "{val}e5m2"),
443 ConstantScalarValue::Float(val, FloatKind::UE8M0) => write!(f, "{val}ue8m0"),
444 ConstantScalarValue::Float(val, FloatKind::BF16) => write!(f, "{val}bf16"),
445 ConstantScalarValue::Float(val, FloatKind::F16) => write!(f, "{val}f16"),
446 ConstantScalarValue::Float(val, FloatKind::TF32) => write!(f, "{val}tf32"),
447 ConstantScalarValue::Float(val, FloatKind::Flex32) => write!(f, "{val}flex32"),
448 ConstantScalarValue::Float(val, FloatKind::F32) => write!(f, "{val}f32"),
449 ConstantScalarValue::Float(val, FloatKind::F64) => write!(f, "{val}f64"),
450 ConstantScalarValue::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
451 ConstantScalarValue::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
452 ConstantScalarValue::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
453 ConstantScalarValue::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
454 ConstantScalarValue::Bool(val) => write!(f, "{val}"),
455 }
456 }
457}
458
459impl Variable {
460 pub fn line_size(&self) -> u32 {
461 self.ty.line_size()
462 }
463
464 pub fn index(&self) -> Option<Id> {
465 match self.kind {
466 VariableKind::GlobalInputArray(id)
467 | VariableKind::GlobalOutputArray(id)
468 | VariableKind::TensorMapInput(id)
469 | VariableKind::TensorMapOutput(id)
470 | VariableKind::GlobalScalar(id)
471 | VariableKind::LocalMut { id, .. }
472 | VariableKind::Versioned { id, .. }
473 | VariableKind::LocalConst { id, .. }
474 | VariableKind::ConstantArray { id, .. }
475 | VariableKind::SharedMemory { id, .. }
476 | VariableKind::LocalArray { id, .. }
477 | VariableKind::Matrix { id, .. } => Some(id),
478 _ => None,
479 }
480 }
481
482 pub fn as_const(&self) -> Option<ConstantScalarValue> {
483 match self.kind {
484 VariableKind::ConstantScalar(constant) => Some(constant),
485 _ => None,
486 }
487 }
488}
489
490impl Display for Variable {
491 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
492 match self.kind {
493 VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
494 VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
495 VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"),
496 VariableKind::TensorMapInput(id) => write!(f, "tensor_map({id})"),
497 VariableKind::TensorMapOutput(id) => write!(f, "tensor_map({id})"),
498 VariableKind::ConstantScalar(constant) => write!(f, "{constant}"),
499 VariableKind::LocalMut { id } => write!(f, "local({id})"),
500 VariableKind::Versioned { id, version } => {
501 write!(f, "local({id}).v{version}")
502 }
503 VariableKind::LocalConst { id } => write!(f, "binding({id})"),
504 VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
505 VariableKind::SharedMemory { id, .. } => write!(f, "shared({id})"),
506 VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
507 VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
508 VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
509 VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
510 VariableKind::Barrier { id, .. } => write!(f, "barrier({id})"),
511 }
512 }
513}
514
515impl From<&Variable> for Variable {
517 fn from(value: &Variable) -> Self {
518 *value
519 }
520}