cubecl_core/compute/
builder.rs1use std::{
2 rc::Rc,
3 sync::atomic::{AtomicI8, Ordering},
4};
5
6use alloc::collections::BTreeMap;
7
8use cubecl_ir::{ExpandElement, Scope, StorageType, TargetProperties, Variable, VariableKind};
9use cubecl_runtime::{
10 config::{GlobalConfig, compilation::CompilationLogLevel},
11 kernel::Visibility,
12};
13
14use crate::ir::{Id, Type};
15use crate::prelude::KernelDefinition;
16use crate::{BufferInfo, KernelSettings, ScalarInfo};
17use crate::{KernelExpansion, KernelIntegrator};
18
19pub struct KernelBuilder {
21 pub scope: Scope,
23 buffers: Vec<BufferInfo>,
24 scalars: BTreeMap<StorageType, usize>,
25 tensor_maps: Vec<BufferInfo>,
26}
27
28static DEBUG: AtomicI8 = AtomicI8::new(-1);
29
30impl KernelBuilder {
31 pub fn scalar(&mut self, storage: StorageType) -> ExpandElement {
33 let id = self.scalars.entry(storage).or_default();
34 let expand = self.scope.scalar(*id as Id, storage);
35 *id += 1;
36 expand
37 }
38
39 fn buffer_id(&self) -> Id {
40 self.buffers.len() as Id + self.tensor_maps.len() as Id
41 }
42
43 pub fn output_tensor(&mut self, item: Type) -> ExpandElement {
45 let id = self.buffer_id();
46 self.buffers.push(BufferInfo {
47 id,
48 item,
49 visibility: Visibility::ReadWrite,
50 has_extended_meta: true,
51 });
52 self.scope.output(id, item)
53 }
54
55 pub fn input_tensor_map(&mut self, item: Type) -> ExpandElement {
57 let id = self.buffer_id();
58 self.tensor_maps.push(BufferInfo {
59 id,
60 item,
61 visibility: Visibility::ReadWrite,
62 has_extended_meta: true,
63 });
64 ExpandElement::Plain(Variable::new(VariableKind::TensorMapInput(id), item))
65 }
66
67 pub fn output_tensor_map(&mut self, item: Type) -> ExpandElement {
69 let id = self.buffer_id();
70 self.tensor_maps.push(BufferInfo {
71 id,
72 item,
73 visibility: Visibility::Read,
74 has_extended_meta: true,
75 });
76 ExpandElement::Plain(Variable::new(VariableKind::TensorMapOutput(id), item))
77 }
78
79 pub fn input_tensor(&mut self, item: Type) -> ExpandElement {
81 let id = self.buffer_id();
82 self.buffers.push(BufferInfo {
83 id,
84 item,
85 visibility: Visibility::Read,
86 has_extended_meta: true,
87 });
88 self.scope.input(id, item)
89 }
90
91 pub fn output_array(&mut self, item: Type) -> ExpandElement {
93 let id = self.buffer_id();
94 self.buffers.push(BufferInfo {
95 id,
96 item,
97 visibility: Visibility::ReadWrite,
98 has_extended_meta: false,
99 });
100 self.scope.output(id, item)
101 }
102
103 pub fn inplace_output(&mut self, position: Id) -> ExpandElement {
105 let input = self
106 .buffers
107 .get_mut(position as usize)
108 .expect("Position valid");
109
110 input.visibility = Visibility::ReadWrite;
111 self.scope.input(position, input.item)
112 }
113
114 pub fn input_array(&mut self, item: Type) -> ExpandElement {
116 let id = self.buffer_id();
117 self.buffers.push(BufferInfo {
118 id,
119 item,
120 visibility: Visibility::Read,
121 has_extended_meta: false,
122 });
123 self.scope.input(id, item)
124 }
125
126 pub fn runtime_properties(&mut self, properties: TargetProperties) {
127 self.scope.runtime_properties = Rc::new(properties);
128 }
129
130 pub fn build(self, settings: KernelSettings) -> KernelDefinition {
132 let scalars = self
133 .scalars
134 .into_iter()
135 .map(|(ty, count)| ScalarInfo { ty, count })
136 .collect();
137 KernelIntegrator::new(KernelExpansion {
138 scope: self.scope,
139 buffers: self.buffers,
140 scalars,
141 tensor_maps: self.tensor_maps,
142 })
143 .integrate(settings)
144 }
145
146 pub fn new() -> Self {
147 let debug = DEBUG.load(Ordering::Relaxed);
148 let debug = if debug == -1 {
149 let val = match GlobalConfig::get().compilation.logger.level {
150 CompilationLogLevel::Full => 1,
151 _ => 0,
152 };
153
154 DEBUG.store(val, Ordering::Relaxed);
155 val == 1
156 } else {
157 debug == 1
158 };
159 Self {
160 scope: Scope::root(debug),
161 buffers: Default::default(),
162 scalars: Default::default(),
163 tensor_maps: Default::default(),
164 }
165 }
166}
167
168impl Default for KernelBuilder {
169 fn default() -> Self {
170 Self::new()
171 }
172}