cubecl_core/codegen/
integrator.rs1use cubecl_ir::{Id, Scope, StorageType, Type};
2use cubecl_runtime::{
3 kernel::{Binding, KernelDefinition, KernelOptions, Location, ScalarBinding, Visibility},
4 server::CubeDim,
5};
6
7#[derive(Clone)]
10pub struct KernelIntegrator {
11 expansion: KernelExpansion,
12 buffer_bindings: Vec<Binding>,
13 scalar_bindings: Vec<ScalarBinding>,
14 tensor_maps: Vec<Binding>,
15}
16
17#[derive(Clone)]
19pub struct KernelExpansion {
20 pub buffers: Vec<BufferInfo>,
21 pub scalars: Vec<ScalarInfo>,
22 pub tensor_maps: Vec<BufferInfo>,
23 pub scope: Scope,
24}
25
26#[derive(Clone, Debug, Hash, PartialEq, Eq)]
27pub struct KernelSettings {
28 pub cube_dim: CubeDim,
29 pub options: KernelOptions,
30}
31
32impl Default for KernelSettings {
33 fn default() -> Self {
34 Self {
35 cube_dim: CubeDim::new_1d(1),
36 options: Default::default(),
37 }
38 }
39}
40
41impl KernelSettings {
42 #[allow(dead_code)]
44 pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
45 self.cube_dim = cube_dim;
46 self
47 }
48
49 #[allow(dead_code)]
51 pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
52 self.options.kernel_name = name.as_ref().to_string();
53 self
54 }
55
56 pub fn debug_symbols(mut self) -> Self {
58 self.options.debug_symbols = true;
59 self
60 }
61
62 pub fn cluster_dim(mut self, cluster_dim: CubeDim) -> Self {
64 self.options.cluster_dim = Some(cluster_dim);
65 self
66 }
67}
68
69#[derive(Clone, Debug)]
71pub struct BufferInfo {
72 pub id: Id,
73 pub item: Type,
74 pub visibility: Visibility,
75 pub has_extended_meta: bool,
77}
78
79#[derive(Clone, Debug)]
81pub struct ScalarInfo {
82 pub ty: StorageType,
83 pub count: usize,
84}
85
86impl KernelIntegrator {
87 pub fn new(info: KernelExpansion) -> Self {
89 Self {
90 expansion: info,
91 buffer_bindings: Default::default(),
92 scalar_bindings: Default::default(),
93 tensor_maps: Default::default(),
94 }
95 }
96
97 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
99 pub fn integrate(mut self, settings: KernelSettings) -> KernelDefinition {
100 self.register_buffers();
101 self.register_scalars();
102 self.register_tensor_maps();
103
104 self.scalar_bindings.sort_by_key(|binding| binding.ty);
105
106 KernelDefinition {
107 buffers: self.buffer_bindings,
108 tensor_maps: self.tensor_maps,
109 scalars: self.scalar_bindings,
110 cube_dim: settings.cube_dim,
111 body: self.expansion.scope,
112 options: settings.options,
113 }
114 }
115
116 fn register_buffers(&mut self) {
117 for buffer in self.expansion.buffers.drain(..) {
118 self.buffer_bindings.push(Binding {
119 id: buffer.id,
120 ty: buffer.item,
121 visibility: buffer.visibility,
122 location: Location::Storage,
123 has_extended_meta: buffer.has_extended_meta,
124 size: None,
125 });
126 }
127 }
128
129 fn register_scalars(&mut self) {
130 for scalar in self.expansion.scalars.drain(..) {
131 self.scalar_bindings.push(ScalarBinding {
132 ty: scalar.ty,
133 count: scalar.count,
134 });
135 }
136 }
137
138 fn register_tensor_maps(&mut self) {
139 for buffer in self.expansion.tensor_maps.drain(..) {
140 self.tensor_maps.push(Binding {
141 id: buffer.id,
142 ty: buffer.item,
143 visibility: buffer.visibility,
144 location: Location::Storage,
145 has_extended_meta: buffer.has_extended_meta,
146 size: None,
147 });
148 }
149 }
150}