1use std::{fmt::Display, marker::PhantomData};
2
3use crate::{Compiler, Kernel, KernelId, KernelOptions};
4use alloc::sync::Arc;
5use cubecl_common::{CubeDim, ExecutionMode};
6use cubecl_ir::{Elem, Id, Item, Scope};
7use serde::{Deserialize, Serialize};
8
9pub struct CompiledKernel<C: Compiler> {
11 pub entrypoint_name: String,
21
22 pub debug_name: Option<&'static str>,
40
41 pub source: String,
43 pub repr: Option<C::Representation>,
45 pub cube_dim: CubeDim,
47 pub debug_info: Option<DebugInformation>,
49}
50
51#[derive(new)]
53pub struct DebugInformation {
54 pub lang_tag: &'static str,
56 pub id: KernelId,
58}
59
60impl Display for KernelId {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match &self.info {
63 Some(info) => f.write_str(
64 format_str(
65 format!("{:?}", info).as_str(),
66 &[('(', ')'), ('[', ']'), ('{', '}')],
67 true,
68 )
69 .as_str(),
70 ),
71 None => f.write_str("No info"),
72 }
73 }
74}
75
76impl<C: Compiler> Display for CompiledKernel<C> {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 f.write_str("\n[START_KERNEL_COMPILATION]")?;
79
80 if let Some(name) = self.debug_name {
81 if name.len() <= 32 {
82 f.write_fmt(format_args!("\nname: {name}"))?;
83 } else {
84 let name = format_str(name, &[('<', '>')], false);
85 f.write_fmt(format_args!("\nname: {name}"))?;
86 }
87 }
88
89 f.write_fmt(format_args!(
90 "
91cube_dim: ({}, {}, {})",
92 self.cube_dim.x, self.cube_dim.y, self.cube_dim.z,
93 ))?;
94
95 if let Some(info) = &self.debug_info {
96 f.write_fmt(format_args!(
97 "\ninfo: {}",
98 format_str(
99 format!("{:?}", info.id).as_str(),
100 &[('(', ')'), ('[', ']'), ('{', '}')],
101 true
102 )
103 ))?;
104 }
105
106 f.write_fmt(format_args!(
107 "
108source:
109```{}
110{}
111```
112[END_KERNEL_COMPILATION]
113",
114 self.debug_info
115 .as_ref()
116 .map(|info| info.lang_tag)
117 .unwrap_or(""),
118 self.source
119 ))
120 }
121}
122
123fn format_str(kernel_id: &str, markers: &[(char, char)], include_space: bool) -> String {
124 let kernel_id = kernel_id.to_string();
125 let mut result = String::new();
126 let mut depth = 0;
127 let indentation = 4;
128
129 let mut prev = ' ';
130
131 for c in kernel_id.chars() {
132 if c == ' ' {
133 continue;
134 }
135
136 let mut found_marker = false;
137
138 for (start, end) in markers {
139 let (start, end) = (*start, *end);
140
141 if c == start {
142 depth += 1;
143 if prev != ' ' && include_space {
144 result.push(' ');
145 }
146 result.push(start);
147 result.push('\n');
148 result.push_str(&" ".repeat(indentation * depth));
149 found_marker = true;
150 } else if c == end {
151 depth -= 1;
152 if prev != start {
153 if prev == ' ' {
154 result.pop();
155 }
156 result.push_str(",\n");
157 result.push_str(&" ".repeat(indentation * depth));
158 result.push(end);
159 } else {
160 for _ in 0..(&" ".repeat(indentation * depth).len()) + 1 + indentation {
161 result.pop();
162 }
163 result.push(end);
164 }
165 found_marker = true;
166 }
167 }
168
169 if found_marker {
170 prev = c;
171 continue;
172 }
173
174 if c == ',' && depth > 0 {
175 if prev == ' ' {
176 result.pop();
177 }
178
179 result.push_str(",\n");
180 result.push_str(&" ".repeat(indentation * depth));
181 continue;
182 }
183
184 if c == ':' && include_space {
185 result.push(c);
186 result.push(' ');
187 prev = ' ';
188 } else {
189 result.push(c);
190 prev = c;
191 }
192 }
193
194 result
195}
196
197#[derive(Debug, Clone)]
198#[allow(missing_docs)]
199pub struct KernelDefinition {
200 pub buffers: Vec<Binding>,
201 pub tensor_maps: Vec<Id>,
202 pub scalars: Vec<ScalarBinding>,
203 pub cube_dim: CubeDim,
204 pub body: Scope,
205 pub options: KernelOptions,
206}
207
208#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
209#[allow(missing_docs)]
210pub struct Binding {
211 pub id: Id,
212 pub location: Location,
213 pub visibility: Visibility,
214 pub item: Item,
215 pub size: Option<usize>,
216 pub has_extended_meta: bool,
217}
218
219#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
220#[allow(missing_docs)]
221pub struct ScalarBinding {
222 pub elem: Elem,
223 pub count: usize,
224}
225
226#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
227#[allow(missing_docs)]
228pub enum Location {
229 Storage,
230 Cube,
231}
232
233#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
234#[allow(missing_docs)]
235pub enum Visibility {
236 Read,
237 ReadWrite,
238}
239
240pub trait CubeTask<C: Compiler>: Send + Sync {
243 fn id(&self) -> KernelId;
245 fn compile(
247 &self,
248 compiler: &mut C,
249 compilation_options: &C::CompilationOptions,
250 mode: ExecutionMode,
251 ) -> CompiledKernel<C>;
252 fn name(&self) -> &'static str {
253 core::any::type_name::<Self>()
254 }
255}
256
257#[derive(new)]
259pub struct KernelTask<C: Compiler, K: Kernel> {
260 kernel_definition: K,
261 _compiler: PhantomData<C>,
262}
263
264impl<C: Compiler, K: Kernel> CubeTask<C> for KernelTask<C, K> {
265 fn compile(
266 &self,
267 compiler: &mut C,
268 compilation_options: &C::CompilationOptions,
269 mode: ExecutionMode,
270 ) -> CompiledKernel<C> {
271 let gpu_ir = self.kernel_definition.define();
272 let entrypoint_name = gpu_ir.options.kernel_name.clone();
273 let cube_dim = gpu_ir.cube_dim;
274 let lower_level_ir = compiler.compile(gpu_ir, compilation_options, mode);
275
276 CompiledKernel {
277 entrypoint_name,
278 debug_name: Some(core::any::type_name::<K>()),
279 source: lower_level_ir.to_string(),
280 repr: Some(lower_level_ir),
281 cube_dim,
282 debug_info: None,
283 }
284 }
285
286 fn id(&self) -> KernelId {
287 self.kernel_definition.id().clone()
288 }
289
290 fn name(&self) -> &'static str {
291 core::any::type_name::<K>()
292 }
293}
294
295impl<C: Compiler> CubeTask<C> for Arc<dyn CubeTask<C>> {
296 fn compile(
297 &self,
298 compiler: &mut C,
299 compilation_options: &C::CompilationOptions,
300 mode: ExecutionMode,
301 ) -> CompiledKernel<C> {
302 self.as_ref().compile(compiler, compilation_options, mode)
303 }
304
305 fn id(&self) -> KernelId {
306 self.as_ref().id()
307 }
308 fn name(&self) -> &'static str {
309 self.as_ref().name()
310 }
311}
312
313impl<C: Compiler> CubeTask<C> for Box<dyn CubeTask<C>> {
314 fn compile(
315 &self,
316 compiler: &mut C,
317 compilation_options: &C::CompilationOptions,
318 mode: ExecutionMode,
319 ) -> CompiledKernel<C> {
320 self.as_ref().compile(compiler, compilation_options, mode)
321 }
322
323 fn id(&self) -> KernelId {
324 self.as_ref().id()
325 }
326
327 fn name(&self) -> &'static str {
328 self.as_ref().name()
329 }
330}