cubecl_core/compute/
builder.rs1use alloc::{rc::Rc, vec::Vec};
2use core::sync::atomic::{AtomicI8, Ordering};
3
4use crate::{
5 BufferInfo, KernelExpansion, KernelIntegrator, KernelSettings, ScalarInfo,
6 ir::{Id, Type},
7 prelude::KernelDefinition,
8};
9use alloc::collections::BTreeMap;
10use cubecl_ir::{
11 DeviceProperties, ExpandElement, Scope, StorageType, TargetProperties, Variable, VariableKind,
12};
13use cubecl_runtime::{
14 config::{GlobalConfig, compilation::CompilationLogLevel},
15 kernel::Visibility,
16};
17
18pub struct KernelBuilder {
20 pub scope: Scope,
22 buffers: Vec<BufferInfo>,
23 scalars: BTreeMap<StorageType, usize>,
24 tensor_maps: Vec<BufferInfo>,
25}
26
27static DEBUG: AtomicI8 = AtomicI8::new(-1);
28
29impl KernelBuilder {
30 pub fn scalar(&mut self, storage: StorageType) -> ExpandElement {
32 let id = self.scalars.entry(storage).or_default();
33 let expand = self.scope.scalar(*id as Id, storage);
34 *id += 1;
35 expand
36 }
37
38 fn buffer_id(&self) -> Id {
39 self.buffers.len() as Id + self.tensor_maps.len() as Id
40 }
41
42 pub fn output_tensor(&mut self, item: Type) -> ExpandElement {
44 let id = self.buffer_id();
45 self.buffers.push(BufferInfo {
46 id,
47 item,
48 visibility: Visibility::ReadWrite,
49 has_extended_meta: true,
50 });
51 self.scope.output(id, item)
52 }
53
54 pub fn input_tensor_map(&mut self, item: Type) -> ExpandElement {
56 let id = self.buffer_id();
57 self.tensor_maps.push(BufferInfo {
58 id,
59 item,
60 visibility: Visibility::ReadWrite,
61 has_extended_meta: true,
62 });
63 ExpandElement::Plain(Variable::new(VariableKind::TensorMapInput(id), item))
64 }
65
66 pub fn output_tensor_map(&mut self, item: Type) -> ExpandElement {
68 let id = self.buffer_id();
69 self.tensor_maps.push(BufferInfo {
70 id,
71 item,
72 visibility: Visibility::Read,
73 has_extended_meta: true,
74 });
75 ExpandElement::Plain(Variable::new(VariableKind::TensorMapOutput(id), item))
76 }
77
78 pub fn input_tensor(&mut self, item: Type) -> ExpandElement {
80 let id = self.buffer_id();
81 self.buffers.push(BufferInfo {
82 id,
83 item,
84 visibility: Visibility::Read,
85 has_extended_meta: true,
86 });
87 self.scope.input(id, item)
88 }
89
90 pub fn output_array(&mut self, item: Type) -> ExpandElement {
92 let id = self.buffer_id();
93 self.buffers.push(BufferInfo {
94 id,
95 item,
96 visibility: Visibility::ReadWrite,
97 has_extended_meta: false,
98 });
99 self.scope.output(id, item)
100 }
101
102 pub fn inplace_output(&mut self, position: Id) -> ExpandElement {
104 let input = self
105 .buffers
106 .get_mut(position as usize)
107 .expect("Position valid");
108
109 input.visibility = Visibility::ReadWrite;
110 self.scope.input(position, input.item)
111 }
112
113 pub fn input_array(&mut self, item: Type) -> ExpandElement {
115 let id = self.buffer_id();
116 self.buffers.push(BufferInfo {
117 id,
118 item,
119 visibility: Visibility::Read,
120 has_extended_meta: false,
121 });
122 self.scope.input(id, item)
123 }
124
125 pub fn runtime_properties(&mut self, properties: TargetProperties) {
126 self.scope.runtime_properties = Rc::new(properties);
127 }
128
129 pub fn device_properties(&mut self, properties: &DeviceProperties) {
130 self.scope.device_properties(properties);
131 }
132
133 pub fn build(self, settings: KernelSettings) -> KernelDefinition {
135 let scalars = self
136 .scalars
137 .into_iter()
138 .map(|(ty, count)| ScalarInfo { ty, count })
139 .collect();
140 KernelIntegrator::new(KernelExpansion {
141 scope: self.scope,
142 buffers: self.buffers,
143 scalars,
144 tensor_maps: self.tensor_maps,
145 })
146 .integrate(settings)
147 }
148
149 pub fn new() -> Self {
150 let debug = DEBUG.load(Ordering::Relaxed);
151 let debug = if debug == -1 {
152 let val = match GlobalConfig::get().compilation.logger.level {
153 CompilationLogLevel::Full => 1,
154 _ => 0,
155 };
156
157 DEBUG.store(val, Ordering::Relaxed);
158 val == 1
159 } else {
160 debug == 1
161 };
162
163 Self {
164 scope: Scope::root(debug),
165 buffers: Default::default(),
166 scalars: Default::default(),
167 tensor_maps: Default::default(),
168 }
169 }
170}
171
172impl Default for KernelBuilder {
173 fn default() -> Self {
174 Self::new()
175 }
176}