cubecl_core/codegen/
integrator.rs

1use cubecl_common::CubeDim;
2use cubecl_ir::{Elem, Id, Item, Scope};
3
4use crate::{
5    compute::{Binding, KernelDefinition, Location, ScalarBinding, Visibility},
6    prelude::FastMath,
7};
8
9/// The kernel integrator allows you to create a [kernel definition](KernelDefinition) based on
10/// [kernel expansion](KernelExpansion) and [kernel settings](KernelSettings).
11#[derive(Clone)]
12pub struct KernelIntegrator {
13    expansion: KernelExpansion,
14    buffer_bindings: Vec<Binding>,
15    scalar_bindings: Vec<ScalarBinding>,
16    tensor_maps: Vec<Id>,
17}
18
19/// The information necessary to compile a [kernel definition](KernelDefinition).
20#[derive(Clone)]
21pub struct KernelExpansion {
22    pub buffers: Vec<BufferInfo>,
23    pub scalars: Vec<ScalarInfo>,
24    pub tensor_maps: Vec<Id>,
25    pub scope: Scope,
26}
27
28#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
29pub struct KernelSettings {
30    pub cube_dim: CubeDim,
31    pub options: KernelOptions,
32}
33
34#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
35pub struct KernelOptions {
36    pub kernel_name: String,
37    pub debug_symbols: bool,
38    pub fp_math_mode: FastMath,
39    pub cluster_dim: Option<CubeDim>,
40}
41
42impl KernelSettings {
43    /// Set cube dimension.
44    #[allow(dead_code)]
45    pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
46        self.cube_dim = cube_dim;
47        self
48    }
49
50    /// Set kernel name.
51    #[allow(dead_code)]
52    pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
53        self.options.kernel_name = name.as_ref().to_string();
54        self
55    }
56
57    /// Activate debug symbols
58    pub fn debug_symbols(mut self) -> Self {
59        self.options.debug_symbols = true;
60        self
61    }
62
63    /// Set FP math mode
64    pub fn fp_math_mode(mut self, mode: FastMath) -> Self {
65        self.options.fp_math_mode = mode;
66        self
67    }
68
69    /// Set cluster dim
70    pub fn cluster_dim(mut self, cluster_dim: CubeDim) -> Self {
71        self.options.cluster_dim = Some(cluster_dim);
72        self
73    }
74}
75
76/// Information related to a buffer binding.
77#[derive(Clone, Debug)]
78pub struct BufferInfo {
79    pub id: Id,
80    pub item: Item,
81    pub visibility: Visibility,
82    /// Whether this input has extended metadata (rank, shape, strides)
83    pub has_extended_meta: bool,
84}
85
86/// Information related to a scalar input.
87#[derive(Clone, Debug)]
88pub struct ScalarInfo {
89    pub elem: Elem,
90    pub count: usize,
91}
92
93impl KernelIntegrator {
94    /// Starts a new compilation.
95    pub fn new(info: KernelExpansion) -> Self {
96        Self {
97            expansion: info,
98            buffer_bindings: Default::default(),
99            scalar_bindings: Default::default(),
100            tensor_maps: Default::default(),
101        }
102    }
103
104    /// Performs the compilation with the provided [settings](KernelSettings).
105    pub fn integrate(mut self, settings: KernelSettings) -> KernelDefinition {
106        self.register_buffers();
107        self.register_scalars();
108        self.register_tensor_maps();
109
110        self.scalar_bindings.sort_by_key(|binding| binding.elem);
111
112        KernelDefinition {
113            buffers: self.buffer_bindings,
114            tensor_maps: self.tensor_maps,
115            scalars: self.scalar_bindings,
116            cube_dim: settings.cube_dim,
117            body: self.expansion.scope,
118            options: settings.options,
119        }
120    }
121
122    fn register_buffers(&mut self) {
123        for buffer in self.expansion.buffers.drain(..) {
124            self.buffer_bindings.push(Binding {
125                id: buffer.id,
126                item: buffer.item,
127                visibility: buffer.visibility,
128                location: Location::Storage,
129                has_extended_meta: buffer.has_extended_meta,
130                size: None,
131            });
132        }
133    }
134
135    fn register_scalars(&mut self) {
136        for scalar in self.expansion.scalars.drain(..) {
137            self.scalar_bindings.push(ScalarBinding {
138                elem: scalar.elem,
139                count: scalar.count,
140            });
141        }
142    }
143
144    fn register_tensor_maps(&mut self) {
145        for id in self.expansion.tensor_maps.drain(..) {
146            self.tensor_maps.push(id);
147        }
148    }
149}