Skip to main content

cubecl_core/codegen/
integrator.rs

1use alloc::{string::ToString, vec::Vec};
2use cubecl_ir::{Id, Scope, StorageType, Type};
3use cubecl_runtime::{
4    kernel::{Binding, KernelDefinition, KernelOptions, Location, ScalarBinding, Visibility},
5    server::CubeDim,
6};
7
8use crate::prelude::AddressType;
9
10/// The kernel integrator allows you to create a [kernel definition](KernelDefinition) based on
11/// [kernel expansion](KernelExpansion) and [kernel settings](KernelSettings).
12#[derive(Clone)]
13pub struct KernelIntegrator {
14    expansion: KernelExpansion,
15    buffer_bindings: Vec<Binding>,
16    scalar_bindings: Vec<ScalarBinding>,
17    tensor_maps: Vec<Binding>,
18}
19
20/// The information necessary to compile a [kernel definition](KernelDefinition).
21#[derive(Clone)]
22pub struct KernelExpansion {
23    pub buffers: Vec<BufferInfo>,
24    pub scalars: Vec<ScalarInfo>,
25    pub tensor_maps: Vec<BufferInfo>,
26    pub scope: Scope,
27}
28
29#[derive(Clone, Debug, PartialEq, Eq)]
30pub struct KernelSettings {
31    pub cube_dim: CubeDim,
32    pub address_type: AddressType,
33    pub options: KernelOptions,
34}
35
36impl Default for KernelSettings {
37    fn default() -> Self {
38        Self {
39            cube_dim: CubeDim::new_1d(1),
40            address_type: AddressType::U32,
41            options: Default::default(),
42        }
43    }
44}
45
46impl KernelSettings {
47    /// Set cube dimension.
48    pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
49        self.cube_dim = cube_dim;
50        self
51    }
52
53    /// Set address type.
54    pub fn address_type(mut self, ty: AddressType) -> Self {
55        self.address_type = ty;
56        self
57    }
58
59    /// Set kernel name.
60    pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
61        self.options.kernel_name = name.as_ref().to_string();
62        self
63    }
64
65    /// Activate debug symbols
66    pub fn debug_symbols(mut self) -> Self {
67        self.options.debug_symbols = true;
68        self
69    }
70
71    /// Set cluster dim
72    pub fn cluster_dim(mut self, cluster_dim: CubeDim) -> Self {
73        self.options.cluster_dim = Some(cluster_dim);
74        self
75    }
76}
77
78/// Information related to a buffer binding.
79#[derive(Clone, Debug)]
80pub struct BufferInfo {
81    pub id: Id,
82    pub item: Type,
83    pub visibility: Visibility,
84    /// Whether this input has extended metadata (rank, shape, strides)
85    pub has_extended_meta: bool,
86}
87
88/// Information related to a scalar input.
89#[derive(Clone, Debug)]
90pub struct ScalarInfo {
91    pub ty: StorageType,
92    pub count: usize,
93}
94
95impl KernelIntegrator {
96    /// Starts a new compilation.
97    pub fn new(info: KernelExpansion) -> Self {
98        Self {
99            expansion: info,
100            buffer_bindings: Default::default(),
101            scalar_bindings: Default::default(),
102            tensor_maps: Default::default(),
103        }
104    }
105
106    /// Performs the compilation with the provided [settings](KernelSettings).
107    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
108    pub fn integrate(mut self, settings: KernelSettings) -> KernelDefinition {
109        self.register_buffers();
110        self.register_scalars();
111        self.register_tensor_maps();
112
113        self.scalar_bindings.sort_by_key(|binding| binding.ty);
114
115        KernelDefinition {
116            buffers: self.buffer_bindings,
117            tensor_maps: self.tensor_maps,
118            scalars: self.scalar_bindings,
119            cube_dim: settings.cube_dim,
120            body: self.expansion.scope,
121            options: settings.options,
122        }
123    }
124
125    fn register_buffers(&mut self) {
126        for buffer in self.expansion.buffers.drain(..) {
127            self.buffer_bindings.push(Binding {
128                id: buffer.id,
129                ty: buffer.item,
130                visibility: buffer.visibility,
131                location: Location::Storage,
132                has_extended_meta: buffer.has_extended_meta,
133                size: None,
134            });
135        }
136    }
137
138    fn register_scalars(&mut self) {
139        for scalar in self.expansion.scalars.drain(..) {
140            self.scalar_bindings.push(ScalarBinding {
141                ty: scalar.ty,
142                count: scalar.count,
143            });
144        }
145    }
146
147    fn register_tensor_maps(&mut self) {
148        for buffer in self.expansion.tensor_maps.drain(..) {
149            self.tensor_maps.push(Binding {
150                id: buffer.id,
151                ty: buffer.item,
152                visibility: buffer.visibility,
153                location: Location::Storage,
154                has_extended_meta: buffer.has_extended_meta,
155                size: None,
156            });
157        }
158    }
159}