1use alloc::{
2 boxed::Box,
3 string::{String, ToString},
4 vec::Vec,
5};
6use core::{
7 fmt::Display,
8 marker::PhantomData,
9 sync::atomic::{AtomicI8, Ordering},
10};
11
12use cubecl_common::format::format_str;
13use cubecl_ir::{Id, Scope, StorageType, Type};
14use serde::{Deserialize, Serialize};
15
16use crate::{
17 compiler::{CompilationError, Compiler, CubeTask},
18 config::{GlobalConfig, compilation::CompilationLogLevel},
19 id::KernelId,
20 server::{CubeDim, ExecutionMode},
21};
22
23pub trait KernelMetadata: Send + Sync + 'static {
25 fn name(&self) -> &'static str {
27 core::any::type_name::<Self>()
28 }
29
30 fn id(&self) -> KernelId;
32
33 fn address_type(&self) -> StorageType;
35}
36
37#[derive(Debug, Clone)]
38#[allow(missing_docs)]
39pub struct KernelDefinition {
40 pub buffers: Vec<KernelArg>,
41 pub tensor_maps: Vec<KernelArg>,
42 pub scalars: Vec<ScalarKernelArg>,
43 pub cube_dim: CubeDim,
44 pub body: Scope,
45 pub options: KernelOptions,
46}
47
48#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
49pub struct KernelOptions {
51 pub kernel_name: String,
53 pub debug_symbols: bool,
55 pub cluster_dim: Option<CubeDim>,
57}
58
59#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
60pub struct KernelArg {
62 pub id: Id,
64 pub visibility: Visibility,
67 pub ty: Type,
69 pub size: Option<usize>,
71 pub has_extended_meta: bool,
73}
74
75#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
76#[allow(missing_docs)]
77pub struct ScalarKernelArg {
78 pub ty: StorageType,
79 pub count: usize,
80}
81
82#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
83#[allow(missing_docs)]
84pub enum Visibility {
85 Read,
86 ReadWrite,
87}
88
89pub struct CompiledKernel<C: Compiler> {
91 pub entrypoint_name: String,
101
102 pub debug_name: Option<&'static str>,
120
121 pub source: String,
123 pub repr: Option<C::Representation>,
125 pub cube_dim: CubeDim,
127 pub debug_info: Option<DebugInformation>,
129}
130
131#[derive(new)]
133pub struct DebugInformation {
134 pub lang_tag: &'static str,
136 pub id: KernelId,
138}
139
140pub trait CubeKernel: KernelMetadata {
142 fn define(&self) -> KernelDefinition;
144}
145
146pub struct KernelTask<C: Compiler, K: CubeKernel> {
148 kernel_definition: K,
149 _compiler: PhantomData<C>,
150}
151
152pub struct CubeTaskKernel<C: Compiler> {
154 pub task: Box<dyn CubeTask<C>>,
156}
157
158impl<C: Compiler, K: CubeKernel> KernelTask<C, K> {
159 pub fn new(kernel_definition: K) -> Self {
161 Self {
162 kernel_definition,
163 _compiler: PhantomData,
164 }
165 }
166}
167
168impl<C: Compiler, K: CubeKernel> CubeTask<C> for KernelTask<C, K> {
169 fn compile(
170 &self,
171 compiler: &mut C,
172 compilation_options: &C::CompilationOptions,
173 mode: ExecutionMode,
174 addr_type: StorageType,
175 ) -> Result<CompiledKernel<C>, CompilationError> {
176 let gpu_ir = self.kernel_definition.define();
177 let entrypoint_name = gpu_ir.options.kernel_name.clone();
178 let cube_dim = gpu_ir.cube_dim;
179 let lower_level_ir = compiler.compile(gpu_ir, compilation_options, mode, addr_type)?;
180
181 Ok(CompiledKernel {
182 entrypoint_name,
183 debug_name: Some(core::any::type_name::<K>()),
184 source: lower_level_ir.to_string(),
185 repr: Some(lower_level_ir),
186 cube_dim,
187 debug_info: None,
188 })
189 }
190}
191
192impl<C: Compiler, K: CubeKernel> KernelMetadata for KernelTask<C, K> {
193 fn id(&self) -> KernelId {
195 self.kernel_definition.id()
196 }
197
198 fn name(&self) -> &'static str {
200 self.kernel_definition.name()
201 }
202
203 fn address_type(&self) -> StorageType {
204 self.kernel_definition.address_type()
205 }
206}
207
208impl<C: Compiler> KernelMetadata for Box<dyn CubeTask<C>> {
209 fn id(&self) -> KernelId {
211 self.as_ref().id()
212 }
213
214 fn name(&self) -> &'static str {
216 self.as_ref().name()
217 }
218
219 fn address_type(&self) -> StorageType {
220 self.as_ref().address_type()
221 }
222}
223
224static COMPILATION_LEVEL: AtomicI8 = AtomicI8::new(-1);
225
226fn compilation_level() -> u8 {
227 let compilation_level = COMPILATION_LEVEL.load(Ordering::Relaxed);
228 if compilation_level == -1 {
229 let val = match GlobalConfig::get().compilation.logger.level {
230 CompilationLogLevel::Full => 2,
231 CompilationLogLevel::Disabled => 0,
232 CompilationLogLevel::Basic => 1,
233 };
234
235 COMPILATION_LEVEL.store(val, Ordering::Relaxed);
236 val as u8
237 } else {
238 compilation_level as u8
239 }
240}
241
242impl<C: Compiler> Display for CompiledKernel<C> {
243 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
244 match compilation_level() {
245 2 => self.format_full(f),
246 _ => self.format_basic(f),
247 }
248 }
249}
250
251impl<C: Compiler> CompiledKernel<C> {
252 fn format_basic(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
253 f.write_str("[Compiling kernel]")?;
254 if let Some(name) = self.debug_name {
255 if name.len() <= 32 {
256 f.write_fmt(format_args!(" {name}"))?;
257 } else {
258 f.write_fmt(format_args!(" {}", name.split('<').next().unwrap_or("")))?;
259 }
260 }
261
262 Ok(())
263 }
264
265 fn format_full(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
266 f.write_str("[START_KERNEL_COMPILATION]")?;
267
268 if let Some(name) = self.debug_name {
269 if name.len() <= 32 {
270 f.write_fmt(format_args!("\nname: {name}"))?;
271 } else {
272 let name = format_str(name, &[('<', '>')], false);
273 f.write_fmt(format_args!("\nname: {name}"))?;
274 }
275 }
276
277 if let Some(info) = &self.debug_info {
278 f.write_fmt(format_args!("\nid: {:#?}", info.id))?;
279 }
280
281 f.write_fmt(format_args!(
282 "
283source:
284```{}
285{}
286```
287[END_KERNEL_COMPILATION]
288",
289 self.debug_info
290 .as_ref()
291 .map(|info| info.lang_tag)
292 .unwrap_or(""),
293 self.source
294 ))
295 }
296}