cubecl_core/compute/
kernel.rs

1use std::{fmt::Display, marker::PhantomData};
2
3use crate::{Compiler, Kernel, KernelId, KernelOptions};
4use alloc::sync::Arc;
5use cubecl_common::{CubeDim, ExecutionMode};
6use cubecl_ir::{Elem, Id, Item, Scope};
7use serde::{Deserialize, Serialize};
8
9/// A kernel, compiled in the target language
10pub struct CompiledKernel<C: Compiler> {
11    /// The name of the kernel entrypoint.
12    /// For example
13    ///
14    /// ```text
15    /// #[cube(launch)]
16    /// fn gelu_array<F: Float, R: Runtime>() {}
17    /// ```
18    ///
19    /// would have the entrypoint name "gelu_array".
20    pub entrypoint_name: String,
21
22    /// A fully qualified debug name of the kernel.
23    ///
24    /// For example
25    ///
26    /// ```text
27    /// #[cube(launch)]
28    /// fn gelu_array<F: Float, R: Runtime>() {}
29    /// ```
30    ///
31    /// would have a debug name such as
32    ///
33    /// ```text
34    /// gelu::gelu_array::GeluArray<
35    ///    cubecl_core::frontend::element::float::F32,
36    ///    cubecl_cuda::runtime::CudaRuntime,
37    /// >
38    /// ```
39    pub debug_name: Option<&'static str>,
40
41    /// Source code of the kernel
42    pub source: String,
43    /// In-memory representation of the kernel
44    pub repr: Option<C::Representation>,
45    /// Size of a cube for the compiled kernel
46    pub cube_dim: CubeDim,
47    /// Extra debugging information about the compiled kernel.
48    pub debug_info: Option<DebugInformation>,
49}
50
51/// Extra debugging information about the compiled kernel.
52#[derive(new)]
53pub struct DebugInformation {
54    /// The language tag of the source..
55    pub lang_tag: &'static str,
56    /// The compilation id.
57    pub id: KernelId,
58}
59
60impl Display for KernelId {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        match &self.info {
63            Some(info) => f.write_str(
64                format_str(
65                    format!("{:?}", info).as_str(),
66                    &[('(', ')'), ('[', ']'), ('{', '}')],
67                    true,
68                )
69                .as_str(),
70            ),
71            None => f.write_str("No info"),
72        }
73    }
74}
75
76impl<C: Compiler> Display for CompiledKernel<C> {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        f.write_str("\n[START_KERNEL_COMPILATION]")?;
79
80        if let Some(name) = self.debug_name {
81            if name.len() <= 32 {
82                f.write_fmt(format_args!("\nname: {name}"))?;
83            } else {
84                let name = format_str(name, &[('<', '>')], false);
85                f.write_fmt(format_args!("\nname: {name}"))?;
86            }
87        }
88
89        f.write_fmt(format_args!(
90            "
91cube_dim: ({}, {}, {})",
92            self.cube_dim.x, self.cube_dim.y, self.cube_dim.z,
93        ))?;
94
95        if let Some(info) = &self.debug_info {
96            f.write_fmt(format_args!(
97                "\ninfo: {}",
98                format_str(
99                    format!("{:?}", info.id).as_str(),
100                    &[('(', ')'), ('[', ']'), ('{', '}')],
101                    true
102                )
103            ))?;
104        }
105
106        f.write_fmt(format_args!(
107            "
108source:
109```{}
110{}
111```
112[END_KERNEL_COMPILATION]
113",
114            self.debug_info
115                .as_ref()
116                .map(|info| info.lang_tag)
117                .unwrap_or(""),
118            self.source
119        ))
120    }
121}
122
123fn format_str(kernel_id: &str, markers: &[(char, char)], include_space: bool) -> String {
124    let kernel_id = kernel_id.to_string();
125    let mut result = String::new();
126    let mut depth = 0;
127    let indentation = 4;
128
129    let mut prev = ' ';
130
131    for c in kernel_id.chars() {
132        if c == ' ' {
133            continue;
134        }
135
136        let mut found_marker = false;
137
138        for (start, end) in markers {
139            let (start, end) = (*start, *end);
140
141            if c == start {
142                depth += 1;
143                if prev != ' ' && include_space {
144                    result.push(' ');
145                }
146                result.push(start);
147                result.push('\n');
148                result.push_str(&" ".repeat(indentation * depth));
149                found_marker = true;
150            } else if c == end {
151                depth -= 1;
152                if prev != start {
153                    if prev == ' ' {
154                        result.pop();
155                    }
156                    result.push_str(",\n");
157                    result.push_str(&" ".repeat(indentation * depth));
158                    result.push(end);
159                } else {
160                    for _ in 0..(&" ".repeat(indentation * depth).len()) + 1 + indentation {
161                        result.pop();
162                    }
163                    result.push(end);
164                }
165                found_marker = true;
166            }
167        }
168
169        if found_marker {
170            prev = c;
171            continue;
172        }
173
174        if c == ',' && depth > 0 {
175            if prev == ' ' {
176                result.pop();
177            }
178
179            result.push_str(",\n");
180            result.push_str(&" ".repeat(indentation * depth));
181            continue;
182        }
183
184        if c == ':' && include_space {
185            result.push(c);
186            result.push(' ');
187            prev = ' ';
188        } else {
189            result.push(c);
190            prev = c;
191        }
192    }
193
194    result
195}
196
197#[derive(Debug, Clone)]
198#[allow(missing_docs)]
199pub struct KernelDefinition {
200    pub buffers: Vec<Binding>,
201    pub tensor_maps: Vec<Id>,
202    pub scalars: Vec<ScalarBinding>,
203    pub cube_dim: CubeDim,
204    pub body: Scope,
205    pub options: KernelOptions,
206}
207
208#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
209#[allow(missing_docs)]
210pub struct Binding {
211    pub id: Id,
212    pub location: Location,
213    pub visibility: Visibility,
214    pub item: Item,
215    pub size: Option<usize>,
216    pub has_extended_meta: bool,
217}
218
219#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
220#[allow(missing_docs)]
221pub struct ScalarBinding {
222    pub elem: Elem,
223    pub count: usize,
224}
225
226#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
227#[allow(missing_docs)]
228pub enum Location {
229    Storage,
230    Cube,
231}
232
233#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
234#[allow(missing_docs)]
235pub enum Visibility {
236    Read,
237    ReadWrite,
238}
239
240/// Kernel trait with the ComputeShader that will be compiled and cached based on the
241/// provided id.
242pub trait CubeTask<C: Compiler>: Send + Sync {
243    /// Identifier for the kernel, used for caching kernel compilation.
244    fn id(&self) -> KernelId;
245    /// Compile the kernel into source
246    fn compile(
247        &self,
248        compiler: &mut C,
249        compilation_options: &C::CompilationOptions,
250        mode: ExecutionMode,
251    ) -> CompiledKernel<C>;
252    fn name(&self) -> &'static str {
253        core::any::type_name::<Self>()
254    }
255}
256
257/// Wraps a [kernel](Kernel) to create a [cube task](CubeTask).
258#[derive(new)]
259pub struct KernelTask<C: Compiler, K: Kernel> {
260    kernel_definition: K,
261    _compiler: PhantomData<C>,
262}
263
264impl<C: Compiler, K: Kernel> CubeTask<C> for KernelTask<C, K> {
265    fn compile(
266        &self,
267        compiler: &mut C,
268        compilation_options: &C::CompilationOptions,
269        mode: ExecutionMode,
270    ) -> CompiledKernel<C> {
271        let gpu_ir = self.kernel_definition.define();
272        let entrypoint_name = gpu_ir.options.kernel_name.clone();
273        let cube_dim = gpu_ir.cube_dim;
274        let lower_level_ir = compiler.compile(gpu_ir, compilation_options, mode);
275
276        CompiledKernel {
277            entrypoint_name,
278            debug_name: Some(core::any::type_name::<K>()),
279            source: lower_level_ir.to_string(),
280            repr: Some(lower_level_ir),
281            cube_dim,
282            debug_info: None,
283        }
284    }
285
286    fn id(&self) -> KernelId {
287        self.kernel_definition.id().clone()
288    }
289
290    fn name(&self) -> &'static str {
291        core::any::type_name::<K>()
292    }
293}
294
295impl<C: Compiler> CubeTask<C> for Arc<dyn CubeTask<C>> {
296    fn compile(
297        &self,
298        compiler: &mut C,
299        compilation_options: &C::CompilationOptions,
300        mode: ExecutionMode,
301    ) -> CompiledKernel<C> {
302        self.as_ref().compile(compiler, compilation_options, mode)
303    }
304
305    fn id(&self) -> KernelId {
306        self.as_ref().id()
307    }
308    fn name(&self) -> &'static str {
309        self.as_ref().name()
310    }
311}
312
313impl<C: Compiler> CubeTask<C> for Box<dyn CubeTask<C>> {
314    fn compile(
315        &self,
316        compiler: &mut C,
317        compilation_options: &C::CompilationOptions,
318        mode: ExecutionMode,
319    ) -> CompiledKernel<C> {
320        self.as_ref().compile(compiler, compilation_options, mode)
321    }
322
323    fn id(&self) -> KernelId {
324        self.as_ref().id()
325    }
326
327    fn name(&self) -> &'static str {
328        self.as_ref().name()
329    }
330}