cubecl_core/compute/
builder.rs

1use crate::ir::{Elem, Id, Item, Visibility};
2use crate::prelude::KernelDefinition;
3use crate::KernelSettings;
4use crate::{
5    frontend::{CubeContext, ExpandElement},
6    InputInfo, KernelExpansion, KernelIntegrator, OutputInfo,
7};
8use std::collections::HashMap;
9
10/// Prepare a kernel to create a [kernel definition](crate::KernelDefinition).
11pub struct KernelBuilder {
12    /// Cube [context](CubeContext).
13    pub context: CubeContext,
14    inputs: Vec<InputInfo>,
15    outputs: Vec<OutputInfo>,
16    indices: HashMap<Elem, usize>,
17    num_input: Id,
18    num_output: Id,
19}
20
21impl KernelBuilder {
22    /// Register a scalar and return the [element](ExpandElement) to be used for kernel expansion.
23    pub fn scalar(&mut self, elem: Elem) -> ExpandElement {
24        let index = match self.indices.get_mut(&elem) {
25            Some(index) => match self.inputs.get_mut(*index).unwrap() {
26                InputInfo::Scalar { elem: _, size } => {
27                    *size += 1;
28                    *size as Id - 1
29                }
30                _ => panic!("Should be a scalar."),
31            },
32            None => {
33                self.indices.insert(elem, self.inputs.len());
34                self.inputs.push(InputInfo::Scalar { size: 1, elem });
35                0
36            }
37        };
38
39        self.context.scalar(index, elem)
40    }
41
42    /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
43    pub fn output_tensor(&mut self, item: Item) -> ExpandElement {
44        self.outputs.push(OutputInfo::Array {
45            item,
46            has_extended_meta: true,
47        });
48        let variable = self.context.output(self.num_output, item);
49        self.num_output += 1;
50
51        variable
52    }
53
54    /// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
55    pub fn input_tensor(&mut self, item: Item) -> ExpandElement {
56        self.inputs.push(InputInfo::Array {
57            item,
58            visibility: Visibility::Read,
59            has_extended_meta: true,
60        });
61        let variable = self.context.input(self.num_input, item);
62        self.num_input += 1;
63        variable
64    }
65
66    /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
67    pub fn output_array(&mut self, item: Item) -> ExpandElement {
68        self.outputs.push(OutputInfo::Array {
69            item,
70            has_extended_meta: false,
71        });
72        let variable = self.context.output(self.num_output, item);
73        self.num_output += 1;
74
75        variable
76    }
77
78    /// Register an output that uses the same resource as the input as the given position.
79    pub fn inplace_output(&mut self, position: Id) -> ExpandElement {
80        let input = self
81            .inputs
82            .get_mut(position as usize)
83            .expect("Position valid");
84
85        if let InputInfo::Array {
86            visibility, item, ..
87        } = input
88        {
89            *visibility = Visibility::ReadWrite;
90            let variable = self.context.input(position, *item);
91            return variable;
92        }
93
94        panic!("No input found at position {position}");
95    }
96
97    /// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
98    pub fn input_array(&mut self, item: Item) -> ExpandElement {
99        self.inputs.push(InputInfo::Array {
100            item,
101            visibility: Visibility::Read,
102            has_extended_meta: false,
103        });
104        let variable = self.context.input(self.num_input, item);
105        self.num_input += 1;
106        variable
107    }
108
109    /// Build the [kernel definition](KernelDefinition).
110    pub fn build(self, settings: KernelSettings) -> KernelDefinition {
111        KernelIntegrator::new(KernelExpansion {
112            scope: self.context.into_scope(),
113            inputs: self.inputs,
114            outputs: self.outputs,
115            kernel_name: settings.kernel_name.clone(),
116        })
117        .integrate(settings)
118    }
119
120    pub fn new() -> Self {
121        Self {
122            context: CubeContext::root(),
123            inputs: Vec::new(),
124            outputs: Vec::new(),
125            indices: HashMap::new(),
126            num_input: 0,
127            num_output: 0,
128        }
129    }
130}
131
132impl Default for KernelBuilder {
133    fn default() -> Self {
134        Self::new()
135    }
136}