cubecl_core/compute/
builder.rs

1use std::{
2    rc::Rc,
3    sync::atomic::{AtomicI8, Ordering},
4};
5
6use alloc::collections::BTreeMap;
7
8use cubecl_ir::{ExpandElement, Scope, StorageType, TargetProperties, Variable, VariableKind};
9use cubecl_runtime::{
10    config::{GlobalConfig, compilation::CompilationLogLevel},
11    kernel::Visibility,
12};
13
14use crate::ir::{Id, Type};
15use crate::prelude::KernelDefinition;
16use crate::{BufferInfo, KernelSettings, ScalarInfo};
17use crate::{KernelExpansion, KernelIntegrator};
18
19/// Prepare a kernel to create a [kernel definition](crate::KernelDefinition).
20pub struct KernelBuilder {
21    /// Cube [scope](Scope).
22    pub scope: Scope,
23    buffers: Vec<BufferInfo>,
24    scalars: BTreeMap<StorageType, usize>,
25    tensor_maps: Vec<BufferInfo>,
26}
27
28static DEBUG: AtomicI8 = AtomicI8::new(-1);
29
30impl KernelBuilder {
31    /// Register a scalar and return the [element](ExpandElement) to be used for kernel expansion.
32    pub fn scalar(&mut self, storage: StorageType) -> ExpandElement {
33        let id = self.scalars.entry(storage).or_default();
34        let expand = self.scope.scalar(*id as Id, storage);
35        *id += 1;
36        expand
37    }
38
39    fn buffer_id(&self) -> Id {
40        self.buffers.len() as Id + self.tensor_maps.len() as Id
41    }
42
43    /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
44    pub fn output_tensor(&mut self, item: Type) -> ExpandElement {
45        let id = self.buffer_id();
46        self.buffers.push(BufferInfo {
47            id,
48            item,
49            visibility: Visibility::ReadWrite,
50            has_extended_meta: true,
51        });
52        self.scope.output(id, item)
53    }
54
55    /// Register a tensor map and return the [element](ExpandElement) to be used for kernel expansion.
56    pub fn input_tensor_map(&mut self, item: Type) -> ExpandElement {
57        let id = self.buffer_id();
58        self.tensor_maps.push(BufferInfo {
59            id,
60            item,
61            visibility: Visibility::ReadWrite,
62            has_extended_meta: true,
63        });
64        ExpandElement::Plain(Variable::new(VariableKind::TensorMapInput(id), item))
65    }
66
67    /// Register a tensor map and return the [element](ExpandElement) to be used for kernel expansion.
68    pub fn output_tensor_map(&mut self, item: Type) -> ExpandElement {
69        let id = self.buffer_id();
70        self.tensor_maps.push(BufferInfo {
71            id,
72            item,
73            visibility: Visibility::Read,
74            has_extended_meta: true,
75        });
76        ExpandElement::Plain(Variable::new(VariableKind::TensorMapOutput(id), item))
77    }
78
79    /// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
80    pub fn input_tensor(&mut self, item: Type) -> ExpandElement {
81        let id = self.buffer_id();
82        self.buffers.push(BufferInfo {
83            id,
84            item,
85            visibility: Visibility::Read,
86            has_extended_meta: true,
87        });
88        self.scope.input(id, item)
89    }
90
91    /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion.
92    pub fn output_array(&mut self, item: Type) -> ExpandElement {
93        let id = self.buffer_id();
94        self.buffers.push(BufferInfo {
95            id,
96            item,
97            visibility: Visibility::ReadWrite,
98            has_extended_meta: false,
99        });
100        self.scope.output(id, item)
101    }
102
103    /// Register an output that uses the same resource as the input as the given position.
104    pub fn inplace_output(&mut self, position: Id) -> ExpandElement {
105        let input = self
106            .buffers
107            .get_mut(position as usize)
108            .expect("Position valid");
109
110        input.visibility = Visibility::ReadWrite;
111        self.scope.input(position, input.item)
112    }
113
114    /// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.
115    pub fn input_array(&mut self, item: Type) -> ExpandElement {
116        let id = self.buffer_id();
117        self.buffers.push(BufferInfo {
118            id,
119            item,
120            visibility: Visibility::Read,
121            has_extended_meta: false,
122        });
123        self.scope.input(id, item)
124    }
125
126    pub fn runtime_properties(&mut self, properties: TargetProperties) {
127        self.scope.runtime_properties = Rc::new(properties);
128    }
129
130    /// Build the [kernel definition](KernelDefinition).
131    pub fn build(self, settings: KernelSettings) -> KernelDefinition {
132        let scalars = self
133            .scalars
134            .into_iter()
135            .map(|(ty, count)| ScalarInfo { ty, count })
136            .collect();
137        KernelIntegrator::new(KernelExpansion {
138            scope: self.scope,
139            buffers: self.buffers,
140            scalars,
141            tensor_maps: self.tensor_maps,
142        })
143        .integrate(settings)
144    }
145
146    pub fn new() -> Self {
147        let debug = DEBUG.load(Ordering::Relaxed);
148        let debug = if debug == -1 {
149            let val = match GlobalConfig::get().compilation.logger.level {
150                CompilationLogLevel::Full => 1,
151                _ => 0,
152            };
153
154            DEBUG.store(val, Ordering::Relaxed);
155            val == 1
156        } else {
157            debug == 1
158        };
159        Self {
160            scope: Scope::root(debug),
161            buffers: Default::default(),
162            scalars: Default::default(),
163            tensor_maps: Default::default(),
164        }
165    }
166}
167
168impl Default for KernelBuilder {
169    fn default() -> Self {
170        Self::new()
171    }
172}