cubecl_core/compute/
builder.rs1use std::sync::atomic::{AtomicI8, Ordering};
2
3use alloc::collections::BTreeMap;
4
5use cubecl_ir::{ExpandElement, Scope, Variable, VariableKind};
6use cubecl_runtime::config::{GlobalConfig, compilation::CompilationLogLevel};
7
8use crate::ir::{Elem, Id, Item};
9use crate::prelude::KernelDefinition;
10use crate::{BufferInfo, KernelSettings, ScalarInfo};
11use crate::{KernelExpansion, KernelIntegrator};
12
13use super::Visibility;
14
15pub struct KernelBuilder {
17 pub scope: Scope,
19 buffers: Vec<BufferInfo>,
20 scalars: BTreeMap<Elem, usize>,
21 tensor_maps: Vec<Id>,
22}
23
24static DEBUG: AtomicI8 = AtomicI8::new(-1);
25
26impl KernelBuilder {
27 pub fn scalar(&mut self, elem: Elem) -> ExpandElement {
29 let id = self.scalars.entry(elem).or_default();
30 let expand = self.scope.scalar(*id as Id, elem);
31 *id += 1;
32 expand
33 }
34
35 fn buffer_id(&self) -> Id {
36 self.buffers.len() as Id + self.tensor_maps.len() as Id
37 }
38
39 pub fn output_tensor(&mut self, item: Item) -> ExpandElement {
41 let id = self.buffer_id();
42 self.buffers.push(BufferInfo {
43 id,
44 item,
45 visibility: Visibility::ReadWrite,
46 has_extended_meta: true,
47 });
48 self.scope.output(id, item)
49 }
50
51 pub fn tensor_map(&mut self) -> ExpandElement {
53 let id = self.buffer_id();
54 self.tensor_maps.push(id);
55 ExpandElement::Plain(Variable::new(
56 VariableKind::TensorMap(id),
57 Item::new(Elem::Bool),
58 ))
59 }
60
61 pub fn input_tensor(&mut self, item: Item) -> ExpandElement {
63 let id = self.buffer_id();
64 self.buffers.push(BufferInfo {
65 id,
66 item,
67 visibility: Visibility::Read,
68 has_extended_meta: true,
69 });
70 self.scope.input(id, item)
71 }
72
73 pub fn output_array(&mut self, item: Item) -> ExpandElement {
75 let id = self.buffer_id();
76 self.buffers.push(BufferInfo {
77 id,
78 item,
79 visibility: Visibility::ReadWrite,
80 has_extended_meta: false,
81 });
82 self.scope.output(id, item)
83 }
84
85 pub fn inplace_output(&mut self, position: Id) -> ExpandElement {
87 let input = self
88 .buffers
89 .get_mut(position as usize)
90 .expect("Position valid");
91
92 input.visibility = Visibility::ReadWrite;
93 self.scope.input(position, input.item)
94 }
95
96 pub fn input_array(&mut self, item: Item) -> ExpandElement {
98 let id = self.buffer_id();
99 self.buffers.push(BufferInfo {
100 id,
101 item,
102 visibility: Visibility::Read,
103 has_extended_meta: false,
104 });
105 self.scope.input(id, item)
106 }
107
108 pub fn build(self, settings: KernelSettings) -> KernelDefinition {
110 let scalars = self
111 .scalars
112 .into_iter()
113 .map(|(elem, count)| ScalarInfo { elem, count })
114 .collect();
115 KernelIntegrator::new(KernelExpansion {
116 scope: self.scope,
117 buffers: self.buffers,
118 scalars,
119 tensor_maps: self.tensor_maps,
120 })
121 .integrate(settings)
122 }
123
124 pub fn new() -> Self {
125 let debug = DEBUG.load(Ordering::Relaxed);
126 let debug = if debug == -1 {
127 let val = match GlobalConfig::get().compilation.logger.level {
128 CompilationLogLevel::Full => 1,
129 _ => 0,
130 };
131
132 DEBUG.store(val, Ordering::Relaxed);
133 val == 1
134 } else {
135 debug == 1
136 };
137 Self {
138 scope: Scope::root(debug),
139 buffers: Default::default(),
140 scalars: Default::default(),
141 tensor_maps: Default::default(),
142 }
143 }
144}
145
146impl Default for KernelBuilder {
147 fn default() -> Self {
148 Self::new()
149 }
150}