cubecl_core/compute/
builder.rs

1use std::sync::atomic::{AtomicI8, Ordering};
2
3use alloc::collections::BTreeMap;
4
5use cubecl_ir::{ExpandElement, Scope, Variable, VariableKind};
6use cubecl_runtime::config::{GlobalConfig, compilation::CompilationLogLevel};
7
8use crate::ir::{Elem, Id, Item};
9use crate::prelude::KernelDefinition;
10use crate::{BufferInfo, KernelSettings, ScalarInfo};
11use crate::{KernelExpansion, KernelIntegrator};
12
13use super::Visibility;
14
15/// Prepare a kernel to create a [kernel definition](crate::KernelDefinition).
16pub struct KernelBuilder {
17    /// Cube [scope](Scope).
18    pub scope: Scope,
19    buffers: Vec<BufferInfo>,
20    scalars: BTreeMap<Elem, usize>,
21    tensor_maps: Vec<Id>,
22}
23
24static DEBUG: AtomicI8 = AtomicI8::new(-1);
25
26impl KernelBuilder {
27    /// Register a scalar and return the [element](ExpandElement) to be used for kernel expansion.
28    pub fn scalar(&mut self, elem: Elem) -> ExpandElement {
29        let id = self.scalars.entry(elem).or_default();
30        let expand = self.scope.scalar(*id as Id, elem);
31        *id += 1;
32        expand
33    }
34
35    fn buffer_id(&self) -> Id {
36        self.buffers.len() as Id + self.tensor_maps.len() as Id
37    }
38
39    /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
40    pub fn output_tensor(&mut self, item: Item) -> ExpandElement {
41        let id = self.buffer_id();
42        self.buffers.push(BufferInfo {
43            id,
44            item,
45            visibility: Visibility::ReadWrite,
46            has_extended_meta: true,
47        });
48        self.scope.output(id, item)
49    }
50
51    /// Register a tensor map and return the [element](ExpandElement) to be used for kernel expansion.
52    pub fn tensor_map(&mut self) -> ExpandElement {
53        let id = self.buffer_id();
54        self.tensor_maps.push(id);
55        ExpandElement::Plain(Variable::new(
56            VariableKind::TensorMap(id),
57            Item::new(Elem::Bool),
58        ))
59    }
60
61    /// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
62    pub fn input_tensor(&mut self, item: Item) -> ExpandElement {
63        let id = self.buffer_id();
64        self.buffers.push(BufferInfo {
65            id,
66            item,
67            visibility: Visibility::Read,
68            has_extended_meta: true,
69        });
70        self.scope.input(id, item)
71    }
72
73    /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
74    pub fn output_array(&mut self, item: Item) -> ExpandElement {
75        let id = self.buffer_id();
76        self.buffers.push(BufferInfo {
77            id,
78            item,
79            visibility: Visibility::ReadWrite,
80            has_extended_meta: false,
81        });
82        self.scope.output(id, item)
83    }
84
85    /// Register an output that uses the same resource as the input as the given position.
86    pub fn inplace_output(&mut self, position: Id) -> ExpandElement {
87        let input = self
88            .buffers
89            .get_mut(position as usize)
90            .expect("Position valid");
91
92        input.visibility = Visibility::ReadWrite;
93        self.scope.input(position, input.item)
94    }
95
96    /// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
97    pub fn input_array(&mut self, item: Item) -> ExpandElement {
98        let id = self.buffer_id();
99        self.buffers.push(BufferInfo {
100            id,
101            item,
102            visibility: Visibility::Read,
103            has_extended_meta: false,
104        });
105        self.scope.input(id, item)
106    }
107
108    /// Build the [kernel definition](KernelDefinition).
109    pub fn build(self, settings: KernelSettings) -> KernelDefinition {
110        let scalars = self
111            .scalars
112            .into_iter()
113            .map(|(elem, count)| ScalarInfo { elem, count })
114            .collect();
115        KernelIntegrator::new(KernelExpansion {
116            scope: self.scope,
117            buffers: self.buffers,
118            scalars,
119            tensor_maps: self.tensor_maps,
120        })
121        .integrate(settings)
122    }
123
124    pub fn new() -> Self {
125        let debug = DEBUG.load(Ordering::Relaxed);
126        let debug = if debug == -1 {
127            let val = match GlobalConfig::get().compilation.logger.level {
128                CompilationLogLevel::Full => 1,
129                _ => 0,
130            };
131
132            DEBUG.store(val, Ordering::Relaxed);
133            val == 1
134        } else {
135            debug == 1
136        };
137        Self {
138            scope: Scope::root(debug),
139            buffers: Default::default(),
140            scalars: Default::default(),
141            tensor_maps: Default::default(),
142        }
143    }
144}
145
146impl Default for KernelBuilder {
147    fn default() -> Self {
148        Self::new()
149    }
150}