cubecl_core/compute/
builder.rs

1use std::{
2    rc::Rc,
3    sync::atomic::{AtomicI8, Ordering},
4};
5
6use alloc::collections::BTreeMap;
7
8use cubecl_ir::{ExpandElement, Scope, StorageType, TargetProperties, Variable, VariableKind};
9use cubecl_runtime::config::{GlobalConfig, compilation::CompilationLogLevel};
10
11use crate::ir::{Id, Type};
12use crate::prelude::KernelDefinition;
13use crate::{BufferInfo, KernelSettings, ScalarInfo};
14use crate::{KernelExpansion, KernelIntegrator};
15
16use super::Visibility;
17
18/// Prepare a kernel to create a [kernel definition](crate::KernelDefinition).
19pub struct KernelBuilder {
20    /// Cube [scope](Scope).
21    pub scope: Scope,
22    buffers: Vec<BufferInfo>,
23    scalars: BTreeMap<StorageType, usize>,
24    tensor_maps: Vec<BufferInfo>,
25}
26
27static DEBUG: AtomicI8 = AtomicI8::new(-1);
28
29impl KernelBuilder {
30    /// Register a scalar and return the [element](ExpandElement) to be used for kernel expansion.
31    pub fn scalar(&mut self, storage: StorageType) -> ExpandElement {
32        let id = self.scalars.entry(storage).or_default();
33        let expand = self.scope.scalar(*id as Id, storage);
34        *id += 1;
35        expand
36    }
37
38    fn buffer_id(&self) -> Id {
39        self.buffers.len() as Id + self.tensor_maps.len() as Id
40    }
41
42    /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
43    pub fn output_tensor(&mut self, item: Type) -> ExpandElement {
44        let id = self.buffer_id();
45        self.buffers.push(BufferInfo {
46            id,
47            item,
48            visibility: Visibility::ReadWrite,
49            has_extended_meta: true,
50        });
51        self.scope.output(id, item)
52    }
53
54    /// Register a tensor map and return the [element](ExpandElement) to be used for kernel expansion.
55    pub fn input_tensor_map(&mut self, item: Type) -> ExpandElement {
56        let id = self.buffer_id();
57        self.tensor_maps.push(BufferInfo {
58            id,
59            item,
60            visibility: Visibility::ReadWrite,
61            has_extended_meta: true,
62        });
63        ExpandElement::Plain(Variable::new(VariableKind::TensorMapInput(id), item))
64    }
65
66    /// Register a tensor map and return the [element](ExpandElement) to be used for kernel expansion.
67    pub fn output_tensor_map(&mut self, item: Type) -> ExpandElement {
68        let id = self.buffer_id();
69        self.tensor_maps.push(BufferInfo {
70            id,
71            item,
72            visibility: Visibility::Read,
73            has_extended_meta: true,
74        });
75        ExpandElement::Plain(Variable::new(VariableKind::TensorMapOutput(id), item))
76    }
77
78    /// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
79    pub fn input_tensor(&mut self, item: Type) -> ExpandElement {
80        let id = self.buffer_id();
81        self.buffers.push(BufferInfo {
82            id,
83            item,
84            visibility: Visibility::Read,
85            has_extended_meta: true,
86        });
87        self.scope.input(id, item)
88    }
89
90    /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
91    pub fn output_array(&mut self, item: Type) -> ExpandElement {
92        let id = self.buffer_id();
93        self.buffers.push(BufferInfo {
94            id,
95            item,
96            visibility: Visibility::ReadWrite,
97            has_extended_meta: false,
98        });
99        self.scope.output(id, item)
100    }
101
102    /// Register an output that uses the same resource as the input as the given position.
103    pub fn inplace_output(&mut self, position: Id) -> ExpandElement {
104        let input = self
105            .buffers
106            .get_mut(position as usize)
107            .expect("Position valid");
108
109        input.visibility = Visibility::ReadWrite;
110        self.scope.input(position, input.item)
111    }
112
113    /// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
114    pub fn input_array(&mut self, item: Type) -> ExpandElement {
115        let id = self.buffer_id();
116        self.buffers.push(BufferInfo {
117            id,
118            item,
119            visibility: Visibility::Read,
120            has_extended_meta: false,
121        });
122        self.scope.input(id, item)
123    }
124
125    pub fn runtime_properties(&mut self, properties: TargetProperties) {
126        self.scope.runtime_properties = Rc::new(properties);
127    }
128
129    /// Build the [kernel definition](KernelDefinition).
130    pub fn build(self, settings: KernelSettings) -> KernelDefinition {
131        let scalars = self
132            .scalars
133            .into_iter()
134            .map(|(ty, count)| ScalarInfo { ty, count })
135            .collect();
136        KernelIntegrator::new(KernelExpansion {
137            scope: self.scope,
138            buffers: self.buffers,
139            scalars,
140            tensor_maps: self.tensor_maps,
141        })
142        .integrate(settings)
143    }
144
145    pub fn new() -> Self {
146        let debug = DEBUG.load(Ordering::Relaxed);
147        let debug = if debug == -1 {
148            let val = match GlobalConfig::get().compilation.logger.level {
149                CompilationLogLevel::Full => 1,
150                _ => 0,
151            };
152
153            DEBUG.store(val, Ordering::Relaxed);
154            val == 1
155        } else {
156            debug == 1
157        };
158        Self {
159            scope: Scope::root(debug),
160            buffers: Default::default(),
161            scalars: Default::default(),
162            tensor_maps: Default::default(),
163        }
164    }
165}
166
167impl Default for KernelBuilder {
168    fn default() -> Self {
169        Self::new()
170    }
171}