cubecl_core/codegen/
integrator.rs

1use cubecl_ir::{Id, Scope, StorageType, Type};
2use cubecl_runtime::{
3    kernel::{Binding, KernelDefinition, KernelOptions, Location, ScalarBinding, Visibility},
4    server::CubeDim,
5};
6
7use crate::prelude::AddressType;
8
9/// The kernel integrator allows you to create a [kernel definition](KernelDefinition) based on
10/// [kernel expansion](KernelExpansion) and [kernel settings](KernelSettings).
11#[derive(Clone)]
12pub struct KernelIntegrator {
13    expansion: KernelExpansion,
14    buffer_bindings: Vec<Binding>,
15    scalar_bindings: Vec<ScalarBinding>,
16    tensor_maps: Vec<Binding>,
17}
18
19/// The information necessary to compile a [kernel definition](KernelDefinition).
20#[derive(Clone)]
21pub struct KernelExpansion {
22    pub buffers: Vec<BufferInfo>,
23    pub scalars: Vec<ScalarInfo>,
24    pub tensor_maps: Vec<BufferInfo>,
25    pub scope: Scope,
26}
27
28#[derive(Clone, Debug, PartialEq, Eq)]
29pub struct KernelSettings {
30    pub cube_dim: CubeDim,
31    pub address_type: AddressType,
32    pub options: KernelOptions,
33}
34
35impl Default for KernelSettings {
36    fn default() -> Self {
37        Self {
38            cube_dim: CubeDim::new_1d(1),
39            address_type: AddressType::U32,
40            options: Default::default(),
41        }
42    }
43}
44
45impl KernelSettings {
46    /// Set cube dimension.
47    pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
48        self.cube_dim = cube_dim;
49        self
50    }
51
52    /// Set address type.
53    pub fn address_type(mut self, ty: AddressType) -> Self {
54        self.address_type = ty;
55        self
56    }
57
58    /// Set kernel name.
59    pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
60        self.options.kernel_name = name.as_ref().to_string();
61        self
62    }
63
64    /// Activate debug symbols
65    pub fn debug_symbols(mut self) -> Self {
66        self.options.debug_symbols = true;
67        self
68    }
69
70    /// Set cluster dim
71    pub fn cluster_dim(mut self, cluster_dim: CubeDim) -> Self {
72        self.options.cluster_dim = Some(cluster_dim);
73        self
74    }
75}
76
77/// Information related to a buffer binding.
78#[derive(Clone, Debug)]
79pub struct BufferInfo {
80    pub id: Id,
81    pub item: Type,
82    pub visibility: Visibility,
83    /// Whether this input has extended metadata (rank, shape, strides)
84    pub has_extended_meta: bool,
85}
86
87/// Information related to a scalar input.
88#[derive(Clone, Debug)]
89pub struct ScalarInfo {
90    pub ty: StorageType,
91    pub count: usize,
92}
93
94impl KernelIntegrator {
95    /// Starts a new compilation.
96    pub fn new(info: KernelExpansion) -> Self {
97        Self {
98            expansion: info,
99            buffer_bindings: Default::default(),
100            scalar_bindings: Default::default(),
101            tensor_maps: Default::default(),
102        }
103    }
104
105    /// Performs the compilation with the provided [settings](KernelSettings).
106    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
107    pub fn integrate(mut self, settings: KernelSettings) -> KernelDefinition {
108        self.register_buffers();
109        self.register_scalars();
110        self.register_tensor_maps();
111
112        self.scalar_bindings.sort_by_key(|binding| binding.ty);
113
114        KernelDefinition {
115            buffers: self.buffer_bindings,
116            tensor_maps: self.tensor_maps,
117            scalars: self.scalar_bindings,
118            cube_dim: settings.cube_dim,
119            body: self.expansion.scope,
120            options: settings.options,
121        }
122    }
123
124    fn register_buffers(&mut self) {
125        for buffer in self.expansion.buffers.drain(..) {
126            self.buffer_bindings.push(Binding {
127                id: buffer.id,
128                ty: buffer.item,
129                visibility: buffer.visibility,
130                location: Location::Storage,
131                has_extended_meta: buffer.has_extended_meta,
132                size: None,
133            });
134        }
135    }
136
137    fn register_scalars(&mut self) {
138        for scalar in self.expansion.scalars.drain(..) {
139            self.scalar_bindings.push(ScalarBinding {
140                ty: scalar.ty,
141                count: scalar.count,
142            });
143        }
144    }
145
146    fn register_tensor_maps(&mut self) {
147        for buffer in self.expansion.tensor_maps.drain(..) {
148            self.tensor_maps.push(Binding {
149                id: buffer.id,
150                ty: buffer.item,
151                visibility: buffer.visibility,
152                location: Location::Storage,
153                has_extended_meta: buffer.has_extended_meta,
154                size: None,
155            });
156        }
157    }
158}