cubecl_core/codegen/
integrator.rs1use cubecl_common::CubeDim;
2use cubecl_ir::{Id, Scope, StorageType, Type};
3use enumset::EnumSet;
4
5use crate::{
6 compute::{Binding, KernelDefinition, Location, ScalarBinding, Visibility},
7 prelude::FastMath,
8};
9
10#[derive(Clone)]
13pub struct KernelIntegrator {
14 expansion: KernelExpansion,
15 buffer_bindings: Vec<Binding>,
16 scalar_bindings: Vec<ScalarBinding>,
17 tensor_maps: Vec<Binding>,
18}
19
20#[derive(Clone)]
22pub struct KernelExpansion {
23 pub buffers: Vec<BufferInfo>,
24 pub scalars: Vec<ScalarInfo>,
25 pub tensor_maps: Vec<BufferInfo>,
26 pub scope: Scope,
27}
28
29#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
30pub struct KernelSettings {
31 pub cube_dim: CubeDim,
32 pub options: KernelOptions,
33}
34
35#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
36pub struct KernelOptions {
37 pub kernel_name: String,
38 pub debug_symbols: bool,
39 pub fp_math_mode: EnumSet<FastMath>,
40 pub cluster_dim: Option<CubeDim>,
41}
42
43impl KernelSettings {
44 #[allow(dead_code)]
46 pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
47 self.cube_dim = cube_dim;
48 self
49 }
50
51 #[allow(dead_code)]
53 pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
54 self.options.kernel_name = name.as_ref().to_string();
55 self
56 }
57
58 pub fn debug_symbols(mut self) -> Self {
60 self.options.debug_symbols = true;
61 self
62 }
63
64 pub fn fp_math_mode(mut self, mode: EnumSet<FastMath>) -> Self {
66 self.options.fp_math_mode = mode;
67 self
68 }
69
70 pub fn cluster_dim(mut self, cluster_dim: CubeDim) -> Self {
72 self.options.cluster_dim = Some(cluster_dim);
73 self
74 }
75}
76
77#[derive(Clone, Debug)]
79pub struct BufferInfo {
80 pub id: Id,
81 pub item: Type,
82 pub visibility: Visibility,
83 pub has_extended_meta: bool,
85}
86
87#[derive(Clone, Debug)]
89pub struct ScalarInfo {
90 pub ty: StorageType,
91 pub count: usize,
92}
93
94impl KernelIntegrator {
95 pub fn new(info: KernelExpansion) -> Self {
97 Self {
98 expansion: info,
99 buffer_bindings: Default::default(),
100 scalar_bindings: Default::default(),
101 tensor_maps: Default::default(),
102 }
103 }
104
105 pub fn integrate(mut self, settings: KernelSettings) -> KernelDefinition {
107 self.register_buffers();
108 self.register_scalars();
109 self.register_tensor_maps();
110
111 self.scalar_bindings.sort_by_key(|binding| binding.ty);
112
113 KernelDefinition {
114 buffers: self.buffer_bindings,
115 tensor_maps: self.tensor_maps,
116 scalars: self.scalar_bindings,
117 cube_dim: settings.cube_dim,
118 body: self.expansion.scope,
119 options: settings.options,
120 }
121 }
122
123 fn register_buffers(&mut self) {
124 for buffer in self.expansion.buffers.drain(..) {
125 self.buffer_bindings.push(Binding {
126 id: buffer.id,
127 ty: buffer.item,
128 visibility: buffer.visibility,
129 location: Location::Storage,
130 has_extended_meta: buffer.has_extended_meta,
131 size: None,
132 });
133 }
134 }
135
136 fn register_scalars(&mut self) {
137 for scalar in self.expansion.scalars.drain(..) {
138 self.scalar_bindings.push(ScalarBinding {
139 ty: scalar.ty,
140 count: scalar.count,
141 });
142 }
143 }
144
145 fn register_tensor_maps(&mut self) {
146 for buffer in self.expansion.tensor_maps.drain(..) {
147 self.tensor_maps.push(Binding {
148 id: buffer.id,
149 ty: buffer.item,
150 visibility: buffer.visibility,
151 location: Location::Storage,
152 has_extended_meta: buffer.has_extended_meta,
153 size: None,
154 });
155 }
156 }
157}