cubecl_core/compute/
builder.rs1use crate::ir::{Elem, Id, Item, Visibility};
2use crate::prelude::KernelDefinition;
3use crate::KernelSettings;
4use crate::{
5 frontend::{CubeContext, ExpandElement},
6 InputInfo, KernelExpansion, KernelIntegrator, OutputInfo,
7};
8use std::collections::HashMap;
9
10pub struct KernelBuilder {
12 pub context: CubeContext,
14 inputs: Vec<InputInfo>,
15 outputs: Vec<OutputInfo>,
16 indices: HashMap<Elem, usize>,
17 num_input: Id,
18 num_output: Id,
19}
20
21impl KernelBuilder {
22 pub fn scalar(&mut self, elem: Elem) -> ExpandElement {
24 let index = match self.indices.get_mut(&elem) {
25 Some(index) => match self.inputs.get_mut(*index).unwrap() {
26 InputInfo::Scalar { elem: _, size } => {
27 *size += 1;
28 *size as Id - 1
29 }
30 _ => panic!("Should be a scalar."),
31 },
32 None => {
33 self.indices.insert(elem, self.inputs.len());
34 self.inputs.push(InputInfo::Scalar { size: 1, elem });
35 0
36 }
37 };
38
39 self.context.scalar(index, elem)
40 }
41
42 pub fn output_tensor(&mut self, item: Item) -> ExpandElement {
44 self.outputs.push(OutputInfo::Array {
45 item,
46 has_extended_meta: true,
47 });
48 let variable = self.context.output(self.num_output, item);
49 self.num_output += 1;
50
51 variable
52 }
53
54 pub fn input_tensor(&mut self, item: Item) -> ExpandElement {
56 self.inputs.push(InputInfo::Array {
57 item,
58 visibility: Visibility::Read,
59 has_extended_meta: true,
60 });
61 let variable = self.context.input(self.num_input, item);
62 self.num_input += 1;
63 variable
64 }
65
66 pub fn output_array(&mut self, item: Item) -> ExpandElement {
68 self.outputs.push(OutputInfo::Array {
69 item,
70 has_extended_meta: false,
71 });
72 let variable = self.context.output(self.num_output, item);
73 self.num_output += 1;
74
75 variable
76 }
77
78 pub fn inplace_output(&mut self, position: Id) -> ExpandElement {
80 let input = self
81 .inputs
82 .get_mut(position as usize)
83 .expect("Position valid");
84
85 if let InputInfo::Array {
86 visibility, item, ..
87 } = input
88 {
89 *visibility = Visibility::ReadWrite;
90 let variable = self.context.input(position, *item);
91 return variable;
92 }
93
94 panic!("No input found at position {position}");
95 }
96
97 pub fn input_array(&mut self, item: Item) -> ExpandElement {
99 self.inputs.push(InputInfo::Array {
100 item,
101 visibility: Visibility::Read,
102 has_extended_meta: false,
103 });
104 let variable = self.context.input(self.num_input, item);
105 self.num_input += 1;
106 variable
107 }
108
109 pub fn build(self, settings: KernelSettings) -> KernelDefinition {
111 KernelIntegrator::new(KernelExpansion {
112 scope: self.context.into_scope(),
113 inputs: self.inputs,
114 outputs: self.outputs,
115 kernel_name: settings.kernel_name.clone(),
116 })
117 .integrate(settings)
118 }
119
120 pub fn new() -> Self {
121 Self {
122 context: CubeContext::root(),
123 inputs: Vec::new(),
124 outputs: Vec::new(),
125 indices: HashMap::new(),
126 num_input: 0,
127 num_output: 0,
128 }
129 }
130}
131
132impl Default for KernelBuilder {
133 fn default() -> Self {
134 Self::new()
135 }
136}