cubecl_core/codegen/
integrator.rs1use cubecl_common::CubeDim;
2use cubecl_ir::{Elem, Id, Item, Scope};
3
4use crate::{
5 compute::{Binding, KernelDefinition, Location, ScalarBinding, Visibility},
6 prelude::FastMath,
7};
8
9#[derive(Clone)]
12pub struct KernelIntegrator {
13 expansion: KernelExpansion,
14 buffer_bindings: Vec<Binding>,
15 scalar_bindings: Vec<ScalarBinding>,
16 tensor_maps: Vec<Id>,
17}
18
19#[derive(Clone)]
21pub struct KernelExpansion {
22 pub buffers: Vec<BufferInfo>,
23 pub scalars: Vec<ScalarInfo>,
24 pub tensor_maps: Vec<Id>,
25 pub scope: Scope,
26}
27
28#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
29pub struct KernelSettings {
30 pub cube_dim: CubeDim,
31 pub options: KernelOptions,
32}
33
34#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
35pub struct KernelOptions {
36 pub kernel_name: String,
37 pub debug_symbols: bool,
38 pub fp_math_mode: FastMath,
39 pub cluster_dim: Option<CubeDim>,
40}
41
42impl KernelSettings {
43 #[allow(dead_code)]
45 pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
46 self.cube_dim = cube_dim;
47 self
48 }
49
50 #[allow(dead_code)]
52 pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
53 self.options.kernel_name = name.as_ref().to_string();
54 self
55 }
56
57 pub fn debug_symbols(mut self) -> Self {
59 self.options.debug_symbols = true;
60 self
61 }
62
63 pub fn fp_math_mode(mut self, mode: FastMath) -> Self {
65 self.options.fp_math_mode = mode;
66 self
67 }
68
69 pub fn cluster_dim(mut self, cluster_dim: CubeDim) -> Self {
71 self.options.cluster_dim = Some(cluster_dim);
72 self
73 }
74}
75
76#[derive(Clone, Debug)]
78pub struct BufferInfo {
79 pub id: Id,
80 pub item: Item,
81 pub visibility: Visibility,
82 pub has_extended_meta: bool,
84}
85
86#[derive(Clone, Debug)]
88pub struct ScalarInfo {
89 pub elem: Elem,
90 pub count: usize,
91}
92
93impl KernelIntegrator {
94 pub fn new(info: KernelExpansion) -> Self {
96 Self {
97 expansion: info,
98 buffer_bindings: Default::default(),
99 scalar_bindings: Default::default(),
100 tensor_maps: Default::default(),
101 }
102 }
103
104 pub fn integrate(mut self, settings: KernelSettings) -> KernelDefinition {
106 self.register_buffers();
107 self.register_scalars();
108 self.register_tensor_maps();
109
110 self.scalar_bindings.sort_by_key(|binding| binding.elem);
111
112 KernelDefinition {
113 buffers: self.buffer_bindings,
114 tensor_maps: self.tensor_maps,
115 scalars: self.scalar_bindings,
116 cube_dim: settings.cube_dim,
117 body: self.expansion.scope,
118 options: settings.options,
119 }
120 }
121
122 fn register_buffers(&mut self) {
123 for buffer in self.expansion.buffers.drain(..) {
124 self.buffer_bindings.push(Binding {
125 id: buffer.id,
126 item: buffer.item,
127 visibility: buffer.visibility,
128 location: Location::Storage,
129 has_extended_meta: buffer.has_extended_meta,
130 size: None,
131 });
132 }
133 }
134
135 fn register_scalars(&mut self) {
136 for scalar in self.expansion.scalars.drain(..) {
137 self.scalar_bindings.push(ScalarBinding {
138 elem: scalar.elem,
139 count: scalar.count,
140 });
141 }
142 }
143
144 fn register_tensor_maps(&mut self) {
145 for id in self.expansion.tensor_maps.drain(..) {
146 self.tensor_maps.push(id);
147 }
148 }
149}