cubecl_core/codegen/
integrator.rs

1use cubecl_common::CubeDim;
2use cubecl_ir::{Id, Scope, StorageType, Type};
3
4use crate::compute::{Binding, KernelDefinition, Location, ScalarBinding, Visibility};
5
6/// The kernel integrator allows you to create a [kernel definition](KernelDefinition) based on
7/// [kernel expansion](KernelExpansion) and [kernel settings](KernelSettings).
8#[derive(Clone)]
9pub struct KernelIntegrator {
10    expansion: KernelExpansion,
11    buffer_bindings: Vec<Binding>,
12    scalar_bindings: Vec<ScalarBinding>,
13    tensor_maps: Vec<Binding>,
14}
15
16/// The information necessary to compile a [kernel definition](KernelDefinition).
17#[derive(Clone)]
18pub struct KernelExpansion {
19    pub buffers: Vec<BufferInfo>,
20    pub scalars: Vec<ScalarInfo>,
21    pub tensor_maps: Vec<BufferInfo>,
22    pub scope: Scope,
23}
24
25#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
26pub struct KernelSettings {
27    pub cube_dim: CubeDim,
28    pub options: KernelOptions,
29}
30
31#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
32pub struct KernelOptions {
33    pub kernel_name: String,
34    pub debug_symbols: bool,
35    pub cluster_dim: Option<CubeDim>,
36}
37
38impl KernelSettings {
39    /// Set cube dimension.
40    #[allow(dead_code)]
41    pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
42        self.cube_dim = cube_dim;
43        self
44    }
45
46    /// Set kernel name.
47    #[allow(dead_code)]
48    pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
49        self.options.kernel_name = name.as_ref().to_string();
50        self
51    }
52
53    /// Activate debug symbols
54    pub fn debug_symbols(mut self) -> Self {
55        self.options.debug_symbols = true;
56        self
57    }
58
59    /// Set cluster dim
60    pub fn cluster_dim(mut self, cluster_dim: CubeDim) -> Self {
61        self.options.cluster_dim = Some(cluster_dim);
62        self
63    }
64}
65
66/// Information related to a buffer binding.
67#[derive(Clone, Debug)]
68pub struct BufferInfo {
69    pub id: Id,
70    pub item: Type,
71    pub visibility: Visibility,
72    /// Whether this input has extended metadata (rank, shape, strides)
73    pub has_extended_meta: bool,
74}
75
76/// Information related to a scalar input.
77#[derive(Clone, Debug)]
78pub struct ScalarInfo {
79    pub ty: StorageType,
80    pub count: usize,
81}
82
83impl KernelIntegrator {
84    /// Starts a new compilation.
85    pub fn new(info: KernelExpansion) -> Self {
86        Self {
87            expansion: info,
88            buffer_bindings: Default::default(),
89            scalar_bindings: Default::default(),
90            tensor_maps: Default::default(),
91        }
92    }
93
94    /// Performs the compilation with the provided [settings](KernelSettings).
95    pub fn integrate(mut self, settings: KernelSettings) -> KernelDefinition {
96        self.register_buffers();
97        self.register_scalars();
98        self.register_tensor_maps();
99
100        self.scalar_bindings.sort_by_key(|binding| binding.ty);
101
102        KernelDefinition {
103            buffers: self.buffer_bindings,
104            tensor_maps: self.tensor_maps,
105            scalars: self.scalar_bindings,
106            cube_dim: settings.cube_dim,
107            body: self.expansion.scope,
108            options: settings.options,
109        }
110    }
111
112    fn register_buffers(&mut self) {
113        for buffer in self.expansion.buffers.drain(..) {
114            self.buffer_bindings.push(Binding {
115                id: buffer.id,
116                ty: buffer.item,
117                visibility: buffer.visibility,
118                location: Location::Storage,
119                has_extended_meta: buffer.has_extended_meta,
120                size: None,
121            });
122        }
123    }
124
125    fn register_scalars(&mut self) {
126        for scalar in self.expansion.scalars.drain(..) {
127            self.scalar_bindings.push(ScalarBinding {
128                ty: scalar.ty,
129                count: scalar.count,
130            });
131        }
132    }
133
134    fn register_tensor_maps(&mut self) {
135        for buffer in self.expansion.tensor_maps.drain(..) {
136            self.tensor_maps.push(Binding {
137                id: buffer.id,
138                ty: buffer.item,
139                visibility: buffer.visibility,
140                location: Location::Storage,
141                has_extended_meta: buffer.has_extended_meta,
142                size: None,
143            });
144        }
145    }
146}