cubecl_core/codegen/
integrator.rs1use alloc::{string::ToString, vec::Vec};
2use cubecl_ir::{Id, Scope, StorageType, Type};
3use cubecl_runtime::{
4 kernel::{Binding, KernelDefinition, KernelOptions, Location, ScalarBinding, Visibility},
5 server::CubeDim,
6};
7
8use crate::prelude::AddressType;
9
10#[derive(Clone)]
13pub struct KernelIntegrator {
14 expansion: KernelExpansion,
15 buffer_bindings: Vec<Binding>,
16 scalar_bindings: Vec<ScalarBinding>,
17 tensor_maps: Vec<Binding>,
18}
19
20#[derive(Clone)]
22pub struct KernelExpansion {
23 pub buffers: Vec<BufferInfo>,
24 pub scalars: Vec<ScalarInfo>,
25 pub tensor_maps: Vec<BufferInfo>,
26 pub scope: Scope,
27}
28
29#[derive(Clone, Debug, PartialEq, Eq)]
30pub struct KernelSettings {
31 pub cube_dim: CubeDim,
32 pub address_type: AddressType,
33 pub options: KernelOptions,
34}
35
36impl Default for KernelSettings {
37 fn default() -> Self {
38 Self {
39 cube_dim: CubeDim::new_1d(1),
40 address_type: AddressType::U32,
41 options: Default::default(),
42 }
43 }
44}
45
46impl KernelSettings {
47 pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
49 self.cube_dim = cube_dim;
50 self
51 }
52
53 pub fn address_type(mut self, ty: AddressType) -> Self {
55 self.address_type = ty;
56 self
57 }
58
59 pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
61 self.options.kernel_name = name.as_ref().to_string();
62 self
63 }
64
65 pub fn debug_symbols(mut self) -> Self {
67 self.options.debug_symbols = true;
68 self
69 }
70
71 pub fn cluster_dim(mut self, cluster_dim: CubeDim) -> Self {
73 self.options.cluster_dim = Some(cluster_dim);
74 self
75 }
76}
77
78#[derive(Clone, Debug)]
80pub struct BufferInfo {
81 pub id: Id,
82 pub item: Type,
83 pub visibility: Visibility,
84 pub has_extended_meta: bool,
86}
87
88#[derive(Clone, Debug)]
90pub struct ScalarInfo {
91 pub ty: StorageType,
92 pub count: usize,
93}
94
95impl KernelIntegrator {
96 pub fn new(info: KernelExpansion) -> Self {
98 Self {
99 expansion: info,
100 buffer_bindings: Default::default(),
101 scalar_bindings: Default::default(),
102 tensor_maps: Default::default(),
103 }
104 }
105
106 #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
108 pub fn integrate(mut self, settings: KernelSettings) -> KernelDefinition {
109 self.register_buffers();
110 self.register_scalars();
111 self.register_tensor_maps();
112
113 self.scalar_bindings.sort_by_key(|binding| binding.ty);
114
115 KernelDefinition {
116 buffers: self.buffer_bindings,
117 tensor_maps: self.tensor_maps,
118 scalars: self.scalar_bindings,
119 cube_dim: settings.cube_dim,
120 body: self.expansion.scope,
121 options: settings.options,
122 }
123 }
124
125 fn register_buffers(&mut self) {
126 for buffer in self.expansion.buffers.drain(..) {
127 self.buffer_bindings.push(Binding {
128 id: buffer.id,
129 ty: buffer.item,
130 visibility: buffer.visibility,
131 location: Location::Storage,
132 has_extended_meta: buffer.has_extended_meta,
133 size: None,
134 });
135 }
136 }
137
138 fn register_scalars(&mut self) {
139 for scalar in self.expansion.scalars.drain(..) {
140 self.scalar_bindings.push(ScalarBinding {
141 ty: scalar.ty,
142 count: scalar.count,
143 });
144 }
145 }
146
147 fn register_tensor_maps(&mut self) {
148 for buffer in self.expansion.tensor_maps.drain(..) {
149 self.tensor_maps.push(Binding {
150 id: buffer.id,
151 ty: buffer.item,
152 visibility: buffer.visibility,
153 location: Location::Storage,
154 has_extended_meta: buffer.has_extended_meta,
155 size: None,
156 });
157 }
158 }
159}