cubecl_core/compute/
builder.rs

1use std::{
2    rc::Rc,
3    sync::atomic::{AtomicI8, Ordering},
4};
5
6use crate::{
7    BufferInfo, KernelExpansion, KernelIntegrator, KernelSettings, ScalarInfo,
8    ir::{Id, Type},
9    prelude::KernelDefinition,
10};
11use alloc::collections::BTreeMap;
12use cubecl_ir::{
13    DeviceProperties, ExpandElement, Scope, StorageType, TargetProperties, Variable, VariableKind,
14};
15use cubecl_runtime::{
16    config::{GlobalConfig, compilation::CompilationLogLevel},
17    kernel::Visibility,
18};
19
20/// Prepare a kernel to create a [kernel definition](crate::KernelDefinition).
21pub struct KernelBuilder {
22    /// Cube [scope](Scope).
23    pub scope: Scope,
24    buffers: Vec<BufferInfo>,
25    scalars: BTreeMap<StorageType, usize>,
26    tensor_maps: Vec<BufferInfo>,
27}
28
29static DEBUG: AtomicI8 = AtomicI8::new(-1);
30
31impl KernelBuilder {
32    /// Register a scalar and return the [element](ExpandElement) to be used for kernel expansion.
33    pub fn scalar(&mut self, storage: StorageType) -> ExpandElement {
34        let id = self.scalars.entry(storage).or_default();
35        let expand = self.scope.scalar(*id as Id, storage);
36        *id += 1;
37        expand
38    }
39
40    fn buffer_id(&self) -> Id {
41        self.buffers.len() as Id + self.tensor_maps.len() as Id
42    }
43
44    /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
45    pub fn output_tensor(&mut self, item: Type) -> ExpandElement {
46        let id = self.buffer_id();
47        self.buffers.push(BufferInfo {
48            id,
49            item,
50            visibility: Visibility::ReadWrite,
51            has_extended_meta: true,
52        });
53        self.scope.output(id, item)
54    }
55
56    /// Register a tensor map and return the [element](ExpandElement) to be used for kernel expansion.
57    pub fn input_tensor_map(&mut self, item: Type) -> ExpandElement {
58        let id = self.buffer_id();
59        self.tensor_maps.push(BufferInfo {
60            id,
61            item,
62            visibility: Visibility::ReadWrite,
63            has_extended_meta: true,
64        });
65        ExpandElement::Plain(Variable::new(VariableKind::TensorMapInput(id), item))
66    }
67
68    /// Register a tensor map and return the [element](ExpandElement) to be used for kernel expansion.
69    pub fn output_tensor_map(&mut self, item: Type) -> ExpandElement {
70        let id = self.buffer_id();
71        self.tensor_maps.push(BufferInfo {
72            id,
73            item,
74            visibility: Visibility::Read,
75            has_extended_meta: true,
76        });
77        ExpandElement::Plain(Variable::new(VariableKind::TensorMapOutput(id), item))
78    }
79
80    /// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
81    pub fn input_tensor(&mut self, item: Type) -> ExpandElement {
82        let id = self.buffer_id();
83        self.buffers.push(BufferInfo {
84            id,
85            item,
86            visibility: Visibility::Read,
87            has_extended_meta: true,
88        });
89        self.scope.input(id, item)
90    }
91
92    /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
93    pub fn output_array(&mut self, item: Type) -> ExpandElement {
94        let id = self.buffer_id();
95        self.buffers.push(BufferInfo {
96            id,
97            item,
98            visibility: Visibility::ReadWrite,
99            has_extended_meta: false,
100        });
101        self.scope.output(id, item)
102    }
103
104    /// Register an output that uses the same resource as the input as the given position.
105    pub fn inplace_output(&mut self, position: Id) -> ExpandElement {
106        let input = self
107            .buffers
108            .get_mut(position as usize)
109            .expect("Position valid");
110
111        input.visibility = Visibility::ReadWrite;
112        self.scope.input(position, input.item)
113    }
114
115    /// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
116    pub fn input_array(&mut self, item: Type) -> ExpandElement {
117        let id = self.buffer_id();
118        self.buffers.push(BufferInfo {
119            id,
120            item,
121            visibility: Visibility::Read,
122            has_extended_meta: false,
123        });
124        self.scope.input(id, item)
125    }
126
127    pub fn runtime_properties(&mut self, properties: TargetProperties) {
128        self.scope.runtime_properties = Rc::new(properties);
129    }
130
131    pub fn device_properties(&mut self, properties: &DeviceProperties) {
132        self.scope.device_properties(properties);
133    }
134
135    /// Build the [kernel definition](KernelDefinition).
136    pub fn build(self, settings: KernelSettings) -> KernelDefinition {
137        let scalars = self
138            .scalars
139            .into_iter()
140            .map(|(ty, count)| ScalarInfo { ty, count })
141            .collect();
142        KernelIntegrator::new(KernelExpansion {
143            scope: self.scope,
144            buffers: self.buffers,
145            scalars,
146            tensor_maps: self.tensor_maps,
147        })
148        .integrate(settings)
149    }
150
151    pub fn new() -> Self {
152        let debug = DEBUG.load(Ordering::Relaxed);
153        let debug = if debug == -1 {
154            let val = match GlobalConfig::get().compilation.logger.level {
155                CompilationLogLevel::Full => 1,
156                _ => 0,
157            };
158
159            DEBUG.store(val, Ordering::Relaxed);
160            val == 1
161        } else {
162            debug == 1
163        };
164
165        Self {
166            scope: Scope::root(debug),
167            buffers: Default::default(),
168            scalars: Default::default(),
169            tensor_maps: Default::default(),
170        }
171    }
172}
173
174impl Default for KernelBuilder {
175    fn default() -> Self {
176        Self::new()
177    }
178}