cubecl_core/ir/
kernel.rs

1use super::{ConstantScalarValue, Scope, Variable, VariableKind};
2use crate::PLANE_DIM_APPROX;
3use serde::{Deserialize, Serialize};
4use std::fmt::Display;
5use std::num::NonZero;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8#[allow(missing_docs)]
9pub struct KernelDefinition {
10    pub inputs: Vec<Binding>,
11    pub outputs: Vec<Binding>,
12    pub named: Vec<(String, Binding)>,
13    pub cube_dim: CubeDim,
14    pub body: Scope,
15    pub kernel_name: String,
16}
17
18#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
19#[allow(missing_docs)]
20pub enum Location {
21    Storage,
22    Cube,
23}
24
25#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
26#[allow(missing_docs)]
27pub enum Visibility {
28    Read,
29    ReadWrite,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
33#[allow(missing_docs)]
34pub enum FloatKind {
35    F16,
36    BF16,
37    Flex32,
38    F32,
39    TF32,
40    F64,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
44#[allow(missing_docs)]
45pub enum IntKind {
46    I8,
47    I16,
48    I32,
49    I64,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
53#[allow(missing_docs)]
54pub enum UIntKind {
55    U8,
56    U16,
57    U32,
58    U64,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
62#[allow(missing_docs)]
63pub enum Elem {
64    Float(FloatKind),
65    Int(IntKind),
66    UInt(UIntKind),
67    AtomicFloat(FloatKind),
68    AtomicInt(IntKind),
69    AtomicUInt(UIntKind),
70    Bool,
71}
72
73impl Elem {
74    /// Create a constant scalar from a float.
75    ///
76    /// The output will have the same type as the element.
77    pub fn constant_from_f64(&self, val: f64) -> Variable {
78        Variable::constant(match self {
79            Elem::Float(kind) => ConstantScalarValue::Float(val, *kind),
80            Elem::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
81            Elem::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
82            Elem::Bool => ConstantScalarValue::Bool(val > 0.0),
83            Elem::AtomicInt(kind) => ConstantScalarValue::Int(val as i64, *kind),
84            Elem::AtomicUInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
85            Elem::AtomicFloat(kind) => ConstantScalarValue::Float(val, *kind),
86        })
87    }
88    /// Create a constant scalar from a signed integer.
89    ///
90    /// The output will have the same type as the element.
91    pub fn constant_from_i64(&self, val: i64) -> Variable {
92        Variable::constant(match self {
93            Elem::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
94            Elem::Int(kind) => ConstantScalarValue::Int(val, *kind),
95            Elem::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
96            Elem::Bool => ConstantScalarValue::Bool(val > 0),
97            Elem::AtomicInt(kind) => ConstantScalarValue::Int(val, *kind),
98            Elem::AtomicUInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
99            Elem::AtomicFloat(kind) => ConstantScalarValue::Float(val as f64, *kind),
100        })
101    }
102    /// Create a constant scalar from a unsigned integer.
103    ///
104    /// The output will have the same type as the element.
105    pub fn constant_from_u64(&self, val: u64) -> Variable {
106        Variable::constant(match self {
107            Elem::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
108            Elem::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
109            Elem::UInt(kind) => ConstantScalarValue::UInt(val, *kind),
110            Elem::Bool => ConstantScalarValue::Bool(val > 0),
111            Elem::AtomicInt(kind) => ConstantScalarValue::Int(val as i64, *kind),
112            Elem::AtomicUInt(kind) => ConstantScalarValue::UInt(val, *kind),
113            Elem::AtomicFloat(kind) => ConstantScalarValue::Float(val as f64, *kind),
114        })
115    }
116    /// Create a constant scalar from a boolean.
117    ///
118    /// The output will have the same type as the element.
119    pub fn constant_from_bool(&self, val: bool) -> Variable {
120        Variable::constant(match self {
121            Elem::Float(kind) => ConstantScalarValue::Float(val as u32 as f64, *kind),
122            Elem::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
123            Elem::AtomicInt(kind) => ConstantScalarValue::Int(val as i64, *kind),
124            Elem::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
125            Elem::AtomicUInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
126            Elem::AtomicFloat(kind) => ConstantScalarValue::Float(val as u32 as f64, *kind),
127            Elem::Bool => ConstantScalarValue::Bool(val),
128        })
129    }
130
131    /// Ensure that the variable provided, when a constant, is the same type as elem.
132    pub fn from_constant(&self, constant: Variable) -> Variable {
133        let value = match constant.kind {
134            VariableKind::ConstantScalar(value) => value,
135            _ => return constant,
136        };
137
138        match value {
139            ConstantScalarValue::Int(val, _) => self.constant_from_i64(val),
140            ConstantScalarValue::Float(val, _) => self.constant_from_f64(val),
141            ConstantScalarValue::UInt(val, _) => self.constant_from_u64(val),
142            ConstantScalarValue::Bool(val) => self.constant_from_bool(val),
143        }
144    }
145    /// Get the size in bytes.
146    pub const fn size(&self) -> usize {
147        match self {
148            Elem::Float(kind) | Elem::AtomicFloat(kind) => match kind {
149                FloatKind::F16 => core::mem::size_of::<half::f16>(),
150                FloatKind::BF16 => core::mem::size_of::<half::bf16>(),
151                FloatKind::F32 => core::mem::size_of::<f32>(),
152                FloatKind::F64 => core::mem::size_of::<f64>(),
153                FloatKind::Flex32 => core::mem::size_of::<f32>(),
154                FloatKind::TF32 => core::mem::size_of::<f32>(),
155            },
156            Elem::Int(kind) | Elem::AtomicInt(kind) => match kind {
157                IntKind::I8 => core::mem::size_of::<i8>(),
158                IntKind::I16 => core::mem::size_of::<i16>(),
159                IntKind::I32 => core::mem::size_of::<i32>(),
160                IntKind::I64 => core::mem::size_of::<i64>(),
161            },
162            Elem::UInt(kind) | Elem::AtomicUInt(kind) => match kind {
163                UIntKind::U8 => core::mem::size_of::<u8>(),
164                UIntKind::U16 => core::mem::size_of::<u16>(),
165                UIntKind::U32 => core::mem::size_of::<u32>(),
166                UIntKind::U64 => core::mem::size_of::<u64>(),
167            },
168            Elem::Bool => core::mem::size_of::<bool>(),
169        }
170    }
171
172    pub fn is_atomic(&self) -> bool {
173        matches!(
174            self,
175            Elem::AtomicFloat(_) | Elem::AtomicInt(_) | Elem::AtomicUInt(_)
176        )
177    }
178
179    pub fn is_int(&self) -> bool {
180        matches!(
181            self,
182            Elem::Int(_) | Elem::AtomicInt(_) | Elem::UInt(_) | Elem::AtomicUInt(_)
183        )
184    }
185
186    pub fn max_variable(&self) -> Variable {
187        let value = match self {
188            Elem::Float(kind) | Elem::AtomicFloat(kind) => match kind {
189                FloatKind::F16 => {
190                    ConstantScalarValue::Float(half::f16::MAX.to_f64(), FloatKind::F16)
191                }
192                FloatKind::BF16 => {
193                    ConstantScalarValue::Float(half::bf16::MAX.to_f64(), FloatKind::BF16)
194                }
195                FloatKind::Flex32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::Flex32),
196                FloatKind::F32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::F32),
197                FloatKind::TF32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::TF32),
198                FloatKind::F64 => ConstantScalarValue::Float(f64::MAX, FloatKind::F64),
199            },
200            Elem::Int(kind) | Elem::AtomicInt(kind) => match kind {
201                IntKind::I8 => ConstantScalarValue::Int(i8::MAX.into(), IntKind::I8),
202                IntKind::I16 => ConstantScalarValue::Int(i16::MAX.into(), IntKind::I16),
203                IntKind::I32 => ConstantScalarValue::Int(i32::MAX.into(), IntKind::I32),
204                IntKind::I64 => ConstantScalarValue::Int(i64::MAX, IntKind::I64),
205            },
206            Elem::UInt(kind) | Elem::AtomicUInt(kind) => match kind {
207                UIntKind::U8 => ConstantScalarValue::UInt(u8::MAX.into(), UIntKind::U8),
208                UIntKind::U16 => ConstantScalarValue::UInt(u16::MAX.into(), UIntKind::U16),
209                UIntKind::U32 => ConstantScalarValue::UInt(u32::MAX.into(), UIntKind::U32),
210                UIntKind::U64 => ConstantScalarValue::UInt(u64::MAX, UIntKind::U64),
211            },
212            Elem::Bool => ConstantScalarValue::Bool(true),
213        };
214
215        Variable::new(VariableKind::ConstantScalar(value), Item::new(*self))
216    }
217
218    pub fn min_variable(&self) -> Variable {
219        let value = match self {
220            Elem::Float(kind) | Elem::AtomicFloat(kind) => match kind {
221                FloatKind::F16 => {
222                    ConstantScalarValue::Float(half::f16::MIN.to_f64(), FloatKind::F16)
223                }
224                FloatKind::BF16 => {
225                    ConstantScalarValue::Float(half::bf16::MIN.to_f64(), FloatKind::BF16)
226                }
227                FloatKind::Flex32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::Flex32),
228                FloatKind::F32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::F32),
229                FloatKind::TF32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::TF32),
230                FloatKind::F64 => ConstantScalarValue::Float(f64::MIN, FloatKind::F64),
231            },
232            Elem::Int(kind) | Elem::AtomicInt(kind) => match kind {
233                IntKind::I8 => ConstantScalarValue::Int(i8::MIN.into(), IntKind::I8),
234                IntKind::I16 => ConstantScalarValue::Int(i16::MIN.into(), IntKind::I16),
235                IntKind::I32 => ConstantScalarValue::Int(i32::MIN.into(), IntKind::I32),
236                IntKind::I64 => ConstantScalarValue::Int(i64::MIN, IntKind::I64),
237            },
238            Elem::UInt(kind) | Elem::AtomicUInt(kind) => match kind {
239                UIntKind::U8 => ConstantScalarValue::UInt(u8::MIN.into(), UIntKind::U8),
240                UIntKind::U16 => ConstantScalarValue::UInt(u16::MIN.into(), UIntKind::U16),
241                UIntKind::U32 => ConstantScalarValue::UInt(u32::MIN.into(), UIntKind::U32),
242                UIntKind::U64 => ConstantScalarValue::UInt(u64::MIN, UIntKind::U64),
243            },
244            Elem::Bool => ConstantScalarValue::Bool(false),
245        };
246
247        Variable::new(VariableKind::ConstantScalar(value), Item::new(*self))
248    }
249}
250
251impl From<Elem> for Item {
252    fn from(val: Elem) -> Self {
253        Item::new(val)
254    }
255}
256
257impl Display for Elem {
258    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259        match self {
260            Self::Float(kind) => match kind {
261                FloatKind::F16 => f.write_str("f16"),
262                FloatKind::BF16 => f.write_str("bf16"),
263                FloatKind::Flex32 => f.write_str("flex32"),
264                FloatKind::TF32 => f.write_str("tf32"),
265                FloatKind::F32 => f.write_str("f32"),
266                FloatKind::F64 => f.write_str("f64"),
267            },
268            Self::AtomicFloat(kind) => write!(f, "atomic<{}>", Elem::Float(*kind)),
269            Self::Int(kind) => match kind {
270                IntKind::I8 => f.write_str("i8"),
271                IntKind::I16 => f.write_str("i16"),
272                IntKind::I32 => f.write_str("i32"),
273                IntKind::I64 => f.write_str("i64"),
274            },
275            Self::AtomicInt(kind) => write!(f, "atomic<{}>", Elem::Int(*kind)),
276            Self::UInt(kind) => match kind {
277                UIntKind::U8 => f.write_str("u8"),
278                UIntKind::U16 => f.write_str("u16"),
279                UIntKind::U32 => f.write_str("u32"),
280                UIntKind::U64 => f.write_str("u64"),
281            },
282            Self::AtomicUInt(kind) => write!(f, "atomic<{}>", Elem::UInt(*kind)),
283            Self::Bool => f.write_str("bool"),
284        }
285    }
286}
287
288#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize, Hash, PartialOrd, Ord)]
289pub struct Item {
290    pub elem: Elem,
291    pub vectorization: Vectorization,
292}
293
294pub type Vectorization = Option<NonZero<u8>>;
295
296impl Item {}
297
298impl Item {
299    /// Fetch the elem of the item.
300    pub fn elem(&self) -> Elem {
301        self.elem
302    }
303
304    /// Create a new item without vectorization
305    pub fn new(elem: Elem) -> Self {
306        Self {
307            elem,
308            vectorization: None,
309        }
310    }
311
312    /// Create a new item with vectorization
313    pub fn vectorized(elem: Elem, vectorization: Vectorization) -> Self {
314        Self {
315            elem,
316            vectorization,
317        }
318    }
319
320    pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Item {
321        Item {
322            elem: self.elem,
323            vectorization,
324        }
325    }
326}
327
328impl Display for Item {
329    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330        match self.vectorization {
331            Some(vec) if vec.get() > 1 => {
332                write!(f, "vector{}<{}>", vec.get(), self.elem)
333            }
334            _ => write!(f, "{}", self.elem),
335        }
336    }
337}
338
339#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
340#[allow(missing_docs)]
341pub struct Binding {
342    pub location: Location,
343    pub visibility: Visibility,
344    pub item: Item,
345    pub size: Option<usize>,
346    pub has_extended_meta: bool,
347}
348
349#[derive(new, Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, Hash)]
350#[allow(missing_docs)]
351pub struct CubeDim {
352    pub x: u32,
353    pub y: u32,
354    pub z: u32,
355}
356
357impl CubeDim {
358    /// Create a new cube dim with x = y = z = 1.
359    pub const fn new_single() -> Self {
360        Self { x: 1, y: 1, z: 1 }
361    }
362
363    /// Create a new cube dim with the given x, and y = z = 1.
364    pub const fn new_1d(x: u32) -> Self {
365        Self { x, y: 1, z: 1 }
366    }
367
368    /// Create a new cube dim with the given x and y, and z = 1.
369    pub const fn new_2d(x: u32, y: u32) -> Self {
370        Self { x, y, z: 1 }
371    }
372
373    /// Create a new cube dim with the given x, y and z.
374    /// This is equivalent to the [new](CubeDim::new) function.
375    pub const fn new_3d(x: u32, y: u32, z: u32) -> Self {
376        Self { x, y, z }
377    }
378
379    pub const fn num_elems(&self) -> u32 {
380        self.x * self.y * self.z
381    }
382}
383
384impl Default for CubeDim {
385    fn default() -> Self {
386        Self {
387            x: PLANE_DIM_APPROX as u32,
388            y: PLANE_DIM_APPROX as u32,
389            z: 1,
390        }
391    }
392}