1use super::{ConstantScalarValue, Scope, Variable, VariableKind};
2use crate::PLANE_DIM_APPROX;
3use serde::{Deserialize, Serialize};
4use std::fmt::Display;
5use std::num::NonZero;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8#[allow(missing_docs)]
9pub struct KernelDefinition {
10 pub inputs: Vec<Binding>,
11 pub outputs: Vec<Binding>,
12 pub named: Vec<(String, Binding)>,
13 pub cube_dim: CubeDim,
14 pub body: Scope,
15 pub kernel_name: String,
16}
17
18#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
19#[allow(missing_docs)]
20pub enum Location {
21 Storage,
22 Cube,
23}
24
25#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
26#[allow(missing_docs)]
27pub enum Visibility {
28 Read,
29 ReadWrite,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
33#[allow(missing_docs)]
34pub enum FloatKind {
35 F16,
36 BF16,
37 Flex32,
38 F32,
39 TF32,
40 F64,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
44#[allow(missing_docs)]
45pub enum IntKind {
46 I8,
47 I16,
48 I32,
49 I64,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
53#[allow(missing_docs)]
54pub enum UIntKind {
55 U8,
56 U16,
57 U32,
58 U64,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
62#[allow(missing_docs)]
63pub enum Elem {
64 Float(FloatKind),
65 Int(IntKind),
66 UInt(UIntKind),
67 AtomicFloat(FloatKind),
68 AtomicInt(IntKind),
69 AtomicUInt(UIntKind),
70 Bool,
71}
72
73impl Elem {
74 pub fn constant_from_f64(&self, val: f64) -> Variable {
78 Variable::constant(match self {
79 Elem::Float(kind) => ConstantScalarValue::Float(val, *kind),
80 Elem::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
81 Elem::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
82 Elem::Bool => ConstantScalarValue::Bool(val > 0.0),
83 Elem::AtomicInt(kind) => ConstantScalarValue::Int(val as i64, *kind),
84 Elem::AtomicUInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
85 Elem::AtomicFloat(kind) => ConstantScalarValue::Float(val, *kind),
86 })
87 }
88 pub fn constant_from_i64(&self, val: i64) -> Variable {
92 Variable::constant(match self {
93 Elem::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
94 Elem::Int(kind) => ConstantScalarValue::Int(val, *kind),
95 Elem::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
96 Elem::Bool => ConstantScalarValue::Bool(val > 0),
97 Elem::AtomicInt(kind) => ConstantScalarValue::Int(val, *kind),
98 Elem::AtomicUInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
99 Elem::AtomicFloat(kind) => ConstantScalarValue::Float(val as f64, *kind),
100 })
101 }
102 pub fn constant_from_u64(&self, val: u64) -> Variable {
106 Variable::constant(match self {
107 Elem::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
108 Elem::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
109 Elem::UInt(kind) => ConstantScalarValue::UInt(val, *kind),
110 Elem::Bool => ConstantScalarValue::Bool(val > 0),
111 Elem::AtomicInt(kind) => ConstantScalarValue::Int(val as i64, *kind),
112 Elem::AtomicUInt(kind) => ConstantScalarValue::UInt(val, *kind),
113 Elem::AtomicFloat(kind) => ConstantScalarValue::Float(val as f64, *kind),
114 })
115 }
116 pub fn constant_from_bool(&self, val: bool) -> Variable {
120 Variable::constant(match self {
121 Elem::Float(kind) => ConstantScalarValue::Float(val as u32 as f64, *kind),
122 Elem::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
123 Elem::AtomicInt(kind) => ConstantScalarValue::Int(val as i64, *kind),
124 Elem::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
125 Elem::AtomicUInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
126 Elem::AtomicFloat(kind) => ConstantScalarValue::Float(val as u32 as f64, *kind),
127 Elem::Bool => ConstantScalarValue::Bool(val),
128 })
129 }
130
131 pub fn from_constant(&self, constant: Variable) -> Variable {
133 let value = match constant.kind {
134 VariableKind::ConstantScalar(value) => value,
135 _ => return constant,
136 };
137
138 match value {
139 ConstantScalarValue::Int(val, _) => self.constant_from_i64(val),
140 ConstantScalarValue::Float(val, _) => self.constant_from_f64(val),
141 ConstantScalarValue::UInt(val, _) => self.constant_from_u64(val),
142 ConstantScalarValue::Bool(val) => self.constant_from_bool(val),
143 }
144 }
145 pub const fn size(&self) -> usize {
147 match self {
148 Elem::Float(kind) | Elem::AtomicFloat(kind) => match kind {
149 FloatKind::F16 => core::mem::size_of::<half::f16>(),
150 FloatKind::BF16 => core::mem::size_of::<half::bf16>(),
151 FloatKind::F32 => core::mem::size_of::<f32>(),
152 FloatKind::F64 => core::mem::size_of::<f64>(),
153 FloatKind::Flex32 => core::mem::size_of::<f32>(),
154 FloatKind::TF32 => core::mem::size_of::<f32>(),
155 },
156 Elem::Int(kind) | Elem::AtomicInt(kind) => match kind {
157 IntKind::I8 => core::mem::size_of::<i8>(),
158 IntKind::I16 => core::mem::size_of::<i16>(),
159 IntKind::I32 => core::mem::size_of::<i32>(),
160 IntKind::I64 => core::mem::size_of::<i64>(),
161 },
162 Elem::UInt(kind) | Elem::AtomicUInt(kind) => match kind {
163 UIntKind::U8 => core::mem::size_of::<u8>(),
164 UIntKind::U16 => core::mem::size_of::<u16>(),
165 UIntKind::U32 => core::mem::size_of::<u32>(),
166 UIntKind::U64 => core::mem::size_of::<u64>(),
167 },
168 Elem::Bool => core::mem::size_of::<bool>(),
169 }
170 }
171
172 pub fn is_atomic(&self) -> bool {
173 matches!(
174 self,
175 Elem::AtomicFloat(_) | Elem::AtomicInt(_) | Elem::AtomicUInt(_)
176 )
177 }
178
179 pub fn is_int(&self) -> bool {
180 matches!(
181 self,
182 Elem::Int(_) | Elem::AtomicInt(_) | Elem::UInt(_) | Elem::AtomicUInt(_)
183 )
184 }
185
186 pub fn max_variable(&self) -> Variable {
187 let value = match self {
188 Elem::Float(kind) | Elem::AtomicFloat(kind) => match kind {
189 FloatKind::F16 => {
190 ConstantScalarValue::Float(half::f16::MAX.to_f64(), FloatKind::F16)
191 }
192 FloatKind::BF16 => {
193 ConstantScalarValue::Float(half::bf16::MAX.to_f64(), FloatKind::BF16)
194 }
195 FloatKind::Flex32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::Flex32),
196 FloatKind::F32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::F32),
197 FloatKind::TF32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::TF32),
198 FloatKind::F64 => ConstantScalarValue::Float(f64::MAX, FloatKind::F64),
199 },
200 Elem::Int(kind) | Elem::AtomicInt(kind) => match kind {
201 IntKind::I8 => ConstantScalarValue::Int(i8::MAX.into(), IntKind::I8),
202 IntKind::I16 => ConstantScalarValue::Int(i16::MAX.into(), IntKind::I16),
203 IntKind::I32 => ConstantScalarValue::Int(i32::MAX.into(), IntKind::I32),
204 IntKind::I64 => ConstantScalarValue::Int(i64::MAX, IntKind::I64),
205 },
206 Elem::UInt(kind) | Elem::AtomicUInt(kind) => match kind {
207 UIntKind::U8 => ConstantScalarValue::UInt(u8::MAX.into(), UIntKind::U8),
208 UIntKind::U16 => ConstantScalarValue::UInt(u16::MAX.into(), UIntKind::U16),
209 UIntKind::U32 => ConstantScalarValue::UInt(u32::MAX.into(), UIntKind::U32),
210 UIntKind::U64 => ConstantScalarValue::UInt(u64::MAX, UIntKind::U64),
211 },
212 Elem::Bool => ConstantScalarValue::Bool(true),
213 };
214
215 Variable::new(VariableKind::ConstantScalar(value), Item::new(*self))
216 }
217
218 pub fn min_variable(&self) -> Variable {
219 let value = match self {
220 Elem::Float(kind) | Elem::AtomicFloat(kind) => match kind {
221 FloatKind::F16 => {
222 ConstantScalarValue::Float(half::f16::MIN.to_f64(), FloatKind::F16)
223 }
224 FloatKind::BF16 => {
225 ConstantScalarValue::Float(half::bf16::MIN.to_f64(), FloatKind::BF16)
226 }
227 FloatKind::Flex32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::Flex32),
228 FloatKind::F32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::F32),
229 FloatKind::TF32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::TF32),
230 FloatKind::F64 => ConstantScalarValue::Float(f64::MIN, FloatKind::F64),
231 },
232 Elem::Int(kind) | Elem::AtomicInt(kind) => match kind {
233 IntKind::I8 => ConstantScalarValue::Int(i8::MIN.into(), IntKind::I8),
234 IntKind::I16 => ConstantScalarValue::Int(i16::MIN.into(), IntKind::I16),
235 IntKind::I32 => ConstantScalarValue::Int(i32::MIN.into(), IntKind::I32),
236 IntKind::I64 => ConstantScalarValue::Int(i64::MIN, IntKind::I64),
237 },
238 Elem::UInt(kind) | Elem::AtomicUInt(kind) => match kind {
239 UIntKind::U8 => ConstantScalarValue::UInt(u8::MIN.into(), UIntKind::U8),
240 UIntKind::U16 => ConstantScalarValue::UInt(u16::MIN.into(), UIntKind::U16),
241 UIntKind::U32 => ConstantScalarValue::UInt(u32::MIN.into(), UIntKind::U32),
242 UIntKind::U64 => ConstantScalarValue::UInt(u64::MIN, UIntKind::U64),
243 },
244 Elem::Bool => ConstantScalarValue::Bool(false),
245 };
246
247 Variable::new(VariableKind::ConstantScalar(value), Item::new(*self))
248 }
249}
250
251impl From<Elem> for Item {
252 fn from(val: Elem) -> Self {
253 Item::new(val)
254 }
255}
256
257impl Display for Elem {
258 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259 match self {
260 Self::Float(kind) => match kind {
261 FloatKind::F16 => f.write_str("f16"),
262 FloatKind::BF16 => f.write_str("bf16"),
263 FloatKind::Flex32 => f.write_str("flex32"),
264 FloatKind::TF32 => f.write_str("tf32"),
265 FloatKind::F32 => f.write_str("f32"),
266 FloatKind::F64 => f.write_str("f64"),
267 },
268 Self::AtomicFloat(kind) => write!(f, "atomic<{}>", Elem::Float(*kind)),
269 Self::Int(kind) => match kind {
270 IntKind::I8 => f.write_str("i8"),
271 IntKind::I16 => f.write_str("i16"),
272 IntKind::I32 => f.write_str("i32"),
273 IntKind::I64 => f.write_str("i64"),
274 },
275 Self::AtomicInt(kind) => write!(f, "atomic<{}>", Elem::Int(*kind)),
276 Self::UInt(kind) => match kind {
277 UIntKind::U8 => f.write_str("u8"),
278 UIntKind::U16 => f.write_str("u16"),
279 UIntKind::U32 => f.write_str("u32"),
280 UIntKind::U64 => f.write_str("u64"),
281 },
282 Self::AtomicUInt(kind) => write!(f, "atomic<{}>", Elem::UInt(*kind)),
283 Self::Bool => f.write_str("bool"),
284 }
285 }
286}
287
288#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize, Hash, PartialOrd, Ord)]
289pub struct Item {
290 pub elem: Elem,
291 pub vectorization: Vectorization,
292}
293
294pub type Vectorization = Option<NonZero<u8>>;
295
296impl Item {}
297
298impl Item {
299 pub fn elem(&self) -> Elem {
301 self.elem
302 }
303
304 pub fn new(elem: Elem) -> Self {
306 Self {
307 elem,
308 vectorization: None,
309 }
310 }
311
312 pub fn vectorized(elem: Elem, vectorization: Vectorization) -> Self {
314 Self {
315 elem,
316 vectorization,
317 }
318 }
319
320 pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Item {
321 Item {
322 elem: self.elem,
323 vectorization,
324 }
325 }
326}
327
328impl Display for Item {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 match self.vectorization {
331 Some(vec) if vec.get() > 1 => {
332 write!(f, "vector{}<{}>", vec.get(), self.elem)
333 }
334 _ => write!(f, "{}", self.elem),
335 }
336 }
337}
338
339#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
340#[allow(missing_docs)]
341pub struct Binding {
342 pub location: Location,
343 pub visibility: Visibility,
344 pub item: Item,
345 pub size: Option<usize>,
346 pub has_extended_meta: bool,
347}
348
349#[derive(new, Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, Hash)]
350#[allow(missing_docs)]
351pub struct CubeDim {
352 pub x: u32,
353 pub y: u32,
354 pub z: u32,
355}
356
357impl CubeDim {
358 pub const fn new_single() -> Self {
360 Self { x: 1, y: 1, z: 1 }
361 }
362
363 pub const fn new_1d(x: u32) -> Self {
365 Self { x, y: 1, z: 1 }
366 }
367
368 pub const fn new_2d(x: u32, y: u32) -> Self {
370 Self { x, y, z: 1 }
371 }
372
373 pub const fn new_3d(x: u32, y: u32, z: u32) -> Self {
376 Self { x, y, z }
377 }
378
379 pub const fn num_elems(&self) -> u32 {
380 self.x * self.y * self.z
381 }
382}
383
384impl Default for CubeDim {
385 fn default() -> Self {
386 Self {
387 x: PLANE_DIM_APPROX as u32,
388 y: PLANE_DIM_APPROX as u32,
389 z: 1,
390 }
391 }
392}