cubecl_core/codegen/
integrator.rs1use cubecl_common::CubeDim;
2use cubecl_ir::{Id, Scope, StorageType, Type};
3
4use crate::compute::{Binding, KernelDefinition, Location, ScalarBinding, Visibility};
5
6#[derive(Clone)]
9pub struct KernelIntegrator {
10 expansion: KernelExpansion,
11 buffer_bindings: Vec<Binding>,
12 scalar_bindings: Vec<ScalarBinding>,
13 tensor_maps: Vec<Binding>,
14}
15
16#[derive(Clone)]
18pub struct KernelExpansion {
19 pub buffers: Vec<BufferInfo>,
20 pub scalars: Vec<ScalarInfo>,
21 pub tensor_maps: Vec<BufferInfo>,
22 pub scope: Scope,
23}
24
25#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
26pub struct KernelSettings {
27 pub cube_dim: CubeDim,
28 pub options: KernelOptions,
29}
30
31#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
32pub struct KernelOptions {
33 pub kernel_name: String,
34 pub debug_symbols: bool,
35 pub cluster_dim: Option<CubeDim>,
36}
37
38impl KernelSettings {
39 #[allow(dead_code)]
41 pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
42 self.cube_dim = cube_dim;
43 self
44 }
45
46 #[allow(dead_code)]
48 pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
49 self.options.kernel_name = name.as_ref().to_string();
50 self
51 }
52
53 pub fn debug_symbols(mut self) -> Self {
55 self.options.debug_symbols = true;
56 self
57 }
58
59 pub fn cluster_dim(mut self, cluster_dim: CubeDim) -> Self {
61 self.options.cluster_dim = Some(cluster_dim);
62 self
63 }
64}
65
66#[derive(Clone, Debug)]
68pub struct BufferInfo {
69 pub id: Id,
70 pub item: Type,
71 pub visibility: Visibility,
72 pub has_extended_meta: bool,
74}
75
76#[derive(Clone, Debug)]
78pub struct ScalarInfo {
79 pub ty: StorageType,
80 pub count: usize,
81}
82
83impl KernelIntegrator {
84 pub fn new(info: KernelExpansion) -> Self {
86 Self {
87 expansion: info,
88 buffer_bindings: Default::default(),
89 scalar_bindings: Default::default(),
90 tensor_maps: Default::default(),
91 }
92 }
93
94 pub fn integrate(mut self, settings: KernelSettings) -> KernelDefinition {
96 self.register_buffers();
97 self.register_scalars();
98 self.register_tensor_maps();
99
100 self.scalar_bindings.sort_by_key(|binding| binding.ty);
101
102 KernelDefinition {
103 buffers: self.buffer_bindings,
104 tensor_maps: self.tensor_maps,
105 scalars: self.scalar_bindings,
106 cube_dim: settings.cube_dim,
107 body: self.expansion.scope,
108 options: settings.options,
109 }
110 }
111
112 fn register_buffers(&mut self) {
113 for buffer in self.expansion.buffers.drain(..) {
114 self.buffer_bindings.push(Binding {
115 id: buffer.id,
116 ty: buffer.item,
117 visibility: buffer.visibility,
118 location: Location::Storage,
119 has_extended_meta: buffer.has_extended_meta,
120 size: None,
121 });
122 }
123 }
124
125 fn register_scalars(&mut self) {
126 for scalar in self.expansion.scalars.drain(..) {
127 self.scalar_bindings.push(ScalarBinding {
128 ty: scalar.ty,
129 count: scalar.count,
130 });
131 }
132 }
133
134 fn register_tensor_maps(&mut self) {
135 for buffer in self.expansion.tensor_maps.drain(..) {
136 self.tensor_maps.push(Binding {
137 id: buffer.id,
138 ty: buffer.item,
139 visibility: buffer.visibility,
140 location: Location::Storage,
141 has_extended_meta: buffer.has_extended_meta,
142 size: None,
143 });
144 }
145 }
146}