1use std::{fmt::Display, marker::PhantomData};
2
3use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel, KernelId};
4use alloc::sync::Arc;
5use cubecl_runtime::ExecutionMode;
6
7pub struct CompiledKernel<C: Compiler> {
9 pub entrypoint_name: String,
19
20 pub debug_name: Option<&'static str>,
38
39 pub source: String,
41 pub repr: Option<C::Representation>,
43 pub cube_dim: CubeDim,
45 pub shared_mem_bytes: usize,
47 pub debug_info: Option<DebugInformation>,
49}
50
51#[derive(new)]
53pub struct DebugInformation {
54 pub lang_tag: &'static str,
56 pub id: KernelId,
58}
59
60impl Display for KernelId {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match &self.info {
63 Some(info) => f.write_str(
64 format_str(
65 format!("{:?}", info).as_str(),
66 &[('(', ')'), ('[', ']'), ('{', '}')],
67 true,
68 )
69 .as_str(),
70 ),
71 None => f.write_str("No info"),
72 }
73 }
74}
75
76impl<C: Compiler> Display for CompiledKernel<C> {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 f.write_str("\n[START_KERNEL_COMPILATION]")?;
79
80 if let Some(name) = self.debug_name {
81 if name.len() <= 32 {
82 f.write_fmt(format_args!("\nname: {name}"))?;
83 } else {
84 let name = format_str(name, &[('<', '>')], false);
85 f.write_fmt(format_args!("\nname: {name}"))?;
86 }
87 }
88
89 f.write_fmt(format_args!(
90 "
91cube_dim: ({}, {}, {})
92shared_memory: {} bytes",
93 self.cube_dim.x, self.cube_dim.y, self.cube_dim.z, self.shared_mem_bytes,
94 ))?;
95
96 if let Some(info) = &self.debug_info {
97 f.write_fmt(format_args!(
98 "\ninfo: {}",
99 format_str(
100 format!("{:?}", info.id).as_str(),
101 &[('(', ')'), ('[', ']'), ('{', '}')],
102 true
103 )
104 ))?;
105 }
106
107 f.write_fmt(format_args!(
108 "
109source:
110```{}
111{}
112```
113[END_KERNEL_COMPILATION]
114",
115 self.debug_info
116 .as_ref()
117 .map(|info| info.lang_tag)
118 .unwrap_or(""),
119 self.source
120 ))
121 }
122}
123
124fn format_str(kernel_id: &str, markers: &[(char, char)], include_space: bool) -> String {
125 let kernel_id = kernel_id.to_string();
126 let mut result = String::new();
127 let mut depth = 0;
128 let indentation = 4;
129
130 let mut prev = ' ';
131
132 for c in kernel_id.chars() {
133 if c == ' ' {
134 continue;
135 }
136
137 let mut found_marker = false;
138
139 for (start, end) in markers {
140 let (start, end) = (*start, *end);
141
142 if c == start {
143 depth += 1;
144 if prev != ' ' && include_space {
145 result.push(' ');
146 }
147 result.push(start);
148 result.push('\n');
149 result.push_str(&" ".repeat(indentation * depth));
150 found_marker = true;
151 } else if c == end {
152 depth -= 1;
153 if prev != start {
154 if prev == ' ' {
155 result.pop();
156 }
157 result.push_str(",\n");
158 result.push_str(&" ".repeat(indentation * depth));
159 result.push(end);
160 } else {
161 for _ in 0..(&" ".repeat(indentation * depth).len()) + 1 + indentation {
162 result.pop();
163 }
164 result.push(end);
165 }
166 found_marker = true;
167 }
168 }
169
170 if found_marker {
171 prev = c;
172 continue;
173 }
174
175 if c == ',' && depth > 0 {
176 if prev == ' ' {
177 result.pop();
178 }
179
180 result.push_str(",\n");
181 result.push_str(&" ".repeat(indentation * depth));
182 continue;
183 }
184
185 if c == ':' && include_space {
186 result.push(c);
187 result.push(' ');
188 prev = ' ';
189 } else {
190 result.push(c);
191 prev = c;
192 }
193 }
194
195 result
196}
197
198pub trait CubeTask<C: Compiler>: Send + Sync {
201 fn id(&self) -> KernelId;
203 fn compile(
205 &self,
206 compilation_options: &C::CompilationOptions,
207 mode: ExecutionMode,
208 ) -> CompiledKernel<C>;
209 fn name(&self) -> &'static str {
210 core::any::type_name::<Self>()
211 }
212}
213
214#[derive(new)]
216pub struct KernelTask<C: Compiler, K: Kernel> {
217 kernel_definition: K,
218 _compiler: PhantomData<C>,
219}
220
221impl<C: Compiler, K: Kernel> CubeTask<C> for KernelTask<C, K> {
222 fn compile(
223 &self,
224 compilation_options: &C::CompilationOptions,
225 mode: ExecutionMode,
226 ) -> CompiledKernel<C> {
227 let gpu_ir = self.kernel_definition.define();
228 let entrypoint_name = gpu_ir.kernel_name.clone();
229 let cube_dim = gpu_ir.cube_dim;
230 let lower_level_ir = C::compile(gpu_ir, compilation_options, mode);
231 let shared_mem_bytes = lower_level_ir.shared_memory_size();
232
233 CompiledKernel {
234 entrypoint_name,
235 debug_name: Some(core::any::type_name::<K>()),
236 source: lower_level_ir.to_string(),
237 repr: Some(lower_level_ir),
238 cube_dim,
239 shared_mem_bytes,
240 debug_info: None,
241 }
242 }
243
244 fn id(&self) -> KernelId {
245 self.kernel_definition.id().clone()
246 }
247
248 fn name(&self) -> &'static str {
249 core::any::type_name::<K>()
250 }
251}
252
253impl<C: Compiler> CubeTask<C> for Arc<dyn CubeTask<C>> {
254 fn compile(
255 &self,
256 compilation_options: &C::CompilationOptions,
257 mode: ExecutionMode,
258 ) -> CompiledKernel<C> {
259 self.as_ref().compile(compilation_options, mode)
260 }
261
262 fn id(&self) -> KernelId {
263 self.as_ref().id()
264 }
265 fn name(&self) -> &'static str {
266 self.as_ref().name()
267 }
268}
269
270impl<C: Compiler> CubeTask<C> for Box<dyn CubeTask<C>> {
271 fn compile(
272 &self,
273 compilation_options: &C::CompilationOptions,
274 mode: ExecutionMode,
275 ) -> CompiledKernel<C> {
276 self.as_ref().compile(compilation_options, mode)
277 }
278
279 fn id(&self) -> KernelId {
280 self.as_ref().id()
281 }
282
283 fn name(&self) -> &'static str {
284 self.as_ref().name()
285 }
286}