1use std::fmt::Display;
2use std::num::NonZero;
3
4use super::{Elem, FloatKind, IntKind, Item, Matrix, UIntKind};
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
8#[allow(missing_docs)]
9pub struct Variable {
10 pub kind: VariableKind,
11 pub item: Item,
12}
13
14impl Variable {
15 pub fn new(kind: VariableKind, item: Item) -> Self {
16 Self { kind, item }
17 }
18
19 pub fn builtin(builtin: Builtin) -> Self {
20 Self::new(
21 VariableKind::Builtin(builtin),
22 Item::new(Elem::UInt(UIntKind::U32)),
23 )
24 }
25
26 pub fn constant(scalar: ConstantScalarValue) -> Self {
27 let elem = match scalar {
28 ConstantScalarValue::Int(_, int_kind) => Elem::Int(int_kind),
29 ConstantScalarValue::Float(_, float_kind) => Elem::Float(float_kind),
30 ConstantScalarValue::UInt(_, kind) => Elem::UInt(kind),
31 ConstantScalarValue::Bool(_) => Elem::Bool,
32 };
33 Self::new(VariableKind::ConstantScalar(scalar), Item::new(elem))
34 }
35}
36
37pub type Id = u32;
38
39#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
40pub enum VariableKind {
41 GlobalInputArray(Id),
42 GlobalOutputArray(Id),
43 GlobalScalar(Id),
44 LocalArray { id: Id, length: u32 },
45 LocalMut { id: Id },
46 LocalConst { id: Id },
47 Versioned { id: Id, version: u16 },
48 ConstantScalar(ConstantScalarValue),
49 ConstantArray { id: Id, length: u32 },
50 SharedMemory { id: Id, length: u32 },
51 Matrix { id: Id, mat: Matrix },
52 Slice { id: Id },
53 Builtin(Builtin),
54}
55
56#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
57pub enum Builtin {
58 UnitPos,
59 UnitPosX,
60 UnitPosY,
61 UnitPosZ,
62 CubePos,
63 CubePosX,
64 CubePosY,
65 CubePosZ,
66 CubeDim,
67 CubeDimX,
68 CubeDimY,
69 CubeDimZ,
70 CubeCount,
71 CubeCountX,
72 CubeCountY,
73 CubeCountZ,
74 PlaneDim,
75 UnitPosPlane,
76 AbsolutePos,
77 AbsolutePosX,
78 AbsolutePosY,
79 AbsolutePosZ,
80}
81
82impl Variable {
83 pub fn is_immutable(&self) -> bool {
86 match self.kind {
87 VariableKind::GlobalOutputArray { .. } => false,
88 VariableKind::LocalMut { .. } => false,
89 VariableKind::SharedMemory { .. } => false,
90 VariableKind::Matrix { .. } => false,
91 VariableKind::Slice { .. } => false,
92 VariableKind::LocalArray { .. } => false,
93 VariableKind::GlobalInputArray { .. } => false,
94 VariableKind::GlobalScalar { .. } => true,
95 VariableKind::Versioned { .. } => true,
96 VariableKind::LocalConst { .. } => true,
97 VariableKind::ConstantScalar(_) => true,
98 VariableKind::ConstantArray { .. } => true,
99 VariableKind::Builtin(_) => true,
100 }
101 }
102
103 pub fn is_array(&self) -> bool {
106 matches!(
107 self.kind,
108 VariableKind::GlobalInputArray { .. }
109 | VariableKind::GlobalOutputArray { .. }
110 | VariableKind::ConstantArray { .. }
111 | VariableKind::SharedMemory { .. }
112 | VariableKind::LocalArray { .. }
113 | VariableKind::Matrix { .. }
114 | VariableKind::Slice { .. }
115 )
116 }
117
118 pub fn has_length(&self) -> bool {
119 matches!(
120 self.kind,
121 VariableKind::GlobalInputArray { .. }
122 | VariableKind::GlobalOutputArray { .. }
123 | VariableKind::Slice { .. }
124 )
125 }
126
127 pub fn has_buffer_length(&self) -> bool {
128 matches!(
129 self.kind,
130 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
131 )
132 }
133
134 pub fn is_constant(&self, value: i64) -> bool {
136 match self.kind {
137 VariableKind::ConstantScalar(ConstantScalarValue::Int(val, _)) => val == value,
138 VariableKind::ConstantScalar(ConstantScalarValue::UInt(val, _)) => val as i64 == value,
139 VariableKind::ConstantScalar(ConstantScalarValue::Float(val, _)) => val == value as f64,
140 _ => false,
141 }
142 }
143
144 pub fn is_true(&self) -> bool {
146 match self.kind {
147 VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => val,
148 _ => false,
149 }
150 }
151
152 pub fn is_false(&self) -> bool {
154 match self.kind {
155 VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => !val,
156 _ => false,
157 }
158 }
159}
160
161#[derive(Debug, Clone, PartialEq, Copy, Serialize, Deserialize, PartialOrd)]
164#[allow(missing_docs)]
165pub enum ConstantScalarValue {
166 Int(i64, IntKind),
167 Float(f64, FloatKind),
168 UInt(u64, UIntKind),
169 Bool(bool),
170}
171
172impl ConstantScalarValue {
173 pub fn elem(&self) -> Elem {
175 match self {
176 ConstantScalarValue::Int(_, kind) => Elem::Int(*kind),
177 ConstantScalarValue::Float(_, kind) => Elem::Float(*kind),
178 ConstantScalarValue::UInt(_, kind) => Elem::UInt(*kind),
179 ConstantScalarValue::Bool(_) => Elem::Bool,
180 }
181 }
182
183 pub fn try_as_usize(&self) -> Option<usize> {
187 match self {
188 ConstantScalarValue::UInt(val, _) => Some(*val as usize),
189 ConstantScalarValue::Int(val, _) => Some(*val as usize),
190 ConstantScalarValue::Float(_, _) => None,
191 ConstantScalarValue::Bool(_) => None,
192 }
193 }
194
195 pub fn as_usize(&self) -> usize {
199 self.try_as_usize()
200 .expect("Only Int and UInt kind can be made into usize.")
201 }
202
203 pub fn try_as_u32(&self) -> Option<u32> {
207 match self {
208 ConstantScalarValue::UInt(val, _) => Some(*val as u32),
209 ConstantScalarValue::Int(val, _) => Some(*val as u32),
210 ConstantScalarValue::Float(_, _) => None,
211 ConstantScalarValue::Bool(_) => None,
212 }
213 }
214
215 pub fn as_u32(&self) -> u32 {
219 self.try_as_u32()
220 .expect("Only Int and UInt kind can be made into u32.")
221 }
222
223 pub fn try_as_u64(&self) -> Option<u64> {
227 match self {
228 ConstantScalarValue::UInt(val, _) => Some(*val),
229 ConstantScalarValue::Int(val, _) => Some(*val as u64),
230 ConstantScalarValue::Float(_, _) => None,
231 ConstantScalarValue::Bool(_) => None,
232 }
233 }
234
235 pub fn as_u64(&self) -> u64 {
239 self.try_as_u64()
240 .expect("Only Int and UInt kind can be made into u64.")
241 }
242
243 pub fn try_as_i64(&self) -> Option<i64> {
247 match self {
248 ConstantScalarValue::UInt(val, _) => Some(*val as i64),
249 ConstantScalarValue::Int(val, _) => Some(*val),
250 ConstantScalarValue::Float(_, _) => None,
251 ConstantScalarValue::Bool(_) => None,
252 }
253 }
254
255 pub fn as_i64(&self) -> i64 {
259 self.try_as_i64()
260 .expect("Only Int and UInt kind can be made into i64.")
261 }
262
263 pub fn try_as_bool(&self) -> Option<bool> {
265 match self {
266 ConstantScalarValue::Bool(val) => Some(*val),
267 _ => None,
268 }
269 }
270
271 pub fn as_bool(&self) -> bool {
275 self.try_as_bool()
276 .expect("Only bool can be made into a bool")
277 }
278
279 pub fn is_zero(&self) -> bool {
280 match self {
281 ConstantScalarValue::Int(val, _) => *val == 0,
282 ConstantScalarValue::Float(val, _) => *val == 0.0,
283 ConstantScalarValue::UInt(val, _) => *val == 0,
284 ConstantScalarValue::Bool(_) => false,
285 }
286 }
287
288 pub fn is_one(&self) -> bool {
289 match self {
290 ConstantScalarValue::Int(val, _) => *val == 1,
291 ConstantScalarValue::Float(val, _) => *val == 1.0,
292 ConstantScalarValue::UInt(val, _) => *val == 1,
293 ConstantScalarValue::Bool(_) => false,
294 }
295 }
296
297 pub fn cast_to(&self, other: Elem) -> ConstantScalarValue {
298 match (self, other) {
299 (ConstantScalarValue::Int(val, _), Elem::Float(float_kind)) => {
300 ConstantScalarValue::Float(*val as f64, float_kind)
301 }
302 (ConstantScalarValue::Int(val, _), Elem::Int(int_kind)) => {
303 ConstantScalarValue::Int(*val, int_kind)
304 }
305 (ConstantScalarValue::Int(val, _), Elem::UInt(kind)) => {
306 ConstantScalarValue::UInt(*val as u64, kind)
307 }
308 (ConstantScalarValue::Int(val, _), Elem::Bool) => ConstantScalarValue::Bool(*val == 1),
309 (ConstantScalarValue::Float(val, _), Elem::Float(float_kind)) => {
310 ConstantScalarValue::Float(*val, float_kind)
311 }
312 (ConstantScalarValue::Float(val, _), Elem::Int(int_kind)) => {
313 ConstantScalarValue::Int(*val as i64, int_kind)
314 }
315 (ConstantScalarValue::Float(val, _), Elem::UInt(kind)) => {
316 ConstantScalarValue::UInt(*val as u64, kind)
317 }
318 (ConstantScalarValue::Float(val, _), Elem::Bool) => {
319 ConstantScalarValue::Bool(*val == 0.0)
320 }
321 (ConstantScalarValue::UInt(val, _), Elem::Float(float_kind)) => {
322 ConstantScalarValue::Float(*val as f64, float_kind)
323 }
324 (ConstantScalarValue::UInt(val, _), Elem::Int(int_kind)) => {
325 ConstantScalarValue::Int(*val as i64, int_kind)
326 }
327 (ConstantScalarValue::UInt(val, _), Elem::UInt(kind)) => {
328 ConstantScalarValue::UInt(*val, kind)
329 }
330 (ConstantScalarValue::UInt(val, _), Elem::Bool) => ConstantScalarValue::Bool(*val == 1),
331 (ConstantScalarValue::Bool(val), Elem::Float(float_kind)) => {
332 ConstantScalarValue::Float(*val as u32 as f64, float_kind)
333 }
334 (ConstantScalarValue::Bool(val), Elem::Int(int_kind)) => {
335 ConstantScalarValue::Int(*val as i64, int_kind)
336 }
337 (ConstantScalarValue::Bool(val), Elem::UInt(kind)) => {
338 ConstantScalarValue::UInt(*val as u64, kind)
339 }
340 (ConstantScalarValue::Bool(val), Elem::Bool) => ConstantScalarValue::Bool(*val),
341 _ => unreachable!(),
342 }
343 }
344}
345
346impl Display for ConstantScalarValue {
347 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348 match self {
349 ConstantScalarValue::Int(val, IntKind::I8) => write!(f, "{val}i8"),
350 ConstantScalarValue::Int(val, IntKind::I16) => write!(f, "{val}i16"),
351 ConstantScalarValue::Int(val, IntKind::I32) => write!(f, "{val}i32"),
352 ConstantScalarValue::Int(val, IntKind::I64) => write!(f, "{val}i64"),
353 ConstantScalarValue::Float(val, FloatKind::BF16) => write!(f, "{val}bf16"),
354 ConstantScalarValue::Float(val, FloatKind::F16) => write!(f, "{val}f16"),
355 ConstantScalarValue::Float(val, FloatKind::TF32) => write!(f, "{val}tf32"),
356 ConstantScalarValue::Float(val, FloatKind::Flex32) => write!(f, "{val}flex32"),
357 ConstantScalarValue::Float(val, FloatKind::F32) => write!(f, "{val}f32"),
358 ConstantScalarValue::Float(val, FloatKind::F64) => write!(f, "{val}f64"),
359 ConstantScalarValue::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
360 ConstantScalarValue::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
361 ConstantScalarValue::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
362 ConstantScalarValue::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
363 ConstantScalarValue::Bool(val) => write!(f, "{val}"),
364 }
365 }
366}
367
368impl Variable {
369 pub fn vectorization_factor(&self) -> u8 {
370 self.item.vectorization.map(NonZero::get).unwrap_or(1u8)
371 }
372
373 pub fn index(&self) -> Option<Id> {
374 match self.kind {
375 VariableKind::GlobalInputArray(id)
376 | VariableKind::GlobalScalar(id)
377 | VariableKind::LocalMut { id, .. }
378 | VariableKind::Versioned { id, .. }
379 | VariableKind::LocalConst { id, .. }
380 | VariableKind::Slice { id, .. }
381 | VariableKind::GlobalOutputArray(id)
382 | VariableKind::ConstantArray { id, .. }
383 | VariableKind::SharedMemory { id, .. }
384 | VariableKind::LocalArray { id, .. }
385 | VariableKind::Matrix { id, .. } => Some(id),
386 _ => None,
387 }
388 }
389
390 pub fn as_const(&self) -> Option<ConstantScalarValue> {
391 match self.kind {
392 VariableKind::ConstantScalar(constant) => Some(constant),
393 _ => None,
394 }
395 }
396}
397
398impl Display for Variable {
399 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400 match self.kind {
401 VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
402 VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
403 VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"),
404 VariableKind::ConstantScalar(constant) => write!(f, "{constant}"),
405 VariableKind::LocalMut { id } => write!(f, "local({id})"),
406 VariableKind::Versioned { id, version } => {
407 write!(f, "local({id}).v{version}")
408 }
409 VariableKind::LocalConst { id } => write!(f, "binding({id})"),
410 VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
411 VariableKind::SharedMemory { id, .. } => write!(f, "shared({id})"),
412 VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
413 VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
414 VariableKind::Slice { id } => write!(f, "slice({id})"),
415 VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
416 }
417 }
418}
419
420impl From<&Variable> for Variable {
422 fn from(value: &Variable) -> Self {
423 *value
424 }
425}