cubecl_core/codegen/
integrator.rs

1use cubecl_common::CubeDim;
2use cubecl_ir::{Id, Scope, StorageType, Type};
3use enumset::EnumSet;
4
5use crate::{
6    compute::{Binding, KernelDefinition, Location, ScalarBinding, Visibility},
7    prelude::FastMath,
8};
9
10/// The kernel integrator allows you to create a [kernel definition](KernelDefinition) based on
11/// [kernel expansion](KernelExpansion) and [kernel settings](KernelSettings).
12#[derive(Clone)]
13pub struct KernelIntegrator {
14    expansion: KernelExpansion,
15    buffer_bindings: Vec<Binding>,
16    scalar_bindings: Vec<ScalarBinding>,
17    tensor_maps: Vec<Binding>,
18}
19
20/// The information necessary to compile a [kernel definition](KernelDefinition).
21#[derive(Clone)]
22pub struct KernelExpansion {
23    pub buffers: Vec<BufferInfo>,
24    pub scalars: Vec<ScalarInfo>,
25    pub tensor_maps: Vec<BufferInfo>,
26    pub scope: Scope,
27}
28
29#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
30pub struct KernelSettings {
31    pub cube_dim: CubeDim,
32    pub options: KernelOptions,
33}
34
35#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
36pub struct KernelOptions {
37    pub kernel_name: String,
38    pub debug_symbols: bool,
39    pub fp_math_mode: EnumSet<FastMath>,
40    pub cluster_dim: Option<CubeDim>,
41}
42
43impl KernelSettings {
44    /// Set cube dimension.
45    #[allow(dead_code)]
46    pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
47        self.cube_dim = cube_dim;
48        self
49    }
50
51    /// Set kernel name.
52    #[allow(dead_code)]
53    pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
54        self.options.kernel_name = name.as_ref().to_string();
55        self
56    }
57
58    /// Activate debug symbols
59    pub fn debug_symbols(mut self) -> Self {
60        self.options.debug_symbols = true;
61        self
62    }
63
64    /// Set FP math mode
65    pub fn fp_math_mode(mut self, mode: EnumSet<FastMath>) -> Self {
66        self.options.fp_math_mode = mode;
67        self
68    }
69
70    /// Set cluster dim
71    pub fn cluster_dim(mut self, cluster_dim: CubeDim) -> Self {
72        self.options.cluster_dim = Some(cluster_dim);
73        self
74    }
75}
76
77/// Information related to a buffer binding.
78#[derive(Clone, Debug)]
79pub struct BufferInfo {
80    pub id: Id,
81    pub item: Type,
82    pub visibility: Visibility,
83    /// Whether this input has extended metadata (rank, shape, strides)
84    pub has_extended_meta: bool,
85}
86
87/// Information related to a scalar input.
88#[derive(Clone, Debug)]
89pub struct ScalarInfo {
90    pub ty: StorageType,
91    pub count: usize,
92}
93
94impl KernelIntegrator {
95    /// Starts a new compilation.
96    pub fn new(info: KernelExpansion) -> Self {
97        Self {
98            expansion: info,
99            buffer_bindings: Default::default(),
100            scalar_bindings: Default::default(),
101            tensor_maps: Default::default(),
102        }
103    }
104
105    /// Performs the compilation with the provided [settings](KernelSettings).
106    pub fn integrate(mut self, settings: KernelSettings) -> KernelDefinition {
107        self.register_buffers();
108        self.register_scalars();
109        self.register_tensor_maps();
110
111        self.scalar_bindings.sort_by_key(|binding| binding.ty);
112
113        KernelDefinition {
114            buffers: self.buffer_bindings,
115            tensor_maps: self.tensor_maps,
116            scalars: self.scalar_bindings,
117            cube_dim: settings.cube_dim,
118            body: self.expansion.scope,
119            options: settings.options,
120        }
121    }
122
123    fn register_buffers(&mut self) {
124        for buffer in self.expansion.buffers.drain(..) {
125            self.buffer_bindings.push(Binding {
126                id: buffer.id,
127                ty: buffer.item,
128                visibility: buffer.visibility,
129                location: Location::Storage,
130                has_extended_meta: buffer.has_extended_meta,
131                size: None,
132            });
133        }
134    }
135
136    fn register_scalars(&mut self) {
137        for scalar in self.expansion.scalars.drain(..) {
138            self.scalar_bindings.push(ScalarBinding {
139                ty: scalar.ty,
140                count: scalar.count,
141            });
142        }
143    }
144
145    fn register_tensor_maps(&mut self) {
146        for buffer in self.expansion.tensor_maps.drain(..) {
147            self.tensor_maps.push(Binding {
148                id: buffer.id,
149                ty: buffer.item,
150                visibility: buffer.visibility,
151                location: Location::Storage,
152                has_extended_meta: buffer.has_extended_meta,
153                size: None,
154            });
155        }
156    }
157}