cubecl_core/compute/
builder.rs1use std::{
2 rc::Rc,
3 sync::atomic::{AtomicI8, Ordering},
4};
5
6use crate::{
7 BufferInfo, KernelExpansion, KernelIntegrator, KernelSettings, ScalarInfo,
8 ir::{Id, Type},
9 prelude::KernelDefinition,
10};
11use alloc::collections::BTreeMap;
12use cubecl_ir::{
13 DeviceProperties, ExpandElement, Scope, StorageType, TargetProperties, Variable, VariableKind,
14};
15use cubecl_runtime::{
16 config::{GlobalConfig, compilation::CompilationLogLevel},
17 kernel::Visibility,
18};
19
20pub struct KernelBuilder {
22 pub scope: Scope,
24 buffers: Vec<BufferInfo>,
25 scalars: BTreeMap<StorageType, usize>,
26 tensor_maps: Vec<BufferInfo>,
27}
28
29static DEBUG: AtomicI8 = AtomicI8::new(-1);
30
31impl KernelBuilder {
32 pub fn scalar(&mut self, storage: StorageType) -> ExpandElement {
34 let id = self.scalars.entry(storage).or_default();
35 let expand = self.scope.scalar(*id as Id, storage);
36 *id += 1;
37 expand
38 }
39
40 fn buffer_id(&self) -> Id {
41 self.buffers.len() as Id + self.tensor_maps.len() as Id
42 }
43
44 pub fn output_tensor(&mut self, item: Type) -> ExpandElement {
46 let id = self.buffer_id();
47 self.buffers.push(BufferInfo {
48 id,
49 item,
50 visibility: Visibility::ReadWrite,
51 has_extended_meta: true,
52 });
53 self.scope.output(id, item)
54 }
55
56 pub fn input_tensor_map(&mut self, item: Type) -> ExpandElement {
58 let id = self.buffer_id();
59 self.tensor_maps.push(BufferInfo {
60 id,
61 item,
62 visibility: Visibility::ReadWrite,
63 has_extended_meta: true,
64 });
65 ExpandElement::Plain(Variable::new(VariableKind::TensorMapInput(id), item))
66 }
67
68 pub fn output_tensor_map(&mut self, item: Type) -> ExpandElement {
70 let id = self.buffer_id();
71 self.tensor_maps.push(BufferInfo {
72 id,
73 item,
74 visibility: Visibility::Read,
75 has_extended_meta: true,
76 });
77 ExpandElement::Plain(Variable::new(VariableKind::TensorMapOutput(id), item))
78 }
79
80 pub fn input_tensor(&mut self, item: Type) -> ExpandElement {
82 let id = self.buffer_id();
83 self.buffers.push(BufferInfo {
84 id,
85 item,
86 visibility: Visibility::Read,
87 has_extended_meta: true,
88 });
89 self.scope.input(id, item)
90 }
91
92 pub fn output_array(&mut self, item: Type) -> ExpandElement {
94 let id = self.buffer_id();
95 self.buffers.push(BufferInfo {
96 id,
97 item,
98 visibility: Visibility::ReadWrite,
99 has_extended_meta: false,
100 });
101 self.scope.output(id, item)
102 }
103
104 pub fn inplace_output(&mut self, position: Id) -> ExpandElement {
106 let input = self
107 .buffers
108 .get_mut(position as usize)
109 .expect("Position valid");
110
111 input.visibility = Visibility::ReadWrite;
112 self.scope.input(position, input.item)
113 }
114
115 pub fn input_array(&mut self, item: Type) -> ExpandElement {
117 let id = self.buffer_id();
118 self.buffers.push(BufferInfo {
119 id,
120 item,
121 visibility: Visibility::Read,
122 has_extended_meta: false,
123 });
124 self.scope.input(id, item)
125 }
126
127 pub fn runtime_properties(&mut self, properties: TargetProperties) {
128 self.scope.runtime_properties = Rc::new(properties);
129 }
130
131 pub fn device_properties(&mut self, properties: &DeviceProperties) {
132 self.scope.device_properties(properties);
133 }
134
135 pub fn build(self, settings: KernelSettings) -> KernelDefinition {
137 let scalars = self
138 .scalars
139 .into_iter()
140 .map(|(ty, count)| ScalarInfo { ty, count })
141 .collect();
142 KernelIntegrator::new(KernelExpansion {
143 scope: self.scope,
144 buffers: self.buffers,
145 scalars,
146 tensor_maps: self.tensor_maps,
147 })
148 .integrate(settings)
149 }
150
151 pub fn new() -> Self {
152 let debug = DEBUG.load(Ordering::Relaxed);
153 let debug = if debug == -1 {
154 let val = match GlobalConfig::get().compilation.logger.level {
155 CompilationLogLevel::Full => 1,
156 _ => 0,
157 };
158
159 DEBUG.store(val, Ordering::Relaxed);
160 val == 1
161 } else {
162 debug == 1
163 };
164
165 Self {
166 scope: Scope::root(debug),
167 buffers: Default::default(),
168 scalars: Default::default(),
169 tensor_maps: Default::default(),
170 }
171 }
172}
173
174impl Default for KernelBuilder {
175 fn default() -> Self {
176 Self::new()
177 }
178}