Skip to main content

cubecl_runtime/
kernel.rs

1use alloc::{
2    boxed::Box,
3    string::{String, ToString},
4    vec::Vec,
5};
6use core::{
7    fmt::Display,
8    marker::PhantomData,
9    sync::atomic::{AtomicI8, Ordering},
10};
11
12use cubecl_common::format::format_str;
13use cubecl_ir::{Id, Scope, StorageType, Type};
14use serde::{Deserialize, Serialize};
15
16use crate::{
17    compiler::{CompilationError, Compiler, CubeTask},
18    config::{GlobalConfig, compilation::CompilationLogLevel},
19    id::KernelId,
20    server::{CubeDim, ExecutionMode},
21};
22
23/// Implement this trait to create a [kernel definition](KernelDefinition).
24pub trait KernelMetadata: Send + Sync + 'static {
25    /// Name of the kernel for debugging.
26    fn name(&self) -> &'static str {
27        core::any::type_name::<Self>()
28    }
29
30    /// Identifier for the kernel, used for caching kernel compilation.
31    fn id(&self) -> KernelId;
32
33    /// Type of addresses in this kernel
34    fn address_type(&self) -> StorageType;
35}
36
37#[derive(Debug, Clone)]
38#[allow(missing_docs)]
39pub struct KernelDefinition {
40    pub buffers: Vec<KernelArg>,
41    pub tensor_maps: Vec<KernelArg>,
42    pub scalars: Vec<ScalarKernelArg>,
43    pub cube_dim: CubeDim,
44    pub body: Scope,
45    pub options: KernelOptions,
46}
47
48#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
49/// Options for a specific kernel compilation
50pub struct KernelOptions {
51    /// The name of the kernel
52    pub kernel_name: String,
53    /// Whether to include debug symbols
54    pub debug_symbols: bool,
55    /// CUDA Cluster dim, if any
56    pub cluster_dim: Option<CubeDim>,
57}
58
59#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
60/// Global argument of a kernel.
61pub struct KernelArg {
62    /// The kernel id.
63    pub id: Id,
64    /// Whether the global argument can only be accessed for reading, or if it can also be accessed
65    /// for write.
66    pub visibility: Visibility,
67    /// The type of the argument.
68    pub ty: Type,
69    /// The size of the argument.
70    pub size: Option<usize>,
71    /// Whether the argument has metadata.
72    pub has_extended_meta: bool,
73}
74
75#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
76#[allow(missing_docs)]
77pub struct ScalarKernelArg {
78    pub ty: StorageType,
79    pub count: usize,
80}
81
82#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
83#[allow(missing_docs)]
84pub enum Visibility {
85    Read,
86    ReadWrite,
87}
88
89/// A kernel, compiled in the target language
90pub struct CompiledKernel<C: Compiler> {
91    /// The name of the kernel entrypoint.
92    /// For example
93    ///
94    /// ```text
95    /// #[cube(launch)]
96    /// fn gelu_array<F: Float, R: Runtime>() {}
97    /// ```
98    ///
99    /// would have the entrypoint name "`gelu_array`".
100    pub entrypoint_name: String,
101
102    /// A fully qualified debug name of the kernel.
103    ///
104    /// For example
105    ///
106    /// ```text
107    /// #[cube(launch)]
108    /// fn gelu_array<F: Float, R: Runtime>() {}
109    /// ```
110    ///
111    /// would have a debug name such as
112    ///
113    /// ```text
114    /// gelu::gelu_array::GeluArray<
115    ///    cubecl_core::frontend::element::float::F32,
116    ///    cubecl_cuda::runtime::CudaRuntime,
117    /// >
118    /// ```
119    pub debug_name: Option<&'static str>,
120
121    /// Source code of the kernel
122    pub source: String,
123    /// In-memory representation of the kernel
124    pub repr: Option<C::Representation>,
125    /// Size of a cube for the compiled kernel
126    pub cube_dim: CubeDim,
127    /// Extra debugging information about the compiled kernel.
128    pub debug_info: Option<DebugInformation>,
129}
130
131/// Extra debugging information about the compiled kernel.
132#[derive(new)]
133pub struct DebugInformation {
134    /// The language tag of the source..
135    pub lang_tag: &'static str,
136    /// The compilation id.
137    pub id: KernelId,
138}
139
140/// Kernel that can be defined
141pub trait CubeKernel: KernelMetadata {
142    /// Define the kernel for compilation
143    fn define(&self) -> KernelDefinition;
144}
145
146/// Wraps a [`CubeKernel`] to allow it be compiled.
147pub struct KernelTask<C: Compiler, K: CubeKernel> {
148    kernel_definition: K,
149    _compiler: PhantomData<C>,
150}
151
152/// Generic [`CubeTask`] for compiling kernels
153pub struct CubeTaskKernel<C: Compiler> {
154    /// The inner compilation task being wrapped
155    pub task: Box<dyn CubeTask<C>>,
156}
157
158impl<C: Compiler, K: CubeKernel> KernelTask<C, K> {
159    /// Create a new kernel task
160    pub fn new(kernel_definition: K) -> Self {
161        Self {
162            kernel_definition,
163            _compiler: PhantomData,
164        }
165    }
166}
167
168impl<C: Compiler, K: CubeKernel> CubeTask<C> for KernelTask<C, K> {
169    fn compile(
170        &self,
171        compiler: &mut C,
172        compilation_options: &C::CompilationOptions,
173        mode: ExecutionMode,
174        addr_type: StorageType,
175    ) -> Result<CompiledKernel<C>, CompilationError> {
176        let gpu_ir = self.kernel_definition.define();
177        let entrypoint_name = gpu_ir.options.kernel_name.clone();
178        let cube_dim = gpu_ir.cube_dim;
179        let lower_level_ir = compiler.compile(gpu_ir, compilation_options, mode, addr_type)?;
180
181        Ok(CompiledKernel {
182            entrypoint_name,
183            debug_name: Some(core::any::type_name::<K>()),
184            source: lower_level_ir.to_string(),
185            repr: Some(lower_level_ir),
186            cube_dim,
187            debug_info: None,
188        })
189    }
190}
191
192impl<C: Compiler, K: CubeKernel> KernelMetadata for KernelTask<C, K> {
193    // Forward ID to underlying kernel definition.
194    fn id(&self) -> KernelId {
195        self.kernel_definition.id()
196    }
197
198    // Forward name to underlying kernel definition.
199    fn name(&self) -> &'static str {
200        self.kernel_definition.name()
201    }
202
203    fn address_type(&self) -> StorageType {
204        self.kernel_definition.address_type()
205    }
206}
207
208impl<C: Compiler> KernelMetadata for Box<dyn CubeTask<C>> {
209    // Deref and use existing ID.
210    fn id(&self) -> KernelId {
211        self.as_ref().id()
212    }
213
214    // Deref and use existing name.
215    fn name(&self) -> &'static str {
216        self.as_ref().name()
217    }
218
219    fn address_type(&self) -> StorageType {
220        self.as_ref().address_type()
221    }
222}
223
224static COMPILATION_LEVEL: AtomicI8 = AtomicI8::new(-1);
225
226fn compilation_level() -> u8 {
227    let compilation_level = COMPILATION_LEVEL.load(Ordering::Relaxed);
228    if compilation_level == -1 {
229        let val = match GlobalConfig::get().compilation.logger.level {
230            CompilationLogLevel::Full => 2,
231            CompilationLogLevel::Disabled => 0,
232            CompilationLogLevel::Basic => 1,
233        };
234
235        COMPILATION_LEVEL.store(val, Ordering::Relaxed);
236        val as u8
237    } else {
238        compilation_level as u8
239    }
240}
241
242impl<C: Compiler> Display for CompiledKernel<C> {
243    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
244        match compilation_level() {
245            2 => self.format_full(f),
246            _ => self.format_basic(f),
247        }
248    }
249}
250
251impl<C: Compiler> CompiledKernel<C> {
252    fn format_basic(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
253        f.write_str("[Compiling kernel]")?;
254        if let Some(name) = self.debug_name {
255            if name.len() <= 32 {
256                f.write_fmt(format_args!(" {name}"))?;
257            } else {
258                f.write_fmt(format_args!(" {}", name.split('<').next().unwrap_or("")))?;
259            }
260        }
261
262        Ok(())
263    }
264
265    fn format_full(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
266        f.write_str("[START_KERNEL_COMPILATION]")?;
267
268        if let Some(name) = self.debug_name {
269            if name.len() <= 32 {
270                f.write_fmt(format_args!("\nname: {name}"))?;
271            } else {
272                let name = format_str(name, &[('<', '>')], false);
273                f.write_fmt(format_args!("\nname: {name}"))?;
274            }
275        }
276
277        if let Some(info) = &self.debug_info {
278            f.write_fmt(format_args!("\nid: {:#?}", info.id))?;
279        }
280
281        f.write_fmt(format_args!(
282            "
283source:
284```{}
285{}
286```
287[END_KERNEL_COMPILATION]
288",
289            self.debug_info
290                .as_ref()
291                .map(|info| info.lang_tag)
292                .unwrap_or(""),
293            self.source
294        ))
295    }
296}