cubecl_core/compute/
builder.rs1use alloc::collections::BTreeMap;
2
3use cubecl_ir::{ExpandElement, Scope, Variable, VariableKind};
4use cubecl_runtime::debug::DebugLogger;
5
6use crate::ir::{Elem, Id, Item};
7use crate::prelude::KernelDefinition;
8use crate::{BufferInfo, KernelSettings, ScalarInfo};
9use crate::{KernelExpansion, KernelIntegrator};
10
11use super::Visibility;
12
13pub struct KernelBuilder {
15 pub context: Scope,
17 buffers: Vec<BufferInfo>,
18 scalars: BTreeMap<Elem, usize>,
19 tensor_maps: Vec<Id>,
20}
21
22impl KernelBuilder {
23 pub fn scalar(&mut self, elem: Elem) -> ExpandElement {
25 let id = self.scalars.entry(elem).or_default();
26 let expand = self.context.scalar(*id as Id, elem);
27 *id += 1;
28 expand
29 }
30
31 fn buffer_id(&self) -> Id {
32 self.buffers.len() as Id + self.tensor_maps.len() as Id
33 }
34
35 pub fn output_tensor(&mut self, item: Item) -> ExpandElement {
37 let id = self.buffer_id();
38 self.buffers.push(BufferInfo {
39 id,
40 item,
41 visibility: Visibility::ReadWrite,
42 has_extended_meta: true,
43 });
44 self.context.output(id, item)
45 }
46
47 pub fn tensor_map(&mut self) -> ExpandElement {
49 let id = self.buffer_id();
50 self.tensor_maps.push(id);
51 ExpandElement::Plain(Variable::new(
52 VariableKind::TensorMap(id),
53 Item::new(Elem::Bool),
54 ))
55 }
56
57 pub fn input_tensor(&mut self, item: Item) -> ExpandElement {
59 let id = self.buffer_id();
60 self.buffers.push(BufferInfo {
61 id,
62 item,
63 visibility: Visibility::Read,
64 has_extended_meta: true,
65 });
66 self.context.input(id, item)
67 }
68
69 pub fn output_array(&mut self, item: Item) -> ExpandElement {
71 let id = self.buffer_id();
72 self.buffers.push(BufferInfo {
73 id,
74 item,
75 visibility: Visibility::ReadWrite,
76 has_extended_meta: false,
77 });
78 self.context.output(id, item)
79 }
80
81 pub fn inplace_output(&mut self, position: Id) -> ExpandElement {
83 let input = self
84 .buffers
85 .get_mut(position as usize)
86 .expect("Position valid");
87
88 input.visibility = Visibility::ReadWrite;
89 self.context.input(position, input.item)
90 }
91
92 pub fn input_array(&mut self, item: Item) -> ExpandElement {
94 let id = self.buffer_id();
95 self.buffers.push(BufferInfo {
96 id,
97 item,
98 visibility: Visibility::Read,
99 has_extended_meta: false,
100 });
101 self.context.input(id, item)
102 }
103
104 pub fn build(self, settings: KernelSettings) -> KernelDefinition {
106 let scalars = self
107 .scalars
108 .into_iter()
109 .map(|(elem, count)| ScalarInfo { elem, count })
110 .collect();
111 KernelIntegrator::new(KernelExpansion {
112 scope: self.context,
113 buffers: self.buffers,
114 scalars,
115 tensor_maps: self.tensor_maps,
116 })
117 .integrate(settings)
118 }
119
120 pub fn new() -> Self {
121 Self {
122 context: Scope::root(DebugLogger::default().is_activated()),
123 buffers: Default::default(),
124 scalars: Default::default(),
125 tensor_maps: Default::default(),
126 }
127 }
128}
129
130impl Default for KernelBuilder {
131 fn default() -> Self {
132 Self::new()
133 }
134}