cubecl_core/compute/
builder.rs

1use alloc::collections::BTreeMap;
2
3use cubecl_ir::{ExpandElement, Scope, Variable, VariableKind};
4use cubecl_runtime::debug::DebugLogger;
5
6use crate::ir::{Elem, Id, Item};
7use crate::prelude::KernelDefinition;
8use crate::{BufferInfo, KernelSettings, ScalarInfo};
9use crate::{KernelExpansion, KernelIntegrator};
10
11use super::Visibility;
12
13/// Prepare a kernel to create a [kernel definition](crate::KernelDefinition).
14pub struct KernelBuilder {
15    /// Cube [scope](Scope).
16    pub context: Scope,
17    buffers: Vec<BufferInfo>,
18    scalars: BTreeMap<Elem, usize>,
19    tensor_maps: Vec<Id>,
20}
21
22impl KernelBuilder {
23    /// Register a scalar and return the [element](ExpandElement) to be used for kernel expansion.
24    pub fn scalar(&mut self, elem: Elem) -> ExpandElement {
25        let id = self.scalars.entry(elem).or_default();
26        let expand = self.context.scalar(*id as Id, elem);
27        *id += 1;
28        expand
29    }
30
31    fn buffer_id(&self) -> Id {
32        self.buffers.len() as Id + self.tensor_maps.len() as Id
33    }
34
35    /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
36    pub fn output_tensor(&mut self, item: Item) -> ExpandElement {
37        let id = self.buffer_id();
38        self.buffers.push(BufferInfo {
39            id,
40            item,
41            visibility: Visibility::ReadWrite,
42            has_extended_meta: true,
43        });
44        self.context.output(id, item)
45    }
46
47    /// Register a tensor map and return the [element](ExpandElement) to be used for kernel expansion.
48    pub fn tensor_map(&mut self) -> ExpandElement {
49        let id = self.buffer_id();
50        self.tensor_maps.push(id);
51        ExpandElement::Plain(Variable::new(
52            VariableKind::TensorMap(id),
53            Item::new(Elem::Bool),
54        ))
55    }
56
57    /// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
58    pub fn input_tensor(&mut self, item: Item) -> ExpandElement {
59        let id = self.buffer_id();
60        self.buffers.push(BufferInfo {
61            id,
62            item,
63            visibility: Visibility::Read,
64            has_extended_meta: true,
65        });
66        self.context.input(id, item)
67    }
68
69    /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
70    pub fn output_array(&mut self, item: Item) -> ExpandElement {
71        let id = self.buffer_id();
72        self.buffers.push(BufferInfo {
73            id,
74            item,
75            visibility: Visibility::ReadWrite,
76            has_extended_meta: false,
77        });
78        self.context.output(id, item)
79    }
80
81    /// Register an output that uses the same resource as the input as the given position.
82    pub fn inplace_output(&mut self, position: Id) -> ExpandElement {
83        let input = self
84            .buffers
85            .get_mut(position as usize)
86            .expect("Position valid");
87
88        input.visibility = Visibility::ReadWrite;
89        self.context.input(position, input.item)
90    }
91
92    /// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
93    pub fn input_array(&mut self, item: Item) -> ExpandElement {
94        let id = self.buffer_id();
95        self.buffers.push(BufferInfo {
96            id,
97            item,
98            visibility: Visibility::Read,
99            has_extended_meta: false,
100        });
101        self.context.input(id, item)
102    }
103
104    /// Build the [kernel definition](KernelDefinition).
105    pub fn build(self, settings: KernelSettings) -> KernelDefinition {
106        let scalars = self
107            .scalars
108            .into_iter()
109            .map(|(elem, count)| ScalarInfo { elem, count })
110            .collect();
111        KernelIntegrator::new(KernelExpansion {
112            scope: self.context,
113            buffers: self.buffers,
114            scalars,
115            tensor_maps: self.tensor_maps,
116        })
117        .integrate(settings)
118    }
119
120    pub fn new() -> Self {
121        Self {
122            context: Scope::root(DebugLogger::default().is_activated()),
123            buffers: Default::default(),
124            scalars: Default::default(),
125            tensor_maps: Default::default(),
126        }
127    }
128}
129
130impl Default for KernelBuilder {
131    fn default() -> Self {
132        Self::new()
133    }
134}