cubecl_core/compute/
kernel.rs

1use std::{
2    fmt::Display,
3    marker::PhantomData,
4    sync::atomic::{AtomicI8, Ordering},
5};
6
7use crate::{Compiler, KernelOptions};
8use cubecl_common::{CubeDim, ExecutionMode};
9use cubecl_ir::{Elem, Id, Item, Scope};
10use cubecl_runtime::{
11    config::{GlobalConfig, compilation::CompilationLogLevel},
12    id::{KernelId, format_str},
13    kernel::KernelMetadata,
14};
15use serde::{Deserialize, Serialize};
16
17/// A kernel, compiled in the target language
18pub struct CompiledKernel<C: Compiler> {
19    /// The name of the kernel entrypoint.
20    /// For example
21    ///
22    /// ```text
23    /// #[cube(launch)]
24    /// fn gelu_array<F: Float, R: Runtime>() {}
25    /// ```
26    ///
27    /// would have the entrypoint name "gelu_array".
28    pub entrypoint_name: String,
29
30    /// A fully qualified debug name of the kernel.
31    ///
32    /// For example
33    ///
34    /// ```text
35    /// #[cube(launch)]
36    /// fn gelu_array<F: Float, R: Runtime>() {}
37    /// ```
38    ///
39    /// would have a debug name such as
40    ///
41    /// ```text
42    /// gelu::gelu_array::GeluArray<
43    ///    cubecl_core::frontend::element::float::F32,
44    ///    cubecl_cuda::runtime::CudaRuntime,
45    /// >
46    /// ```
47    pub debug_name: Option<&'static str>,
48
49    /// Source code of the kernel
50    pub source: String,
51    /// In-memory representation of the kernel
52    pub repr: Option<C::Representation>,
53    /// Size of a cube for the compiled kernel
54    pub cube_dim: CubeDim,
55    /// Extra debugging information about the compiled kernel.
56    pub debug_info: Option<DebugInformation>,
57}
58
59/// Extra debugging information about the compiled kernel.
60#[derive(new)]
61pub struct DebugInformation {
62    /// The language tag of the source..
63    pub lang_tag: &'static str,
64    /// The compilation id.
65    pub id: KernelId,
66}
67
68static COMPILATION_LEVEL: AtomicI8 = AtomicI8::new(-1);
69
70fn compilation_level() -> u8 {
71    let compilation_level = COMPILATION_LEVEL.load(Ordering::Relaxed);
72    if compilation_level == -1 {
73        let val = match GlobalConfig::get().compilation.logger.level {
74            CompilationLogLevel::Full => 2,
75            CompilationLogLevel::Disabled => 0,
76            CompilationLogLevel::Basic => 1,
77        };
78
79        COMPILATION_LEVEL.store(val, Ordering::Relaxed);
80        val as u8
81    } else {
82        compilation_level as u8
83    }
84}
85
86impl<C: Compiler> Display for CompiledKernel<C> {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        match compilation_level() {
89            2 => self.format_full(f),
90            _ => self.format_basic(f),
91        }
92    }
93}
94
95impl<C: Compiler> CompiledKernel<C> {
96    fn format_basic(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        f.write_str("[Compiling kernel]")?;
98        if let Some(name) = self.debug_name {
99            if name.len() <= 32 {
100                f.write_fmt(format_args!(" {name}"))?;
101            } else {
102                f.write_fmt(format_args!(" {}", name.split('<').next().unwrap_or("")))?;
103            }
104        }
105
106        Ok(())
107    }
108
109    fn format_full(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        f.write_str("[START_KERNEL_COMPILATION]")?;
111
112        if let Some(name) = self.debug_name {
113            if name.len() <= 32 {
114                f.write_fmt(format_args!("\nname: {name}"))?;
115            } else {
116                let name = format_str(name, &[('<', '>')], false);
117                f.write_fmt(format_args!("\nname: {name}"))?;
118            }
119        }
120
121        f.write_fmt(format_args!(
122            "
123cube_dim: ({}, {}, {})",
124            self.cube_dim.x, self.cube_dim.y, self.cube_dim.z,
125        ))?;
126
127        if let Some(info) = &self.debug_info {
128            f.write_fmt(format_args!(
129                "\ninfo: {}",
130                format_str(
131                    format!("{:?}", info.id).as_str(),
132                    &[('(', ')'), ('[', ']'), ('{', '}')],
133                    true
134                )
135            ))?;
136        }
137
138        f.write_fmt(format_args!(
139            "
140source:
141```{}
142{}
143```
144[END_KERNEL_COMPILATION]
145",
146            self.debug_info
147                .as_ref()
148                .map(|info| info.lang_tag)
149                .unwrap_or(""),
150            self.source
151        ))
152    }
153}
154
155#[derive(Debug, Clone)]
156#[allow(missing_docs)]
157pub struct KernelDefinition {
158    pub buffers: Vec<Binding>,
159    pub tensor_maps: Vec<Id>,
160    pub scalars: Vec<ScalarBinding>,
161    pub cube_dim: CubeDim,
162    pub body: Scope,
163    pub options: KernelOptions,
164}
165
166#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
167#[allow(missing_docs)]
168pub struct Binding {
169    pub id: Id,
170    pub location: Location,
171    pub visibility: Visibility,
172    pub item: Item,
173    pub size: Option<usize>,
174    pub has_extended_meta: bool,
175}
176
177#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
178#[allow(missing_docs)]
179pub struct ScalarBinding {
180    pub elem: Elem,
181    pub count: usize,
182}
183
184#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
185#[allow(missing_docs)]
186pub enum Location {
187    Storage,
188    Cube,
189}
190
191#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
192#[allow(missing_docs)]
193pub enum Visibility {
194    Read,
195    ReadWrite,
196}
197
198pub trait CubeKernel: KernelMetadata {
199    fn define(&self) -> KernelDefinition;
200}
201
202/// Kernel trait with the ComputeShader that will be compiled and cached based on the
203/// provided id.
204pub trait CubeTask<C: Compiler>: KernelMetadata + Send + Sync {
205    fn compile(
206        &self,
207        compiler: &mut C,
208        compilation_options: &C::CompilationOptions,
209        mode: ExecutionMode,
210    ) -> CompiledKernel<C>;
211}
212
213/// Wraps a [kernel](Kernel) to allow it be compiled.
214pub struct KernelTask<C: Compiler, K: CubeKernel> {
215    kernel_definition: K,
216    _compiler: PhantomData<C>,
217}
218
219pub struct CubeTaskKernel<C: Compiler> {
220    pub task: Box<dyn CubeTask<C>>,
221}
222
223impl<C: Compiler, K: CubeKernel> KernelTask<C, K> {
224    pub fn new(kernel_definition: K) -> Self {
225        Self {
226            kernel_definition,
227            _compiler: PhantomData,
228        }
229    }
230}
231
232impl<C: Compiler, K: CubeKernel> CubeTask<C> for KernelTask<C, K> {
233    fn compile(
234        &self,
235        compiler: &mut C,
236        compilation_options: &C::CompilationOptions,
237        mode: ExecutionMode,
238    ) -> CompiledKernel<C> {
239        let gpu_ir = self.kernel_definition.define();
240        let entrypoint_name = gpu_ir.options.kernel_name.clone();
241        let cube_dim = gpu_ir.cube_dim;
242        let lower_level_ir = compiler.compile(gpu_ir, compilation_options, mode);
243
244        CompiledKernel {
245            entrypoint_name,
246            debug_name: Some(core::any::type_name::<K>()),
247            source: lower_level_ir.to_string(),
248            repr: Some(lower_level_ir),
249            cube_dim,
250            debug_info: None,
251        }
252    }
253}
254
255impl<C: Compiler, K: CubeKernel> KernelMetadata for KernelTask<C, K> {
256    // Forward ID to underlying kernel definition.
257    fn id(&self) -> KernelId {
258        self.kernel_definition.id()
259    }
260
261    // Forward name to underlying kernel definition.
262    fn name(&self) -> &'static str {
263        self.kernel_definition.name()
264    }
265}
266
267impl<C: Compiler> KernelMetadata for Box<dyn CubeTask<C>> {
268    // Deref and use existing ID.
269    fn id(&self) -> KernelId {
270        self.as_ref().id()
271    }
272
273    // Deref and use existing name.
274    fn name(&self) -> &'static str {
275        self.as_ref().name()
276    }
277}