cubecl_runtime/
kernel.rs

1use alloc::{
2    boxed::Box,
3    format,
4    string::{String, ToString},
5    vec::Vec,
6};
7use core::{
8    fmt::Display,
9    marker::PhantomData,
10    sync::atomic::{AtomicI8, Ordering},
11};
12
13use cubecl_common::{CubeDim, ExecutionMode, format::format_str};
14use cubecl_ir::{Id, Scope, StorageType, Type};
15use serde::{Deserialize, Serialize};
16
17use crate::{
18    compiler::{CompilationError, Compiler, CubeTask},
19    config::{GlobalConfig, compilation::CompilationLogLevel},
20    id::KernelId,
21};
22
23/// Implement this trait to create a [kernel definition](KernelDefinition).
24pub trait KernelMetadata: Send + Sync + 'static {
25    /// Name of the kernel for debugging.
26    fn name(&self) -> &'static str {
27        core::any::type_name::<Self>()
28    }
29
30    /// Identifier for the kernel, used for caching kernel compilation.
31    fn id(&self) -> KernelId;
32}
33
34#[derive(Debug, Clone)]
35#[allow(missing_docs)]
36pub struct KernelDefinition {
37    pub buffers: Vec<Binding>,
38    pub tensor_maps: Vec<Binding>,
39    pub scalars: Vec<ScalarBinding>,
40    pub cube_dim: CubeDim,
41    pub body: Scope,
42    pub options: KernelOptions,
43}
44
45#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
46/// Options for a specific kernel compilation
47pub struct KernelOptions {
48    /// The name of the kernel
49    pub kernel_name: String,
50    /// Whether to include debug symbols
51    pub debug_symbols: bool,
52    /// CUDA Cluster dim, if any
53    pub cluster_dim: Option<CubeDim>,
54}
55
56#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
57#[allow(missing_docs)]
58pub struct Binding {
59    pub id: Id,
60    pub location: Location,
61    pub visibility: Visibility,
62    pub ty: Type,
63    pub size: Option<usize>,
64    pub has_extended_meta: bool,
65}
66
67#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
68#[allow(missing_docs)]
69pub struct ScalarBinding {
70    pub ty: StorageType,
71    pub count: usize,
72}
73
74#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
75#[allow(missing_docs)]
76pub enum Location {
77    Storage,
78    Cube,
79}
80
81#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
82#[allow(missing_docs)]
83pub enum Visibility {
84    Read,
85    ReadWrite,
86}
87
88/// A kernel, compiled in the target language
89pub struct CompiledKernel<C: Compiler> {
90    /// The name of the kernel entrypoint.
91    /// For example
92    ///
93    /// ```text
94    /// #[cube(launch)]
95    /// fn gelu_array<F: Float, R: Runtime>() {}
96    /// ```
97    ///
98    /// would have the entrypoint name "gelu_array".
99    pub entrypoint_name: String,
100
101    /// A fully qualified debug name of the kernel.
102    ///
103    /// For example
104    ///
105    /// ```text
106    /// #[cube(launch)]
107    /// fn gelu_array<F: Float, R: Runtime>() {}
108    /// ```
109    ///
110    /// would have a debug name such as
111    ///
112    /// ```text
113    /// gelu::gelu_array::GeluArray<
114    ///    cubecl_core::frontend::element::float::F32,
115    ///    cubecl_cuda::runtime::CudaRuntime,
116    /// >
117    /// ```
118    pub debug_name: Option<&'static str>,
119
120    /// Source code of the kernel
121    pub source: String,
122    /// In-memory representation of the kernel
123    pub repr: Option<C::Representation>,
124    /// Size of a cube for the compiled kernel
125    pub cube_dim: CubeDim,
126    /// Extra debugging information about the compiled kernel.
127    pub debug_info: Option<DebugInformation>,
128}
129
130/// Extra debugging information about the compiled kernel.
131#[derive(new)]
132pub struct DebugInformation {
133    /// The language tag of the source..
134    pub lang_tag: &'static str,
135    /// The compilation id.
136    pub id: KernelId,
137}
138
139/// Kernel that can be defined
140pub trait CubeKernel: KernelMetadata {
141    /// Define the kernel for compilation
142    fn define(&self) -> KernelDefinition;
143}
144
145/// Wraps a [kernel](Kernel) to allow it be compiled.
146pub struct KernelTask<C: Compiler, K: CubeKernel> {
147    kernel_definition: K,
148    _compiler: PhantomData<C>,
149}
150
151/// Generic [CubeTask] for compiling kernels
152pub struct CubeTaskKernel<C: Compiler> {
153    /// The inner compilation task being wrapped
154    pub task: Box<dyn CubeTask<C>>,
155}
156
157impl<C: Compiler, K: CubeKernel> KernelTask<C, K> {
158    /// Create a new kernel task
159    pub fn new(kernel_definition: K) -> Self {
160        Self {
161            kernel_definition,
162            _compiler: PhantomData,
163        }
164    }
165}
166
167impl<C: Compiler, K: CubeKernel> CubeTask<C> for KernelTask<C, K> {
168    fn compile(
169        &self,
170        compiler: &mut C,
171        compilation_options: &C::CompilationOptions,
172        mode: ExecutionMode,
173    ) -> Result<CompiledKernel<C>, CompilationError> {
174        let gpu_ir = self.kernel_definition.define();
175        let entrypoint_name = gpu_ir.options.kernel_name.clone();
176        let cube_dim = gpu_ir.cube_dim;
177        let lower_level_ir = compiler.compile(gpu_ir, compilation_options, mode)?;
178
179        Ok(CompiledKernel {
180            entrypoint_name,
181            debug_name: Some(core::any::type_name::<K>()),
182            source: lower_level_ir.to_string(),
183            repr: Some(lower_level_ir),
184            cube_dim,
185            debug_info: None,
186        })
187    }
188}
189
190impl<C: Compiler, K: CubeKernel> KernelMetadata for KernelTask<C, K> {
191    // Forward ID to underlying kernel definition.
192    fn id(&self) -> KernelId {
193        self.kernel_definition.id()
194    }
195
196    // Forward name to underlying kernel definition.
197    fn name(&self) -> &'static str {
198        self.kernel_definition.name()
199    }
200}
201
202impl<C: Compiler> KernelMetadata for Box<dyn CubeTask<C>> {
203    // Deref and use existing ID.
204    fn id(&self) -> KernelId {
205        self.as_ref().id()
206    }
207
208    // Deref and use existing name.
209    fn name(&self) -> &'static str {
210        self.as_ref().name()
211    }
212}
213
214static COMPILATION_LEVEL: AtomicI8 = AtomicI8::new(-1);
215
216fn compilation_level() -> u8 {
217    let compilation_level = COMPILATION_LEVEL.load(Ordering::Relaxed);
218    if compilation_level == -1 {
219        let val = match GlobalConfig::get().compilation.logger.level {
220            CompilationLogLevel::Full => 2,
221            CompilationLogLevel::Disabled => 0,
222            CompilationLogLevel::Basic => 1,
223        };
224
225        COMPILATION_LEVEL.store(val, Ordering::Relaxed);
226        val as u8
227    } else {
228        compilation_level as u8
229    }
230}
231
232impl<C: Compiler> Display for CompiledKernel<C> {
233    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
234        match compilation_level() {
235            2 => self.format_full(f),
236            _ => self.format_basic(f),
237        }
238    }
239}
240
241impl<C: Compiler> CompiledKernel<C> {
242    fn format_basic(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
243        f.write_str("[Compiling kernel]")?;
244        if let Some(name) = self.debug_name {
245            if name.len() <= 32 {
246                f.write_fmt(format_args!(" {name}"))?;
247            } else {
248                f.write_fmt(format_args!(" {}", name.split('<').next().unwrap_or("")))?;
249            }
250        }
251
252        Ok(())
253    }
254
255    fn format_full(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
256        f.write_str("[START_KERNEL_COMPILATION]")?;
257
258        if let Some(name) = self.debug_name {
259            if name.len() <= 32 {
260                f.write_fmt(format_args!("\nname: {name}"))?;
261            } else {
262                let name = format_str(name, &[('<', '>')], false);
263                f.write_fmt(format_args!("\nname: {name}"))?;
264            }
265        }
266
267        f.write_fmt(format_args!(
268            "
269cube_dim: ({}, {}, {})",
270            self.cube_dim.x, self.cube_dim.y, self.cube_dim.z,
271        ))?;
272
273        if let Some(info) = &self.debug_info {
274            f.write_fmt(format_args!(
275                "\ninfo: {}",
276                format_str(
277                    format!("{:?}", info.id).as_str(),
278                    &[('(', ')'), ('[', ']'), ('{', '}')],
279                    true
280                )
281            ))?;
282        }
283
284        f.write_fmt(format_args!(
285            "
286source:
287```{}
288{}
289```
290[END_KERNEL_COMPILATION]
291",
292            self.debug_info
293                .as_ref()
294                .map(|info| info.lang_tag)
295                .unwrap_or(""),
296            self.source
297        ))
298    }
299}