1use std::{
2 fmt::Display,
3 marker::PhantomData,
4 sync::atomic::{AtomicI8, Ordering},
5};
6
7use crate::{Compiler, KernelOptions};
8use cubecl_common::{CubeDim, ExecutionMode};
9use cubecl_ir::{Elem, Id, Item, Scope};
10use cubecl_runtime::{
11 config::{GlobalConfig, compilation::CompilationLogLevel},
12 id::{KernelId, format_str},
13 kernel::KernelMetadata,
14};
15use serde::{Deserialize, Serialize};
16
17pub struct CompiledKernel<C: Compiler> {
19 pub entrypoint_name: String,
29
30 pub debug_name: Option<&'static str>,
48
49 pub source: String,
51 pub repr: Option<C::Representation>,
53 pub cube_dim: CubeDim,
55 pub debug_info: Option<DebugInformation>,
57}
58
59#[derive(new)]
61pub struct DebugInformation {
62 pub lang_tag: &'static str,
64 pub id: KernelId,
66}
67
68static COMPILATION_LEVEL: AtomicI8 = AtomicI8::new(-1);
69
70fn compilation_level() -> u8 {
71 let compilation_level = COMPILATION_LEVEL.load(Ordering::Relaxed);
72 if compilation_level == -1 {
73 let val = match GlobalConfig::get().compilation.logger.level {
74 CompilationLogLevel::Full => 2,
75 CompilationLogLevel::Disabled => 0,
76 CompilationLogLevel::Basic => 1,
77 };
78
79 COMPILATION_LEVEL.store(val, Ordering::Relaxed);
80 val as u8
81 } else {
82 compilation_level as u8
83 }
84}
85
86impl<C: Compiler> Display for CompiledKernel<C> {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 match compilation_level() {
89 2 => self.format_full(f),
90 _ => self.format_basic(f),
91 }
92 }
93}
94
95impl<C: Compiler> CompiledKernel<C> {
96 fn format_basic(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 f.write_str("[Compiling kernel]")?;
98 if let Some(name) = self.debug_name {
99 if name.len() <= 32 {
100 f.write_fmt(format_args!(" {name}"))?;
101 } else {
102 f.write_fmt(format_args!(" {}", name.split('<').next().unwrap_or("")))?;
103 }
104 }
105
106 Ok(())
107 }
108
109 fn format_full(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 f.write_str("[START_KERNEL_COMPILATION]")?;
111
112 if let Some(name) = self.debug_name {
113 if name.len() <= 32 {
114 f.write_fmt(format_args!("\nname: {name}"))?;
115 } else {
116 let name = format_str(name, &[('<', '>')], false);
117 f.write_fmt(format_args!("\nname: {name}"))?;
118 }
119 }
120
121 f.write_fmt(format_args!(
122 "
123cube_dim: ({}, {}, {})",
124 self.cube_dim.x, self.cube_dim.y, self.cube_dim.z,
125 ))?;
126
127 if let Some(info) = &self.debug_info {
128 f.write_fmt(format_args!(
129 "\ninfo: {}",
130 format_str(
131 format!("{:?}", info.id).as_str(),
132 &[('(', ')'), ('[', ']'), ('{', '}')],
133 true
134 )
135 ))?;
136 }
137
138 f.write_fmt(format_args!(
139 "
140source:
141```{}
142{}
143```
144[END_KERNEL_COMPILATION]
145",
146 self.debug_info
147 .as_ref()
148 .map(|info| info.lang_tag)
149 .unwrap_or(""),
150 self.source
151 ))
152 }
153}
154
155#[derive(Debug, Clone)]
156#[allow(missing_docs)]
157pub struct KernelDefinition {
158 pub buffers: Vec<Binding>,
159 pub tensor_maps: Vec<Id>,
160 pub scalars: Vec<ScalarBinding>,
161 pub cube_dim: CubeDim,
162 pub body: Scope,
163 pub options: KernelOptions,
164}
165
166#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
167#[allow(missing_docs)]
168pub struct Binding {
169 pub id: Id,
170 pub location: Location,
171 pub visibility: Visibility,
172 pub item: Item,
173 pub size: Option<usize>,
174 pub has_extended_meta: bool,
175}
176
177#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
178#[allow(missing_docs)]
179pub struct ScalarBinding {
180 pub elem: Elem,
181 pub count: usize,
182}
183
184#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
185#[allow(missing_docs)]
186pub enum Location {
187 Storage,
188 Cube,
189}
190
191#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
192#[allow(missing_docs)]
193pub enum Visibility {
194 Read,
195 ReadWrite,
196}
197
198pub trait CubeKernel: KernelMetadata {
199 fn define(&self) -> KernelDefinition;
200}
201
202pub trait CubeTask<C: Compiler>: KernelMetadata + Send + Sync {
205 fn compile(
206 &self,
207 compiler: &mut C,
208 compilation_options: &C::CompilationOptions,
209 mode: ExecutionMode,
210 ) -> CompiledKernel<C>;
211}
212
213pub struct KernelTask<C: Compiler, K: CubeKernel> {
215 kernel_definition: K,
216 _compiler: PhantomData<C>,
217}
218
219pub struct CubeTaskKernel<C: Compiler> {
220 pub task: Box<dyn CubeTask<C>>,
221}
222
223impl<C: Compiler, K: CubeKernel> KernelTask<C, K> {
224 pub fn new(kernel_definition: K) -> Self {
225 Self {
226 kernel_definition,
227 _compiler: PhantomData,
228 }
229 }
230}
231
232impl<C: Compiler, K: CubeKernel> CubeTask<C> for KernelTask<C, K> {
233 fn compile(
234 &self,
235 compiler: &mut C,
236 compilation_options: &C::CompilationOptions,
237 mode: ExecutionMode,
238 ) -> CompiledKernel<C> {
239 let gpu_ir = self.kernel_definition.define();
240 let entrypoint_name = gpu_ir.options.kernel_name.clone();
241 let cube_dim = gpu_ir.cube_dim;
242 let lower_level_ir = compiler.compile(gpu_ir, compilation_options, mode);
243
244 CompiledKernel {
245 entrypoint_name,
246 debug_name: Some(core::any::type_name::<K>()),
247 source: lower_level_ir.to_string(),
248 repr: Some(lower_level_ir),
249 cube_dim,
250 debug_info: None,
251 }
252 }
253}
254
255impl<C: Compiler, K: CubeKernel> KernelMetadata for KernelTask<C, K> {
256 fn id(&self) -> KernelId {
258 self.kernel_definition.id()
259 }
260
261 fn name(&self) -> &'static str {
263 self.kernel_definition.name()
264 }
265}
266
267impl<C: Compiler> KernelMetadata for Box<dyn CubeTask<C>> {
268 fn id(&self) -> KernelId {
270 self.as_ref().id()
271 }
272
273 fn name(&self) -> &'static str {
275 self.as_ref().name()
276 }
277}