cubecl_runtime/
kernel.rs

1use alloc::{
2    boxed::Box,
3    string::{String, ToString},
4    vec::Vec,
5};
6use core::{
7    fmt::Display,
8    marker::PhantomData,
9    sync::atomic::{AtomicI8, Ordering},
10};
11
12use cubecl_common::format::format_str;
13use cubecl_ir::{Id, Scope, StorageType, Type};
14use serde::{Deserialize, Serialize};
15
16use crate::{
17    compiler::{CompilationError, Compiler, CubeTask},
18    config::{GlobalConfig, compilation::CompilationLogLevel},
19    id::KernelId,
20    server::{CubeDim, ExecutionMode},
21};
22
23/// Implement this trait to create a [kernel definition](KernelDefinition).
24pub trait KernelMetadata: Send + Sync + 'static {
25    /// Name of the kernel for debugging.
26    fn name(&self) -> &'static str {
27        core::any::type_name::<Self>()
28    }
29
30    /// Identifier for the kernel, used for caching kernel compilation.
31    fn id(&self) -> KernelId;
32
33    /// Type of addresses in this kernel
34    fn address_type(&self) -> StorageType;
35}
36
37#[derive(Debug, Clone)]
38#[allow(missing_docs)]
39pub struct KernelDefinition {
40    pub buffers: Vec<Binding>,
41    pub tensor_maps: Vec<Binding>,
42    pub scalars: Vec<ScalarBinding>,
43    pub cube_dim: CubeDim,
44    pub body: Scope,
45    pub options: KernelOptions,
46}
47
48#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
49/// Options for a specific kernel compilation
50pub struct KernelOptions {
51    /// The name of the kernel
52    pub kernel_name: String,
53    /// Whether to include debug symbols
54    pub debug_symbols: bool,
55    /// CUDA Cluster dim, if any
56    pub cluster_dim: Option<CubeDim>,
57}
58
59#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
60#[allow(missing_docs)]
61pub struct Binding {
62    pub id: Id,
63    pub location: Location,
64    pub visibility: Visibility,
65    pub ty: Type,
66    pub size: Option<usize>,
67    pub has_extended_meta: bool,
68}
69
70#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
71#[allow(missing_docs)]
72pub struct ScalarBinding {
73    pub ty: StorageType,
74    pub count: usize,
75}
76
77#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
78#[allow(missing_docs)]
79pub enum Location {
80    Storage,
81    Cube,
82}
83
84#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
85#[allow(missing_docs)]
86pub enum Visibility {
87    Read,
88    ReadWrite,
89}
90
91/// A kernel, compiled in the target language
92pub struct CompiledKernel<C: Compiler> {
93    /// The name of the kernel entrypoint.
94    /// For example
95    ///
96    /// ```text
97    /// #[cube(launch)]
98    /// fn gelu_array<F: Float, R: Runtime>() {}
99    /// ```
100    ///
101    /// would have the entrypoint name "gelu_array".
102    pub entrypoint_name: String,
103
104    /// A fully qualified debug name of the kernel.
105    ///
106    /// For example
107    ///
108    /// ```text
109    /// #[cube(launch)]
110    /// fn gelu_array<F: Float, R: Runtime>() {}
111    /// ```
112    ///
113    /// would have a debug name such as
114    ///
115    /// ```text
116    /// gelu::gelu_array::GeluArray<
117    ///    cubecl_core::frontend::element::float::F32,
118    ///    cubecl_cuda::runtime::CudaRuntime,
119    /// >
120    /// ```
121    pub debug_name: Option<&'static str>,
122
123    /// Source code of the kernel
124    pub source: String,
125    /// In-memory representation of the kernel
126    pub repr: Option<C::Representation>,
127    /// Size of a cube for the compiled kernel
128    pub cube_dim: CubeDim,
129    /// Extra debugging information about the compiled kernel.
130    pub debug_info: Option<DebugInformation>,
131}
132
133/// Extra debugging information about the compiled kernel.
134#[derive(new)]
135pub struct DebugInformation {
136    /// The language tag of the source..
137    pub lang_tag: &'static str,
138    /// The compilation id.
139    pub id: KernelId,
140}
141
142/// Kernel that can be defined
143pub trait CubeKernel: KernelMetadata {
144    /// Define the kernel for compilation
145    fn define(&self) -> KernelDefinition;
146}
147
148/// Wraps a [kernel](Kernel) to allow it be compiled.
149pub struct KernelTask<C: Compiler, K: CubeKernel> {
150    kernel_definition: K,
151    _compiler: PhantomData<C>,
152}
153
154/// Generic [CubeTask] for compiling kernels
155pub struct CubeTaskKernel<C: Compiler> {
156    /// The inner compilation task being wrapped
157    pub task: Box<dyn CubeTask<C>>,
158}
159
160impl<C: Compiler, K: CubeKernel> KernelTask<C, K> {
161    /// Create a new kernel task
162    pub fn new(kernel_definition: K) -> Self {
163        Self {
164            kernel_definition,
165            _compiler: PhantomData,
166        }
167    }
168}
169
170impl<C: Compiler, K: CubeKernel> CubeTask<C> for KernelTask<C, K> {
171    fn compile(
172        &self,
173        compiler: &mut C,
174        compilation_options: &C::CompilationOptions,
175        mode: ExecutionMode,
176        addr_type: StorageType,
177    ) -> Result<CompiledKernel<C>, CompilationError> {
178        let gpu_ir = self.kernel_definition.define();
179        let entrypoint_name = gpu_ir.options.kernel_name.clone();
180        let cube_dim = gpu_ir.cube_dim;
181        let lower_level_ir = compiler.compile(gpu_ir, compilation_options, mode, addr_type)?;
182
183        Ok(CompiledKernel {
184            entrypoint_name,
185            debug_name: Some(core::any::type_name::<K>()),
186            source: lower_level_ir.to_string(),
187            repr: Some(lower_level_ir),
188            cube_dim,
189            debug_info: None,
190        })
191    }
192}
193
194impl<C: Compiler, K: CubeKernel> KernelMetadata for KernelTask<C, K> {
195    // Forward ID to underlying kernel definition.
196    fn id(&self) -> KernelId {
197        self.kernel_definition.id()
198    }
199
200    // Forward name to underlying kernel definition.
201    fn name(&self) -> &'static str {
202        self.kernel_definition.name()
203    }
204
205    fn address_type(&self) -> StorageType {
206        self.kernel_definition.address_type()
207    }
208}
209
210impl<C: Compiler> KernelMetadata for Box<dyn CubeTask<C>> {
211    // Deref and use existing ID.
212    fn id(&self) -> KernelId {
213        self.as_ref().id()
214    }
215
216    // Deref and use existing name.
217    fn name(&self) -> &'static str {
218        self.as_ref().name()
219    }
220
221    fn address_type(&self) -> StorageType {
222        self.as_ref().address_type()
223    }
224}
225
226static COMPILATION_LEVEL: AtomicI8 = AtomicI8::new(-1);
227
228fn compilation_level() -> u8 {
229    let compilation_level = COMPILATION_LEVEL.load(Ordering::Relaxed);
230    if compilation_level == -1 {
231        let val = match GlobalConfig::get().compilation.logger.level {
232            CompilationLogLevel::Full => 2,
233            CompilationLogLevel::Disabled => 0,
234            CompilationLogLevel::Basic => 1,
235        };
236
237        COMPILATION_LEVEL.store(val, Ordering::Relaxed);
238        val as u8
239    } else {
240        compilation_level as u8
241    }
242}
243
244impl<C: Compiler> Display for CompiledKernel<C> {
245    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
246        match compilation_level() {
247            2 => self.format_full(f),
248            _ => self.format_basic(f),
249        }
250    }
251}
252
253impl<C: Compiler> CompiledKernel<C> {
254    fn format_basic(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
255        f.write_str("[Compiling kernel]")?;
256        if let Some(name) = self.debug_name {
257            if name.len() <= 32 {
258                f.write_fmt(format_args!(" {name}"))?;
259            } else {
260                f.write_fmt(format_args!(" {}", name.split('<').next().unwrap_or("")))?;
261            }
262        }
263
264        Ok(())
265    }
266
267    fn format_full(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
268        f.write_str("[START_KERNEL_COMPILATION]")?;
269
270        if let Some(name) = self.debug_name {
271            if name.len() <= 32 {
272                f.write_fmt(format_args!("\nname: {name}"))?;
273            } else {
274                let name = format_str(name, &[('<', '>')], false);
275                f.write_fmt(format_args!("\nname: {name}"))?;
276            }
277        }
278
279        if let Some(info) = &self.debug_info {
280            f.write_fmt(format_args!("\nid: {:#?}", info.id))?;
281        }
282
283        f.write_fmt(format_args!(
284            "
285source:
286```{}
287{}
288```
289[END_KERNEL_COMPILATION]
290",
291            self.debug_info
292                .as_ref()
293                .map(|info| info.lang_tag)
294                .unwrap_or(""),
295            self.source
296        ))
297    }
298}