cubecl_core/compute/
kernel.rs

1use std::{fmt::Display, marker::PhantomData};
2
3use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel, KernelId};
4use alloc::sync::Arc;
5use cubecl_runtime::ExecutionMode;
6
7/// A kernel, compiled in the target language
8pub struct CompiledKernel<C: Compiler> {
9    /// The name of the kernel entrypoint.
10    /// For example
11    ///
12    /// ```text
13    /// #[cube(launch)]
14    /// fn gelu_array<F: Float, R: Runtime>() {}
15    /// ```
16    ///
17    /// would have the entrypoint name "gelu_array".
18    pub entrypoint_name: String,
19
20    /// A fully qualified debug name of the kernel.
21    ///
22    /// For example
23    ///
24    /// ```text
25    /// #[cube(launch)]
26    /// fn gelu_array<F: Float, R: Runtime>() {}
27    /// ```
28    ///
29    /// would have a debug name such as
30    ///
31    /// ```text
32    /// gelu::gelu_array::GeluArray<
33    ///    cubecl_core::frontend::element::float::F32,
34    ///    cubecl_cuda::runtime::CudaRuntime,
35    /// >
36    /// ```
37    pub debug_name: Option<&'static str>,
38
39    /// Source code of the kernel
40    pub source: String,
41    /// In-memory representation of the kernel
42    pub repr: Option<C::Representation>,
43    /// Size of a cube for the compiled kernel
44    pub cube_dim: CubeDim,
45    /// The number of bytes used by the share memory
46    pub shared_mem_bytes: usize,
47    /// Extra debugging information about the compiled kernel.
48    pub debug_info: Option<DebugInformation>,
49}
50
51/// Extra debugging information about the compiled kernel.
52#[derive(new)]
53pub struct DebugInformation {
54    /// The language tag of the source..
55    pub lang_tag: &'static str,
56    /// The compilation id.
57    pub id: KernelId,
58}
59
60impl Display for KernelId {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        match &self.info {
63            Some(info) => f.write_str(
64                format_str(
65                    format!("{:?}", info).as_str(),
66                    &[('(', ')'), ('[', ']'), ('{', '}')],
67                    true,
68                )
69                .as_str(),
70            ),
71            None => f.write_str("No info"),
72        }
73    }
74}
75
76impl<C: Compiler> Display for CompiledKernel<C> {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        f.write_str("\n[START_KERNEL_COMPILATION]")?;
79
80        if let Some(name) = self.debug_name {
81            if name.len() <= 32 {
82                f.write_fmt(format_args!("\nname: {name}"))?;
83            } else {
84                let name = format_str(name, &[('<', '>')], false);
85                f.write_fmt(format_args!("\nname: {name}"))?;
86            }
87        }
88
89        f.write_fmt(format_args!(
90            "
91cube_dim: ({}, {}, {})
92shared_memory: {} bytes",
93            self.cube_dim.x, self.cube_dim.y, self.cube_dim.z, self.shared_mem_bytes,
94        ))?;
95
96        if let Some(info) = &self.debug_info {
97            f.write_fmt(format_args!(
98                "\ninfo: {}",
99                format_str(
100                    format!("{:?}", info.id).as_str(),
101                    &[('(', ')'), ('[', ']'), ('{', '}')],
102                    true
103                )
104            ))?;
105        }
106
107        f.write_fmt(format_args!(
108            "
109source:
110```{}
111{}
112```
113[END_KERNEL_COMPILATION]
114",
115            self.debug_info
116                .as_ref()
117                .map(|info| info.lang_tag)
118                .unwrap_or(""),
119            self.source
120        ))
121    }
122}
123
124fn format_str(kernel_id: &str, markers: &[(char, char)], include_space: bool) -> String {
125    let kernel_id = kernel_id.to_string();
126    let mut result = String::new();
127    let mut depth = 0;
128    let indentation = 4;
129
130    let mut prev = ' ';
131
132    for c in kernel_id.chars() {
133        if c == ' ' {
134            continue;
135        }
136
137        let mut found_marker = false;
138
139        for (start, end) in markers {
140            let (start, end) = (*start, *end);
141
142            if c == start {
143                depth += 1;
144                if prev != ' ' && include_space {
145                    result.push(' ');
146                }
147                result.push(start);
148                result.push('\n');
149                result.push_str(&" ".repeat(indentation * depth));
150                found_marker = true;
151            } else if c == end {
152                depth -= 1;
153                if prev != start {
154                    if prev == ' ' {
155                        result.pop();
156                    }
157                    result.push_str(",\n");
158                    result.push_str(&" ".repeat(indentation * depth));
159                    result.push(end);
160                } else {
161                    for _ in 0..(&" ".repeat(indentation * depth).len()) + 1 + indentation {
162                        result.pop();
163                    }
164                    result.push(end);
165                }
166                found_marker = true;
167            }
168        }
169
170        if found_marker {
171            prev = c;
172            continue;
173        }
174
175        if c == ',' && depth > 0 {
176            if prev == ' ' {
177                result.pop();
178            }
179
180            result.push_str(",\n");
181            result.push_str(&" ".repeat(indentation * depth));
182            continue;
183        }
184
185        if c == ':' && include_space {
186            result.push(c);
187            result.push(' ');
188            prev = ' ';
189        } else {
190            result.push(c);
191            prev = c;
192        }
193    }
194
195    result
196}
197
198/// Kernel trait with the ComputeShader that will be compiled and cached based on the
199/// provided id.
200pub trait CubeTask<C: Compiler>: Send + Sync {
201    /// Identifier for the kernel, used for caching kernel compilation.
202    fn id(&self) -> KernelId;
203    /// Compile the kernel into source
204    fn compile(
205        &self,
206        compilation_options: &C::CompilationOptions,
207        mode: ExecutionMode,
208    ) -> CompiledKernel<C>;
209    fn name(&self) -> &'static str {
210        core::any::type_name::<Self>()
211    }
212}
213
214/// Wraps a [kernel](Kernel) to create a [cube task](CubeTask).
215#[derive(new)]
216pub struct KernelTask<C: Compiler, K: Kernel> {
217    kernel_definition: K,
218    _compiler: PhantomData<C>,
219}
220
221impl<C: Compiler, K: Kernel> CubeTask<C> for KernelTask<C, K> {
222    fn compile(
223        &self,
224        compilation_options: &C::CompilationOptions,
225        mode: ExecutionMode,
226    ) -> CompiledKernel<C> {
227        let gpu_ir = self.kernel_definition.define();
228        let entrypoint_name = gpu_ir.kernel_name.clone();
229        let cube_dim = gpu_ir.cube_dim;
230        let lower_level_ir = C::compile(gpu_ir, compilation_options, mode);
231        let shared_mem_bytes = lower_level_ir.shared_memory_size();
232
233        CompiledKernel {
234            entrypoint_name,
235            debug_name: Some(core::any::type_name::<K>()),
236            source: lower_level_ir.to_string(),
237            repr: Some(lower_level_ir),
238            cube_dim,
239            shared_mem_bytes,
240            debug_info: None,
241        }
242    }
243
244    fn id(&self) -> KernelId {
245        self.kernel_definition.id().clone()
246    }
247
248    fn name(&self) -> &'static str {
249        core::any::type_name::<K>()
250    }
251}
252
253impl<C: Compiler> CubeTask<C> for Arc<dyn CubeTask<C>> {
254    fn compile(
255        &self,
256        compilation_options: &C::CompilationOptions,
257        mode: ExecutionMode,
258    ) -> CompiledKernel<C> {
259        self.as_ref().compile(compilation_options, mode)
260    }
261
262    fn id(&self) -> KernelId {
263        self.as_ref().id()
264    }
265    fn name(&self) -> &'static str {
266        self.as_ref().name()
267    }
268}
269
270impl<C: Compiler> CubeTask<C> for Box<dyn CubeTask<C>> {
271    fn compile(
272        &self,
273        compilation_options: &C::CompilationOptions,
274        mode: ExecutionMode,
275    ) -> CompiledKernel<C> {
276        self.as_ref().compile(compilation_options, mode)
277    }
278
279    fn id(&self) -> KernelId {
280        self.as_ref().id()
281    }
282
283    fn name(&self) -> &'static str {
284        self.as_ref().name()
285    }
286}