cubecl_core/codegen/
integrator.rs

1use cubecl_ir::{Id, Scope, StorageType, Type};
2use cubecl_runtime::{
3    kernel::{Binding, KernelDefinition, KernelOptions, Location, ScalarBinding, Visibility},
4    server::CubeDim,
5};
6
7/// The kernel integrator allows you to create a [kernel definition](KernelDefinition) based on
8/// [kernel expansion](KernelExpansion) and [kernel settings](KernelSettings).
9#[derive(Clone)]
10pub struct KernelIntegrator {
11    expansion: KernelExpansion,
12    buffer_bindings: Vec<Binding>,
13    scalar_bindings: Vec<ScalarBinding>,
14    tensor_maps: Vec<Binding>,
15}
16
17/// The information necessary to compile a [kernel definition](KernelDefinition).
18#[derive(Clone)]
19pub struct KernelExpansion {
20    pub buffers: Vec<BufferInfo>,
21    pub scalars: Vec<ScalarInfo>,
22    pub tensor_maps: Vec<BufferInfo>,
23    pub scope: Scope,
24}
25
26#[derive(Clone, Debug, Hash, PartialEq, Eq)]
27pub struct KernelSettings {
28    pub cube_dim: CubeDim,
29    pub options: KernelOptions,
30}
31
32impl Default for KernelSettings {
33    fn default() -> Self {
34        Self {
35            cube_dim: CubeDim::new_1d(1),
36            options: Default::default(),
37        }
38    }
39}
40
41impl KernelSettings {
42    /// Set cube dimension.
43    #[allow(dead_code)]
44    pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
45        self.cube_dim = cube_dim;
46        self
47    }
48
49    /// Set kernel name.
50    #[allow(dead_code)]
51    pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
52        self.options.kernel_name = name.as_ref().to_string();
53        self
54    }
55
56    /// Activate debug symbols
57    pub fn debug_symbols(mut self) -> Self {
58        self.options.debug_symbols = true;
59        self
60    }
61
62    /// Set cluster dim
63    pub fn cluster_dim(mut self, cluster_dim: CubeDim) -> Self {
64        self.options.cluster_dim = Some(cluster_dim);
65        self
66    }
67}
68
69/// Information related to a buffer binding.
70#[derive(Clone, Debug)]
71pub struct BufferInfo {
72    pub id: Id,
73    pub item: Type,
74    pub visibility: Visibility,
75    /// Whether this input has extended metadata (rank, shape, strides)
76    pub has_extended_meta: bool,
77}
78
79/// Information related to a scalar input.
80#[derive(Clone, Debug)]
81pub struct ScalarInfo {
82    pub ty: StorageType,
83    pub count: usize,
84}
85
86impl KernelIntegrator {
87    /// Starts a new compilation.
88    pub fn new(info: KernelExpansion) -> Self {
89        Self {
90            expansion: info,
91            buffer_bindings: Default::default(),
92            scalar_bindings: Default::default(),
93            tensor_maps: Default::default(),
94        }
95    }
96
97    /// Performs the compilation with the provided [settings](KernelSettings).
98    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
99    pub fn integrate(mut self, settings: KernelSettings) -> KernelDefinition {
100        self.register_buffers();
101        self.register_scalars();
102        self.register_tensor_maps();
103
104        self.scalar_bindings.sort_by_key(|binding| binding.ty);
105
106        KernelDefinition {
107            buffers: self.buffer_bindings,
108            tensor_maps: self.tensor_maps,
109            scalars: self.scalar_bindings,
110            cube_dim: settings.cube_dim,
111            body: self.expansion.scope,
112            options: settings.options,
113        }
114    }
115
116    fn register_buffers(&mut self) {
117        for buffer in self.expansion.buffers.drain(..) {
118            self.buffer_bindings.push(Binding {
119                id: buffer.id,
120                ty: buffer.item,
121                visibility: buffer.visibility,
122                location: Location::Storage,
123                has_extended_meta: buffer.has_extended_meta,
124                size: None,
125            });
126        }
127    }
128
129    fn register_scalars(&mut self) {
130        for scalar in self.expansion.scalars.drain(..) {
131            self.scalar_bindings.push(ScalarBinding {
132                ty: scalar.ty,
133                count: scalar.count,
134            });
135        }
136    }
137
138    fn register_tensor_maps(&mut self) {
139        for buffer in self.expansion.tensor_maps.drain(..) {
140            self.tensor_maps.push(Binding {
141                id: buffer.id,
142                ty: buffer.item,
143                visibility: buffer.visibility,
144                location: Location::Storage,
145                has_extended_meta: buffer.has_extended_meta,
146                size: None,
147            });
148        }
149    }
150}