cubecl_core/compute/
builder.rs1use std::{
2 rc::Rc,
3 sync::atomic::{AtomicI8, Ordering},
4};
5
6use alloc::collections::BTreeMap;
7
8use cubecl_ir::{ExpandElement, Scope, StorageType, TargetProperties, Variable, VariableKind};
9use cubecl_runtime::config::{GlobalConfig, compilation::CompilationLogLevel};
10
11use crate::ir::{Id, Type};
12use crate::prelude::KernelDefinition;
13use crate::{BufferInfo, KernelSettings, ScalarInfo};
14use crate::{KernelExpansion, KernelIntegrator};
15
16use super::Visibility;
17
18pub struct KernelBuilder {
20 pub scope: Scope,
22 buffers: Vec<BufferInfo>,
23 scalars: BTreeMap<StorageType, usize>,
24 tensor_maps: Vec<BufferInfo>,
25}
26
27static DEBUG: AtomicI8 = AtomicI8::new(-1);
28
29impl KernelBuilder {
30 pub fn scalar(&mut self, storage: StorageType) -> ExpandElement {
32 let id = self.scalars.entry(storage).or_default();
33 let expand = self.scope.scalar(*id as Id, storage);
34 *id += 1;
35 expand
36 }
37
38 fn buffer_id(&self) -> Id {
39 self.buffers.len() as Id + self.tensor_maps.len() as Id
40 }
41
42 pub fn output_tensor(&mut self, item: Type) -> ExpandElement {
44 let id = self.buffer_id();
45 self.buffers.push(BufferInfo {
46 id,
47 item,
48 visibility: Visibility::ReadWrite,
49 has_extended_meta: true,
50 });
51 self.scope.output(id, item)
52 }
53
54 pub fn input_tensor_map(&mut self, item: Type) -> ExpandElement {
56 let id = self.buffer_id();
57 self.tensor_maps.push(BufferInfo {
58 id,
59 item,
60 visibility: Visibility::ReadWrite,
61 has_extended_meta: true,
62 });
63 ExpandElement::Plain(Variable::new(VariableKind::TensorMapInput(id), item))
64 }
65
66 pub fn output_tensor_map(&mut self, item: Type) -> ExpandElement {
68 let id = self.buffer_id();
69 self.tensor_maps.push(BufferInfo {
70 id,
71 item,
72 visibility: Visibility::Read,
73 has_extended_meta: true,
74 });
75 ExpandElement::Plain(Variable::new(VariableKind::TensorMapOutput(id), item))
76 }
77
78 pub fn input_tensor(&mut self, item: Type) -> ExpandElement {
80 let id = self.buffer_id();
81 self.buffers.push(BufferInfo {
82 id,
83 item,
84 visibility: Visibility::Read,
85 has_extended_meta: true,
86 });
87 self.scope.input(id, item)
88 }
89
90 pub fn output_array(&mut self, item: Type) -> ExpandElement {
92 let id = self.buffer_id();
93 self.buffers.push(BufferInfo {
94 id,
95 item,
96 visibility: Visibility::ReadWrite,
97 has_extended_meta: false,
98 });
99 self.scope.output(id, item)
100 }
101
102 pub fn inplace_output(&mut self, position: Id) -> ExpandElement {
104 let input = self
105 .buffers
106 .get_mut(position as usize)
107 .expect("Position valid");
108
109 input.visibility = Visibility::ReadWrite;
110 self.scope.input(position, input.item)
111 }
112
113 pub fn input_array(&mut self, item: Type) -> ExpandElement {
115 let id = self.buffer_id();
116 self.buffers.push(BufferInfo {
117 id,
118 item,
119 visibility: Visibility::Read,
120 has_extended_meta: false,
121 });
122 self.scope.input(id, item)
123 }
124
125 pub fn runtime_properties(&mut self, properties: TargetProperties) {
126 self.scope.runtime_properties = Rc::new(properties);
127 }
128
129 pub fn build(self, settings: KernelSettings) -> KernelDefinition {
131 let scalars = self
132 .scalars
133 .into_iter()
134 .map(|(ty, count)| ScalarInfo { ty, count })
135 .collect();
136 KernelIntegrator::new(KernelExpansion {
137 scope: self.scope,
138 buffers: self.buffers,
139 scalars,
140 tensor_maps: self.tensor_maps,
141 })
142 .integrate(settings)
143 }
144
145 pub fn new() -> Self {
146 let debug = DEBUG.load(Ordering::Relaxed);
147 let debug = if debug == -1 {
148 let val = match GlobalConfig::get().compilation.logger.level {
149 CompilationLogLevel::Full => 1,
150 _ => 0,
151 };
152
153 DEBUG.store(val, Ordering::Relaxed);
154 val == 1
155 } else {
156 debug == 1
157 };
158 Self {
159 scope: Scope::root(debug),
160 buffers: Default::default(),
161 scalars: Default::default(),
162 tensor_maps: Default::default(),
163 }
164 }
165}
166
167impl Default for KernelBuilder {
168 fn default() -> Self {
169 Self::new()
170 }
171}