cubecl_runtime/
kernel.rs

1use alloc::{
2    boxed::Box,
3    format,
4    string::{String, ToString},
5    vec::Vec,
6};
7use core::{
8    fmt::Display,
9    marker::PhantomData,
10    sync::atomic::{AtomicI8, Ordering},
11};
12
13use cubecl_common::format::format_str;
14use cubecl_ir::{Id, Scope, StorageType, Type};
15use serde::{Deserialize, Serialize};
16
17use crate::{
18    compiler::{CompilationError, Compiler, CubeTask},
19    config::{GlobalConfig, compilation::CompilationLogLevel},
20    id::KernelId,
21    server::{CubeDim, ExecutionMode},
22};
23
24/// Implement this trait to create a [kernel definition](KernelDefinition).
25pub trait KernelMetadata: Send + Sync + 'static {
26    /// Name of the kernel for debugging.
27    fn name(&self) -> &'static str {
28        core::any::type_name::<Self>()
29    }
30
31    /// Identifier for the kernel, used for caching kernel compilation.
32    fn id(&self) -> KernelId;
33}
34
35#[derive(Debug, Clone)]
36#[allow(missing_docs)]
37pub struct KernelDefinition {
38    pub buffers: Vec<Binding>,
39    pub tensor_maps: Vec<Binding>,
40    pub scalars: Vec<ScalarBinding>,
41    pub cube_dim: CubeDim,
42    pub body: Scope,
43    pub options: KernelOptions,
44}
45
46#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
47/// Options for a specific kernel compilation
48pub struct KernelOptions {
49    /// The name of the kernel
50    pub kernel_name: String,
51    /// Whether to include debug symbols
52    pub debug_symbols: bool,
53    /// CUDA Cluster dim, if any
54    pub cluster_dim: Option<CubeDim>,
55}
56
57#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
58#[allow(missing_docs)]
59pub struct Binding {
60    pub id: Id,
61    pub location: Location,
62    pub visibility: Visibility,
63    pub ty: Type,
64    pub size: Option<usize>,
65    pub has_extended_meta: bool,
66}
67
68#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
69#[allow(missing_docs)]
70pub struct ScalarBinding {
71    pub ty: StorageType,
72    pub count: usize,
73}
74
75#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
76#[allow(missing_docs)]
77pub enum Location {
78    Storage,
79    Cube,
80}
81
82#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
83#[allow(missing_docs)]
84pub enum Visibility {
85    Read,
86    ReadWrite,
87}
88
89/// A kernel, compiled in the target language
90pub struct CompiledKernel<C: Compiler> {
91    /// The name of the kernel entrypoint.
92    /// For example
93    ///
94    /// ```text
95    /// #[cube(launch)]
96    /// fn gelu_array<F: Float, R: Runtime>() {}
97    /// ```
98    ///
99    /// would have the entrypoint name "gelu_array".
100    pub entrypoint_name: String,
101
102    /// A fully qualified debug name of the kernel.
103    ///
104    /// For example
105    ///
106    /// ```text
107    /// #[cube(launch)]
108    /// fn gelu_array<F: Float, R: Runtime>() {}
109    /// ```
110    ///
111    /// would have a debug name such as
112    ///
113    /// ```text
114    /// gelu::gelu_array::GeluArray<
115    ///    cubecl_core::frontend::element::float::F32,
116    ///    cubecl_cuda::runtime::CudaRuntime,
117    /// >
118    /// ```
119    pub debug_name: Option<&'static str>,
120
121    /// Source code of the kernel
122    pub source: String,
123    /// In-memory representation of the kernel
124    pub repr: Option<C::Representation>,
125    /// Size of a cube for the compiled kernel
126    pub cube_dim: CubeDim,
127    /// Extra debugging information about the compiled kernel.
128    pub debug_info: Option<DebugInformation>,
129}
130
131/// Extra debugging information about the compiled kernel.
132#[derive(new)]
133pub struct DebugInformation {
134    /// The language tag of the source..
135    pub lang_tag: &'static str,
136    /// The compilation id.
137    pub id: KernelId,
138}
139
140/// Kernel that can be defined
141pub trait CubeKernel: KernelMetadata {
142    /// Define the kernel for compilation
143    fn define(&self) -> KernelDefinition;
144}
145
146/// Wraps a [kernel](Kernel) to allow it be compiled.
147pub struct KernelTask<C: Compiler, K: CubeKernel> {
148    kernel_definition: K,
149    _compiler: PhantomData<C>,
150}
151
152/// Generic [CubeTask] for compiling kernels
153pub struct CubeTaskKernel<C: Compiler> {
154    /// The inner compilation task being wrapped
155    pub task: Box<dyn CubeTask<C>>,
156}
157
158impl<C: Compiler, K: CubeKernel> KernelTask<C, K> {
159    /// Create a new kernel task
160    pub fn new(kernel_definition: K) -> Self {
161        Self {
162            kernel_definition,
163            _compiler: PhantomData,
164        }
165    }
166}
167
168impl<C: Compiler, K: CubeKernel> CubeTask<C> for KernelTask<C, K> {
169    fn compile(
170        &self,
171        compiler: &mut C,
172        compilation_options: &C::CompilationOptions,
173        mode: ExecutionMode,
174    ) -> Result<CompiledKernel<C>, CompilationError> {
175        let gpu_ir = self.kernel_definition.define();
176        let entrypoint_name = gpu_ir.options.kernel_name.clone();
177        let cube_dim = gpu_ir.cube_dim;
178        let lower_level_ir = compiler.compile(gpu_ir, compilation_options, mode)?;
179
180        Ok(CompiledKernel {
181            entrypoint_name,
182            debug_name: Some(core::any::type_name::<K>()),
183            source: lower_level_ir.to_string(),
184            repr: Some(lower_level_ir),
185            cube_dim,
186            debug_info: None,
187        })
188    }
189}
190
191impl<C: Compiler, K: CubeKernel> KernelMetadata for KernelTask<C, K> {
192    // Forward ID to underlying kernel definition.
193    fn id(&self) -> KernelId {
194        self.kernel_definition.id()
195    }
196
197    // Forward name to underlying kernel definition.
198    fn name(&self) -> &'static str {
199        self.kernel_definition.name()
200    }
201}
202
203impl<C: Compiler> KernelMetadata for Box<dyn CubeTask<C>> {
204    // Deref and use existing ID.
205    fn id(&self) -> KernelId {
206        self.as_ref().id()
207    }
208
209    // Deref and use existing name.
210    fn name(&self) -> &'static str {
211        self.as_ref().name()
212    }
213}
214
215static COMPILATION_LEVEL: AtomicI8 = AtomicI8::new(-1);
216
217fn compilation_level() -> u8 {
218    let compilation_level = COMPILATION_LEVEL.load(Ordering::Relaxed);
219    if compilation_level == -1 {
220        let val = match GlobalConfig::get().compilation.logger.level {
221            CompilationLogLevel::Full => 2,
222            CompilationLogLevel::Disabled => 0,
223            CompilationLogLevel::Basic => 1,
224        };
225
226        COMPILATION_LEVEL.store(val, Ordering::Relaxed);
227        val as u8
228    } else {
229        compilation_level as u8
230    }
231}
232
233impl<C: Compiler> Display for CompiledKernel<C> {
234    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
235        match compilation_level() {
236            2 => self.format_full(f),
237            _ => self.format_basic(f),
238        }
239    }
240}
241
242impl<C: Compiler> CompiledKernel<C> {
243    fn format_basic(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
244        f.write_str("[Compiling kernel]")?;
245        if let Some(name) = self.debug_name {
246            if name.len() <= 32 {
247                f.write_fmt(format_args!(" {name}"))?;
248            } else {
249                f.write_fmt(format_args!(" {}", name.split('<').next().unwrap_or("")))?;
250            }
251        }
252
253        Ok(())
254    }
255
256    fn format_full(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
257        f.write_str("[START_KERNEL_COMPILATION]")?;
258
259        if let Some(name) = self.debug_name {
260            if name.len() <= 32 {
261                f.write_fmt(format_args!("\nname: {name}"))?;
262            } else {
263                let name = format_str(name, &[('<', '>')], false);
264                f.write_fmt(format_args!("\nname: {name}"))?;
265            }
266        }
267
268        f.write_fmt(format_args!(
269            "
270cube_dim: ({}, {}, {})",
271            self.cube_dim.x, self.cube_dim.y, self.cube_dim.z,
272        ))?;
273
274        if let Some(info) = &self.debug_info {
275            f.write_fmt(format_args!(
276                "\ninfo: {}",
277                format_str(
278                    format!("{:?}", info.id).as_str(),
279                    &[('(', ')'), ('[', ']'), ('{', '}')],
280                    true
281                )
282            ))?;
283        }
284
285        f.write_fmt(format_args!(
286            "
287source:
288```{}
289{}
290```
291[END_KERNEL_COMPILATION]
292",
293            self.debug_info
294                .as_ref()
295                .map(|info| info.lang_tag)
296                .unwrap_or(""),
297            self.source
298        ))
299    }
300}