Skip to main content

cubecl_core/compute/
builder.rs

1use alloc::{rc::Rc, vec::Vec};
2use core::sync::atomic::{AtomicI8, Ordering};
3
4use crate::{
5    BufferInfo, KernelExpansion, KernelIntegrator, KernelSettings, ScalarInfo,
6    ir::{Id, Type},
7    prelude::KernelDefinition,
8};
9use alloc::collections::BTreeMap;
10use cubecl_ir::{
11    DeviceProperties, ExpandElement, Scope, StorageType, TargetProperties, Variable, VariableKind,
12};
13use cubecl_runtime::{
14    config::{GlobalConfig, compilation::CompilationLogLevel},
15    kernel::Visibility,
16};
17
18/// Prepare a kernel to create a [`KernelDefinition`].
19pub struct KernelBuilder {
20    /// Cube [scope](Scope).
21    pub scope: Scope,
22    buffers: Vec<BufferInfo>,
23    scalars: BTreeMap<StorageType, usize>,
24    tensor_maps: Vec<BufferInfo>,
25}
26
27static DEBUG: AtomicI8 = AtomicI8::new(-1);
28
29impl KernelBuilder {
30    /// Register a scalar and return the [element](ExpandElement) to be used for kernel expansion.
31    pub fn scalar(&mut self, storage: StorageType) -> ExpandElement {
32        let id = self.scalars.entry(storage).or_default();
33        let expand = self.scope.scalar(*id as Id, storage);
34        *id += 1;
35        expand
36    }
37
38    fn buffer_id(&self) -> Id {
39        self.buffers.len() as Id + self.tensor_maps.len() as Id
40    }
41
42    /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
43    pub fn output_tensor(&mut self, item: Type) -> ExpandElement {
44        let id = self.buffer_id();
45        self.buffers.push(BufferInfo {
46            id,
47            item,
48            visibility: Visibility::ReadWrite,
49            has_extended_meta: true,
50        });
51        self.scope.output(id, item)
52    }
53
54    /// Register a tensor map and return the [element](ExpandElement) to be used for kernel expansion.
55    pub fn input_tensor_map(&mut self, item: Type) -> ExpandElement {
56        let id = self.buffer_id();
57        self.tensor_maps.push(BufferInfo {
58            id,
59            item,
60            visibility: Visibility::ReadWrite,
61            has_extended_meta: true,
62        });
63        ExpandElement::Plain(Variable::new(VariableKind::TensorMapInput(id), item))
64    }
65
66    /// Register a tensor map and return the [element](ExpandElement) to be used for kernel expansion.
67    pub fn output_tensor_map(&mut self, item: Type) -> ExpandElement {
68        let id = self.buffer_id();
69        self.tensor_maps.push(BufferInfo {
70            id,
71            item,
72            visibility: Visibility::Read,
73            has_extended_meta: true,
74        });
75        ExpandElement::Plain(Variable::new(VariableKind::TensorMapOutput(id), item))
76    }
77
78    /// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
79    pub fn input_tensor(&mut self, item: Type) -> ExpandElement {
80        let id = self.buffer_id();
81        self.buffers.push(BufferInfo {
82            id,
83            item,
84            visibility: Visibility::Read,
85            has_extended_meta: true,
86        });
87        self.scope.input(id, item)
88    }
89
90    /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
91    pub fn output_array(&mut self, item: Type) -> ExpandElement {
92        let id = self.buffer_id();
93        self.buffers.push(BufferInfo {
94            id,
95            item,
96            visibility: Visibility::ReadWrite,
97            has_extended_meta: false,
98        });
99        self.scope.output(id, item)
100    }
101
102    /// Register an output that uses the same resource as the input as the given position.
103    pub fn inplace_output(&mut self, position: Id) -> ExpandElement {
104        let input = self
105            .buffers
106            .get_mut(position as usize)
107            .expect("Position valid");
108
109        input.visibility = Visibility::ReadWrite;
110        self.scope.input(position, input.item)
111    }
112
113    /// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
114    pub fn input_array(&mut self, item: Type) -> ExpandElement {
115        let id = self.buffer_id();
116        self.buffers.push(BufferInfo {
117            id,
118            item,
119            visibility: Visibility::Read,
120            has_extended_meta: false,
121        });
122        self.scope.input(id, item)
123    }
124
125    pub fn runtime_properties(&mut self, properties: TargetProperties) {
126        self.scope.runtime_properties = Rc::new(properties);
127    }
128
129    pub fn device_properties(&mut self, properties: &DeviceProperties) {
130        self.scope.device_properties(properties);
131    }
132
133    /// Build the [kernel definition](KernelDefinition).
134    pub fn build(self, settings: KernelSettings) -> KernelDefinition {
135        let scalars = self
136            .scalars
137            .into_iter()
138            .map(|(ty, count)| ScalarInfo { ty, count })
139            .collect();
140        KernelIntegrator::new(KernelExpansion {
141            scope: self.scope,
142            buffers: self.buffers,
143            scalars,
144            tensor_maps: self.tensor_maps,
145        })
146        .integrate(settings)
147    }
148
149    pub fn new() -> Self {
150        let debug = DEBUG.load(Ordering::Relaxed);
151        let debug = if debug == -1 {
152            let val = match GlobalConfig::get().compilation.logger.level {
153                CompilationLogLevel::Full => 1,
154                _ => 0,
155            };
156
157            DEBUG.store(val, Ordering::Relaxed);
158            val == 1
159        } else {
160            debug == 1
161        };
162
163        Self {
164            scope: Scope::root(debug),
165            buffers: Default::default(),
166            scalars: Default::default(),
167            tensor_maps: Default::default(),
168        }
169    }
170}
171
172impl Default for KernelBuilder {
173    fn default() -> Self {
174        Self::new()
175    }
176}