1use alloc::{
2 boxed::Box,
3 format,
4 string::{String, ToString},
5 vec::Vec,
6};
7use core::{
8 fmt::Display,
9 marker::PhantomData,
10 sync::atomic::{AtomicI8, Ordering},
11};
12
13use cubecl_common::{CubeDim, ExecutionMode, format::format_str};
14use cubecl_ir::{Id, Scope, StorageType, Type};
15use serde::{Deserialize, Serialize};
16
17use crate::{
18 compiler::{CompilationError, Compiler, CubeTask},
19 config::{GlobalConfig, compilation::CompilationLogLevel},
20 id::KernelId,
21};
22
23pub trait KernelMetadata: Send + Sync + 'static {
25 fn name(&self) -> &'static str {
27 core::any::type_name::<Self>()
28 }
29
30 fn id(&self) -> KernelId;
32}
33
34#[derive(Debug, Clone)]
35#[allow(missing_docs)]
36pub struct KernelDefinition {
37 pub buffers: Vec<Binding>,
38 pub tensor_maps: Vec<Binding>,
39 pub scalars: Vec<ScalarBinding>,
40 pub cube_dim: CubeDim,
41 pub body: Scope,
42 pub options: KernelOptions,
43}
44
45#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
46pub struct KernelOptions {
48 pub kernel_name: String,
50 pub debug_symbols: bool,
52 pub cluster_dim: Option<CubeDim>,
54}
55
56#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
57#[allow(missing_docs)]
58pub struct Binding {
59 pub id: Id,
60 pub location: Location,
61 pub visibility: Visibility,
62 pub ty: Type,
63 pub size: Option<usize>,
64 pub has_extended_meta: bool,
65}
66
67#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
68#[allow(missing_docs)]
69pub struct ScalarBinding {
70 pub ty: StorageType,
71 pub count: usize,
72}
73
74#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
75#[allow(missing_docs)]
76pub enum Location {
77 Storage,
78 Cube,
79}
80
81#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
82#[allow(missing_docs)]
83pub enum Visibility {
84 Read,
85 ReadWrite,
86}
87
88pub struct CompiledKernel<C: Compiler> {
90 pub entrypoint_name: String,
100
101 pub debug_name: Option<&'static str>,
119
120 pub source: String,
122 pub repr: Option<C::Representation>,
124 pub cube_dim: CubeDim,
126 pub debug_info: Option<DebugInformation>,
128}
129
130#[derive(new)]
132pub struct DebugInformation {
133 pub lang_tag: &'static str,
135 pub id: KernelId,
137}
138
139pub trait CubeKernel: KernelMetadata {
141 fn define(&self) -> KernelDefinition;
143}
144
145pub struct KernelTask<C: Compiler, K: CubeKernel> {
147 kernel_definition: K,
148 _compiler: PhantomData<C>,
149}
150
151pub struct CubeTaskKernel<C: Compiler> {
153 pub task: Box<dyn CubeTask<C>>,
155}
156
157impl<C: Compiler, K: CubeKernel> KernelTask<C, K> {
158 pub fn new(kernel_definition: K) -> Self {
160 Self {
161 kernel_definition,
162 _compiler: PhantomData,
163 }
164 }
165}
166
167impl<C: Compiler, K: CubeKernel> CubeTask<C> for KernelTask<C, K> {
168 fn compile(
169 &self,
170 compiler: &mut C,
171 compilation_options: &C::CompilationOptions,
172 mode: ExecutionMode,
173 ) -> Result<CompiledKernel<C>, CompilationError> {
174 let gpu_ir = self.kernel_definition.define();
175 let entrypoint_name = gpu_ir.options.kernel_name.clone();
176 let cube_dim = gpu_ir.cube_dim;
177 let lower_level_ir = compiler.compile(gpu_ir, compilation_options, mode)?;
178
179 Ok(CompiledKernel {
180 entrypoint_name,
181 debug_name: Some(core::any::type_name::<K>()),
182 source: lower_level_ir.to_string(),
183 repr: Some(lower_level_ir),
184 cube_dim,
185 debug_info: None,
186 })
187 }
188}
189
190impl<C: Compiler, K: CubeKernel> KernelMetadata for KernelTask<C, K> {
191 fn id(&self) -> KernelId {
193 self.kernel_definition.id()
194 }
195
196 fn name(&self) -> &'static str {
198 self.kernel_definition.name()
199 }
200}
201
202impl<C: Compiler> KernelMetadata for Box<dyn CubeTask<C>> {
203 fn id(&self) -> KernelId {
205 self.as_ref().id()
206 }
207
208 fn name(&self) -> &'static str {
210 self.as_ref().name()
211 }
212}
213
214static COMPILATION_LEVEL: AtomicI8 = AtomicI8::new(-1);
215
216fn compilation_level() -> u8 {
217 let compilation_level = COMPILATION_LEVEL.load(Ordering::Relaxed);
218 if compilation_level == -1 {
219 let val = match GlobalConfig::get().compilation.logger.level {
220 CompilationLogLevel::Full => 2,
221 CompilationLogLevel::Disabled => 0,
222 CompilationLogLevel::Basic => 1,
223 };
224
225 COMPILATION_LEVEL.store(val, Ordering::Relaxed);
226 val as u8
227 } else {
228 compilation_level as u8
229 }
230}
231
232impl<C: Compiler> Display for CompiledKernel<C> {
233 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
234 match compilation_level() {
235 2 => self.format_full(f),
236 _ => self.format_basic(f),
237 }
238 }
239}
240
241impl<C: Compiler> CompiledKernel<C> {
242 fn format_basic(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
243 f.write_str("[Compiling kernel]")?;
244 if let Some(name) = self.debug_name {
245 if name.len() <= 32 {
246 f.write_fmt(format_args!(" {name}"))?;
247 } else {
248 f.write_fmt(format_args!(" {}", name.split('<').next().unwrap_or("")))?;
249 }
250 }
251
252 Ok(())
253 }
254
255 fn format_full(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
256 f.write_str("[START_KERNEL_COMPILATION]")?;
257
258 if let Some(name) = self.debug_name {
259 if name.len() <= 32 {
260 f.write_fmt(format_args!("\nname: {name}"))?;
261 } else {
262 let name = format_str(name, &[('<', '>')], false);
263 f.write_fmt(format_args!("\nname: {name}"))?;
264 }
265 }
266
267 f.write_fmt(format_args!(
268 "
269cube_dim: ({}, {}, {})",
270 self.cube_dim.x, self.cube_dim.y, self.cube_dim.z,
271 ))?;
272
273 if let Some(info) = &self.debug_info {
274 f.write_fmt(format_args!(
275 "\ninfo: {}",
276 format_str(
277 format!("{:?}", info.id).as_str(),
278 &[('(', ')'), ('[', ']'), ('{', '}')],
279 true
280 )
281 ))?;
282 }
283
284 f.write_fmt(format_args!(
285 "
286source:
287```{}
288{}
289```
290[END_KERNEL_COMPILATION]
291",
292 self.debug_info
293 .as_ref()
294 .map(|info| info.lang_tag)
295 .unwrap_or(""),
296 self.source
297 ))
298 }
299}