1use alloc::{
2 boxed::Box,
3 string::{String, ToString},
4 vec::Vec,
5};
6use core::{
7 fmt::Display,
8 marker::PhantomData,
9 sync::atomic::{AtomicI8, Ordering},
10};
11
12use cubecl_common::format::format_str;
13use cubecl_ir::{Id, Scope, StorageType, Type};
14use serde::{Deserialize, Serialize};
15
16use crate::{
17 compiler::{CompilationError, Compiler, CubeTask},
18 config::{GlobalConfig, compilation::CompilationLogLevel},
19 id::KernelId,
20 server::{CubeDim, ExecutionMode},
21};
22
23pub trait KernelMetadata: Send + Sync + 'static {
25 fn name(&self) -> &'static str {
27 core::any::type_name::<Self>()
28 }
29
30 fn id(&self) -> KernelId;
32
33 fn address_type(&self) -> StorageType;
35}
36
37#[derive(Debug, Clone)]
38#[allow(missing_docs)]
39pub struct KernelDefinition {
40 pub buffers: Vec<Binding>,
41 pub tensor_maps: Vec<Binding>,
42 pub scalars: Vec<ScalarBinding>,
43 pub cube_dim: CubeDim,
44 pub body: Scope,
45 pub options: KernelOptions,
46}
47
48#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
49pub struct KernelOptions {
51 pub kernel_name: String,
53 pub debug_symbols: bool,
55 pub cluster_dim: Option<CubeDim>,
57}
58
59#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
60#[allow(missing_docs)]
61pub struct Binding {
62 pub id: Id,
63 pub location: Location,
64 pub visibility: Visibility,
65 pub ty: Type,
66 pub size: Option<usize>,
67 pub has_extended_meta: bool,
68}
69
70#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
71#[allow(missing_docs)]
72pub struct ScalarBinding {
73 pub ty: StorageType,
74 pub count: usize,
75}
76
77#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
78#[allow(missing_docs)]
79pub enum Location {
80 Storage,
81 Cube,
82}
83
84#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
85#[allow(missing_docs)]
86pub enum Visibility {
87 Read,
88 ReadWrite,
89}
90
91pub struct CompiledKernel<C: Compiler> {
93 pub entrypoint_name: String,
103
104 pub debug_name: Option<&'static str>,
122
123 pub source: String,
125 pub repr: Option<C::Representation>,
127 pub cube_dim: CubeDim,
129 pub debug_info: Option<DebugInformation>,
131}
132
133#[derive(new)]
135pub struct DebugInformation {
136 pub lang_tag: &'static str,
138 pub id: KernelId,
140}
141
142pub trait CubeKernel: KernelMetadata {
144 fn define(&self) -> KernelDefinition;
146}
147
148pub struct KernelTask<C: Compiler, K: CubeKernel> {
150 kernel_definition: K,
151 _compiler: PhantomData<C>,
152}
153
154pub struct CubeTaskKernel<C: Compiler> {
156 pub task: Box<dyn CubeTask<C>>,
158}
159
160impl<C: Compiler, K: CubeKernel> KernelTask<C, K> {
161 pub fn new(kernel_definition: K) -> Self {
163 Self {
164 kernel_definition,
165 _compiler: PhantomData,
166 }
167 }
168}
169
170impl<C: Compiler, K: CubeKernel> CubeTask<C> for KernelTask<C, K> {
171 fn compile(
172 &self,
173 compiler: &mut C,
174 compilation_options: &C::CompilationOptions,
175 mode: ExecutionMode,
176 addr_type: StorageType,
177 ) -> Result<CompiledKernel<C>, CompilationError> {
178 let gpu_ir = self.kernel_definition.define();
179 let entrypoint_name = gpu_ir.options.kernel_name.clone();
180 let cube_dim = gpu_ir.cube_dim;
181 let lower_level_ir = compiler.compile(gpu_ir, compilation_options, mode, addr_type)?;
182
183 Ok(CompiledKernel {
184 entrypoint_name,
185 debug_name: Some(core::any::type_name::<K>()),
186 source: lower_level_ir.to_string(),
187 repr: Some(lower_level_ir),
188 cube_dim,
189 debug_info: None,
190 })
191 }
192}
193
194impl<C: Compiler, K: CubeKernel> KernelMetadata for KernelTask<C, K> {
195 fn id(&self) -> KernelId {
197 self.kernel_definition.id()
198 }
199
200 fn name(&self) -> &'static str {
202 self.kernel_definition.name()
203 }
204
205 fn address_type(&self) -> StorageType {
206 self.kernel_definition.address_type()
207 }
208}
209
210impl<C: Compiler> KernelMetadata for Box<dyn CubeTask<C>> {
211 fn id(&self) -> KernelId {
213 self.as_ref().id()
214 }
215
216 fn name(&self) -> &'static str {
218 self.as_ref().name()
219 }
220
221 fn address_type(&self) -> StorageType {
222 self.as_ref().address_type()
223 }
224}
225
226static COMPILATION_LEVEL: AtomicI8 = AtomicI8::new(-1);
227
228fn compilation_level() -> u8 {
229 let compilation_level = COMPILATION_LEVEL.load(Ordering::Relaxed);
230 if compilation_level == -1 {
231 let val = match GlobalConfig::get().compilation.logger.level {
232 CompilationLogLevel::Full => 2,
233 CompilationLogLevel::Disabled => 0,
234 CompilationLogLevel::Basic => 1,
235 };
236
237 COMPILATION_LEVEL.store(val, Ordering::Relaxed);
238 val as u8
239 } else {
240 compilation_level as u8
241 }
242}
243
244impl<C: Compiler> Display for CompiledKernel<C> {
245 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
246 match compilation_level() {
247 2 => self.format_full(f),
248 _ => self.format_basic(f),
249 }
250 }
251}
252
253impl<C: Compiler> CompiledKernel<C> {
254 fn format_basic(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
255 f.write_str("[Compiling kernel]")?;
256 if let Some(name) = self.debug_name {
257 if name.len() <= 32 {
258 f.write_fmt(format_args!(" {name}"))?;
259 } else {
260 f.write_fmt(format_args!(" {}", name.split('<').next().unwrap_or("")))?;
261 }
262 }
263
264 Ok(())
265 }
266
267 fn format_full(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
268 f.write_str("[START_KERNEL_COMPILATION]")?;
269
270 if let Some(name) = self.debug_name {
271 if name.len() <= 32 {
272 f.write_fmt(format_args!("\nname: {name}"))?;
273 } else {
274 let name = format_str(name, &[('<', '>')], false);
275 f.write_fmt(format_args!("\nname: {name}"))?;
276 }
277 }
278
279 if let Some(info) = &self.debug_info {
280 f.write_fmt(format_args!("\nid: {:#?}", info.id))?;
281 }
282
283 f.write_fmt(format_args!(
284 "
285source:
286```{}
287{}
288```
289[END_KERNEL_COMPILATION]
290",
291 self.debug_info
292 .as_ref()
293 .map(|info| info.lang_tag)
294 .unwrap_or(""),
295 self.source
296 ))
297 }
298}