cubecl_core/codegen/
integrator.rs1use cubecl_common::CubeDim;
2use cubecl_ir::{Id, Scope, StorageType, Type};
3use cubecl_runtime::kernel::{
4 Binding, KernelDefinition, KernelOptions, Location, ScalarBinding, Visibility,
5};
6
7#[derive(Clone)]
10pub struct KernelIntegrator {
11 expansion: KernelExpansion,
12 buffer_bindings: Vec<Binding>,
13 scalar_bindings: Vec<ScalarBinding>,
14 tensor_maps: Vec<Binding>,
15}
16
17#[derive(Clone)]
19pub struct KernelExpansion {
20 pub buffers: Vec<BufferInfo>,
21 pub scalars: Vec<ScalarInfo>,
22 pub tensor_maps: Vec<BufferInfo>,
23 pub scope: Scope,
24}
25
26#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
27pub struct KernelSettings {
28 pub cube_dim: CubeDim,
29 pub options: KernelOptions,
30}
31
32impl KernelSettings {
33 #[allow(dead_code)]
35 pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
36 self.cube_dim = cube_dim;
37 self
38 }
39
40 #[allow(dead_code)]
42 pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
43 self.options.kernel_name = name.as_ref().to_string();
44 self
45 }
46
47 pub fn debug_symbols(mut self) -> Self {
49 self.options.debug_symbols = true;
50 self
51 }
52
53 pub fn cluster_dim(mut self, cluster_dim: CubeDim) -> Self {
55 self.options.cluster_dim = Some(cluster_dim);
56 self
57 }
58}
59
60#[derive(Clone, Debug)]
62pub struct BufferInfo {
63 pub id: Id,
64 pub item: Type,
65 pub visibility: Visibility,
66 pub has_extended_meta: bool,
68}
69
70#[derive(Clone, Debug)]
72pub struct ScalarInfo {
73 pub ty: StorageType,
74 pub count: usize,
75}
76
77impl KernelIntegrator {
78 pub fn new(info: KernelExpansion) -> Self {
80 Self {
81 expansion: info,
82 buffer_bindings: Default::default(),
83 scalar_bindings: Default::default(),
84 tensor_maps: Default::default(),
85 }
86 }
87
88 pub fn integrate(mut self, settings: KernelSettings) -> KernelDefinition {
90 self.register_buffers();
91 self.register_scalars();
92 self.register_tensor_maps();
93
94 self.scalar_bindings.sort_by_key(|binding| binding.ty);
95
96 KernelDefinition {
97 buffers: self.buffer_bindings,
98 tensor_maps: self.tensor_maps,
99 scalars: self.scalar_bindings,
100 cube_dim: settings.cube_dim,
101 body: self.expansion.scope,
102 options: settings.options,
103 }
104 }
105
106 fn register_buffers(&mut self) {
107 for buffer in self.expansion.buffers.drain(..) {
108 self.buffer_bindings.push(Binding {
109 id: buffer.id,
110 ty: buffer.item,
111 visibility: buffer.visibility,
112 location: Location::Storage,
113 has_extended_meta: buffer.has_extended_meta,
114 size: None,
115 });
116 }
117 }
118
119 fn register_scalars(&mut self) {
120 for scalar in self.expansion.scalars.drain(..) {
121 self.scalar_bindings.push(ScalarBinding {
122 ty: scalar.ty,
123 count: scalar.count,
124 });
125 }
126 }
127
128 fn register_tensor_maps(&mut self) {
129 for buffer in self.expansion.tensor_maps.drain(..) {
130 self.tensor_maps.push(Binding {
131 id: buffer.id,
132 ty: buffer.item,
133 visibility: buffer.visibility,
134 location: Location::Storage,
135 has_extended_meta: buffer.has_extended_meta,
136 size: None,
137 });
138 }
139 }
140}