1use alloc::{
2 boxed::Box,
3 format,
4 string::{String, ToString},
5 vec::Vec,
6};
7use core::{
8 fmt::Display,
9 marker::PhantomData,
10 sync::atomic::{AtomicI8, Ordering},
11};
12
13use cubecl_common::format::format_str;
14use cubecl_ir::{Id, Scope, StorageType, Type};
15use serde::{Deserialize, Serialize};
16
17use crate::{
18 compiler::{CompilationError, Compiler, CubeTask},
19 config::{GlobalConfig, compilation::CompilationLogLevel},
20 id::KernelId,
21 server::{CubeDim, ExecutionMode},
22};
23
24pub trait KernelMetadata: Send + Sync + 'static {
26 fn name(&self) -> &'static str {
28 core::any::type_name::<Self>()
29 }
30
31 fn id(&self) -> KernelId;
33}
34
35#[derive(Debug, Clone)]
36#[allow(missing_docs)]
37pub struct KernelDefinition {
38 pub buffers: Vec<Binding>,
39 pub tensor_maps: Vec<Binding>,
40 pub scalars: Vec<ScalarBinding>,
41 pub cube_dim: CubeDim,
42 pub body: Scope,
43 pub options: KernelOptions,
44}
45
46#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
47pub struct KernelOptions {
49 pub kernel_name: String,
51 pub debug_symbols: bool,
53 pub cluster_dim: Option<CubeDim>,
55}
56
57#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
58#[allow(missing_docs)]
59pub struct Binding {
60 pub id: Id,
61 pub location: Location,
62 pub visibility: Visibility,
63 pub ty: Type,
64 pub size: Option<usize>,
65 pub has_extended_meta: bool,
66}
67
68#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
69#[allow(missing_docs)]
70pub struct ScalarBinding {
71 pub ty: StorageType,
72 pub count: usize,
73}
74
75#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
76#[allow(missing_docs)]
77pub enum Location {
78 Storage,
79 Cube,
80}
81
82#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
83#[allow(missing_docs)]
84pub enum Visibility {
85 Read,
86 ReadWrite,
87}
88
89pub struct CompiledKernel<C: Compiler> {
91 pub entrypoint_name: String,
101
102 pub debug_name: Option<&'static str>,
120
121 pub source: String,
123 pub repr: Option<C::Representation>,
125 pub cube_dim: CubeDim,
127 pub debug_info: Option<DebugInformation>,
129}
130
131#[derive(new)]
133pub struct DebugInformation {
134 pub lang_tag: &'static str,
136 pub id: KernelId,
138}
139
140pub trait CubeKernel: KernelMetadata {
142 fn define(&self) -> KernelDefinition;
144}
145
146pub struct KernelTask<C: Compiler, K: CubeKernel> {
148 kernel_definition: K,
149 _compiler: PhantomData<C>,
150}
151
152pub struct CubeTaskKernel<C: Compiler> {
154 pub task: Box<dyn CubeTask<C>>,
156}
157
158impl<C: Compiler, K: CubeKernel> KernelTask<C, K> {
159 pub fn new(kernel_definition: K) -> Self {
161 Self {
162 kernel_definition,
163 _compiler: PhantomData,
164 }
165 }
166}
167
168impl<C: Compiler, K: CubeKernel> CubeTask<C> for KernelTask<C, K> {
169 fn compile(
170 &self,
171 compiler: &mut C,
172 compilation_options: &C::CompilationOptions,
173 mode: ExecutionMode,
174 ) -> Result<CompiledKernel<C>, CompilationError> {
175 let gpu_ir = self.kernel_definition.define();
176 let entrypoint_name = gpu_ir.options.kernel_name.clone();
177 let cube_dim = gpu_ir.cube_dim;
178 let lower_level_ir = compiler.compile(gpu_ir, compilation_options, mode)?;
179
180 Ok(CompiledKernel {
181 entrypoint_name,
182 debug_name: Some(core::any::type_name::<K>()),
183 source: lower_level_ir.to_string(),
184 repr: Some(lower_level_ir),
185 cube_dim,
186 debug_info: None,
187 })
188 }
189}
190
191impl<C: Compiler, K: CubeKernel> KernelMetadata for KernelTask<C, K> {
192 fn id(&self) -> KernelId {
194 self.kernel_definition.id()
195 }
196
197 fn name(&self) -> &'static str {
199 self.kernel_definition.name()
200 }
201}
202
203impl<C: Compiler> KernelMetadata for Box<dyn CubeTask<C>> {
204 fn id(&self) -> KernelId {
206 self.as_ref().id()
207 }
208
209 fn name(&self) -> &'static str {
211 self.as_ref().name()
212 }
213}
214
215static COMPILATION_LEVEL: AtomicI8 = AtomicI8::new(-1);
216
217fn compilation_level() -> u8 {
218 let compilation_level = COMPILATION_LEVEL.load(Ordering::Relaxed);
219 if compilation_level == -1 {
220 let val = match GlobalConfig::get().compilation.logger.level {
221 CompilationLogLevel::Full => 2,
222 CompilationLogLevel::Disabled => 0,
223 CompilationLogLevel::Basic => 1,
224 };
225
226 COMPILATION_LEVEL.store(val, Ordering::Relaxed);
227 val as u8
228 } else {
229 compilation_level as u8
230 }
231}
232
233impl<C: Compiler> Display for CompiledKernel<C> {
234 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
235 match compilation_level() {
236 2 => self.format_full(f),
237 _ => self.format_basic(f),
238 }
239 }
240}
241
242impl<C: Compiler> CompiledKernel<C> {
243 fn format_basic(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
244 f.write_str("[Compiling kernel]")?;
245 if let Some(name) = self.debug_name {
246 if name.len() <= 32 {
247 f.write_fmt(format_args!(" {name}"))?;
248 } else {
249 f.write_fmt(format_args!(" {}", name.split('<').next().unwrap_or("")))?;
250 }
251 }
252
253 Ok(())
254 }
255
256 fn format_full(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
257 f.write_str("[START_KERNEL_COMPILATION]")?;
258
259 if let Some(name) = self.debug_name {
260 if name.len() <= 32 {
261 f.write_fmt(format_args!("\nname: {name}"))?;
262 } else {
263 let name = format_str(name, &[('<', '>')], false);
264 f.write_fmt(format_args!("\nname: {name}"))?;
265 }
266 }
267
268 f.write_fmt(format_args!(
269 "
270cube_dim: ({}, {}, {})",
271 self.cube_dim.x, self.cube_dim.y, self.cube_dim.z,
272 ))?;
273
274 if let Some(info) = &self.debug_info {
275 f.write_fmt(format_args!(
276 "\ninfo: {}",
277 format_str(
278 format!("{:?}", info.id).as_str(),
279 &[('(', ')'), ('[', ']'), ('{', '}')],
280 true
281 )
282 ))?;
283 }
284
285 f.write_fmt(format_args!(
286 "
287source:
288```{}
289{}
290```
291[END_KERNEL_COMPILATION]
292",
293 self.debug_info
294 .as_ref()
295 .map(|info| info.lang_tag)
296 .unwrap_or(""),
297 self.source
298 ))
299 }
300}