cubecl_core/codegen/
integrator.rs1use cubecl_ir::{Id, Scope, StorageType, Type};
2use cubecl_runtime::{
3 kernel::{Binding, KernelDefinition, KernelOptions, Location, ScalarBinding, Visibility},
4 server::CubeDim,
5};
6
7use crate::prelude::AddressType;
8
9#[derive(Clone)]
12pub struct KernelIntegrator {
13 expansion: KernelExpansion,
14 buffer_bindings: Vec<Binding>,
15 scalar_bindings: Vec<ScalarBinding>,
16 tensor_maps: Vec<Binding>,
17}
18
19#[derive(Clone)]
21pub struct KernelExpansion {
22 pub buffers: Vec<BufferInfo>,
23 pub scalars: Vec<ScalarInfo>,
24 pub tensor_maps: Vec<BufferInfo>,
25 pub scope: Scope,
26}
27
28#[derive(Clone, Debug, PartialEq, Eq)]
29pub struct KernelSettings {
30 pub cube_dim: CubeDim,
31 pub address_type: AddressType,
32 pub options: KernelOptions,
33}
34
35impl Default for KernelSettings {
36 fn default() -> Self {
37 Self {
38 cube_dim: CubeDim::new_1d(1),
39 address_type: AddressType::U32,
40 options: Default::default(),
41 }
42 }
43}
44
45impl KernelSettings {
46 pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
48 self.cube_dim = cube_dim;
49 self
50 }
51
52 pub fn address_type(mut self, ty: AddressType) -> Self {
54 self.address_type = ty;
55 self
56 }
57
58 pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
60 self.options.kernel_name = name.as_ref().to_string();
61 self
62 }
63
64 pub fn debug_symbols(mut self) -> Self {
66 self.options.debug_symbols = true;
67 self
68 }
69
70 pub fn cluster_dim(mut self, cluster_dim: CubeDim) -> Self {
72 self.options.cluster_dim = Some(cluster_dim);
73 self
74 }
75}
76
77#[derive(Clone, Debug)]
79pub struct BufferInfo {
80 pub id: Id,
81 pub item: Type,
82 pub visibility: Visibility,
83 pub has_extended_meta: bool,
85}
86
87#[derive(Clone, Debug)]
89pub struct ScalarInfo {
90 pub ty: StorageType,
91 pub count: usize,
92}
93
94impl KernelIntegrator {
95 pub fn new(info: KernelExpansion) -> Self {
97 Self {
98 expansion: info,
99 buffer_bindings: Default::default(),
100 scalar_bindings: Default::default(),
101 tensor_maps: Default::default(),
102 }
103 }
104
105 #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
107 pub fn integrate(mut self, settings: KernelSettings) -> KernelDefinition {
108 self.register_buffers();
109 self.register_scalars();
110 self.register_tensor_maps();
111
112 self.scalar_bindings.sort_by_key(|binding| binding.ty);
113
114 KernelDefinition {
115 buffers: self.buffer_bindings,
116 tensor_maps: self.tensor_maps,
117 scalars: self.scalar_bindings,
118 cube_dim: settings.cube_dim,
119 body: self.expansion.scope,
120 options: settings.options,
121 }
122 }
123
124 fn register_buffers(&mut self) {
125 for buffer in self.expansion.buffers.drain(..) {
126 self.buffer_bindings.push(Binding {
127 id: buffer.id,
128 ty: buffer.item,
129 visibility: buffer.visibility,
130 location: Location::Storage,
131 has_extended_meta: buffer.has_extended_meta,
132 size: None,
133 });
134 }
135 }
136
137 fn register_scalars(&mut self) {
138 for scalar in self.expansion.scalars.drain(..) {
139 self.scalar_bindings.push(ScalarBinding {
140 ty: scalar.ty,
141 count: scalar.count,
142 });
143 }
144 }
145
146 fn register_tensor_maps(&mut self) {
147 for buffer in self.expansion.tensor_maps.drain(..) {
148 self.tensor_maps.push(Binding {
149 id: buffer.id,
150 ty: buffer.item,
151 visibility: buffer.visibility,
152 location: Location::Storage,
153 has_extended_meta: buffer.has_extended_meta,
154 size: None,
155 });
156 }
157 }
158}