cubecl_core/codegen/
integrator.rs

1use cubecl_common::CubeDim;
2use cubecl_ir::{Id, Scope, StorageType, Type};
3use cubecl_runtime::kernel::{
4    Binding, KernelDefinition, KernelOptions, Location, ScalarBinding, Visibility,
5};
6
7/// The kernel integrator allows you to create a [kernel definition](KernelDefinition) based on
8/// [kernel expansion](KernelExpansion) and [kernel settings](KernelSettings).
9#[derive(Clone)]
10pub struct KernelIntegrator {
11    expansion: KernelExpansion,
12    buffer_bindings: Vec<Binding>,
13    scalar_bindings: Vec<ScalarBinding>,
14    tensor_maps: Vec<Binding>,
15}
16
17/// The information necessary to compile a [kernel definition](KernelDefinition).
18#[derive(Clone)]
19pub struct KernelExpansion {
20    pub buffers: Vec<BufferInfo>,
21    pub scalars: Vec<ScalarInfo>,
22    pub tensor_maps: Vec<BufferInfo>,
23    pub scope: Scope,
24}
25
26#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
27pub struct KernelSettings {
28    pub cube_dim: CubeDim,
29    pub options: KernelOptions,
30}
31
32impl KernelSettings {
33    /// Set cube dimension.
34    #[allow(dead_code)]
35    pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
36        self.cube_dim = cube_dim;
37        self
38    }
39
40    /// Set kernel name.
41    #[allow(dead_code)]
42    pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
43        self.options.kernel_name = name.as_ref().to_string();
44        self
45    }
46
47    /// Activate debug symbols
48    pub fn debug_symbols(mut self) -> Self {
49        self.options.debug_symbols = true;
50        self
51    }
52
53    /// Set cluster dim
54    pub fn cluster_dim(mut self, cluster_dim: CubeDim) -> Self {
55        self.options.cluster_dim = Some(cluster_dim);
56        self
57    }
58}
59
60/// Information related to a buffer binding.
61#[derive(Clone, Debug)]
62pub struct BufferInfo {
63    pub id: Id,
64    pub item: Type,
65    pub visibility: Visibility,
66    /// Whether this input has extended metadata (rank, shape, strides)
67    pub has_extended_meta: bool,
68}
69
70/// Information related to a scalar input.
71#[derive(Clone, Debug)]
72pub struct ScalarInfo {
73    pub ty: StorageType,
74    pub count: usize,
75}
76
77impl KernelIntegrator {
78    /// Starts a new compilation.
79    pub fn new(info: KernelExpansion) -> Self {
80        Self {
81            expansion: info,
82            buffer_bindings: Default::default(),
83            scalar_bindings: Default::default(),
84            tensor_maps: Default::default(),
85        }
86    }
87
88    /// Performs the compilation with the provided [settings](KernelSettings).
89    pub fn integrate(mut self, settings: KernelSettings) -> KernelDefinition {
90        self.register_buffers();
91        self.register_scalars();
92        self.register_tensor_maps();
93
94        self.scalar_bindings.sort_by_key(|binding| binding.ty);
95
96        KernelDefinition {
97            buffers: self.buffer_bindings,
98            tensor_maps: self.tensor_maps,
99            scalars: self.scalar_bindings,
100            cube_dim: settings.cube_dim,
101            body: self.expansion.scope,
102            options: settings.options,
103        }
104    }
105
106    fn register_buffers(&mut self) {
107        for buffer in self.expansion.buffers.drain(..) {
108            self.buffer_bindings.push(Binding {
109                id: buffer.id,
110                ty: buffer.item,
111                visibility: buffer.visibility,
112                location: Location::Storage,
113                has_extended_meta: buffer.has_extended_meta,
114                size: None,
115            });
116        }
117    }
118
119    fn register_scalars(&mut self) {
120        for scalar in self.expansion.scalars.drain(..) {
121            self.scalar_bindings.push(ScalarBinding {
122                ty: scalar.ty,
123                count: scalar.count,
124            });
125        }
126    }
127
128    fn register_tensor_maps(&mut self) {
129        for buffer in self.expansion.tensor_maps.drain(..) {
130            self.tensor_maps.push(Binding {
131                id: buffer.id,
132                ty: buffer.item,
133                visibility: buffer.visibility,
134                location: Location::Storage,
135                has_extended_meta: buffer.has_extended_meta,
136                size: None,
137            });
138        }
139    }
140}