1use super::{Body, Component, Dialect, Elem, Flags, INFO_NAME, Item, Variable};
2use cubecl_core::{CubeDim, ir::Id, prelude::Visibility};
3
4use std::{collections::HashSet, fmt::Display};
5
6#[derive(Debug, PartialEq, Eq, Clone)]
7pub struct KernelArg<D: Dialect> {
8 pub id: Id,
9 pub item: Item<D>,
10 pub size: Option<usize>,
11 pub vis: Visibility,
12}
13
14#[derive(Debug, PartialEq, Eq, Clone)]
15pub enum SharedMemory<D: Dialect> {
16 Array {
17 index: Id,
18 item: Item<D>,
19 length: usize,
20 align: usize,
21 offset: usize,
22 },
23 Value {
24 index: Id,
25 item: Item<D>,
26 align: usize,
27 offset: usize,
28 },
29}
30
31impl<D: Dialect> SharedMemory<D> {
32 pub fn size(&self) -> usize {
33 match self {
34 SharedMemory::Array { item, length, .. } => *length * item.size(),
35 SharedMemory::Value { item, .. } => item.size(),
36 }
37 }
38
39 pub fn align(&self) -> usize {
40 match self {
41 SharedMemory::Array { align, .. } => *align,
42 SharedMemory::Value { align, .. } => *align,
43 }
44 }
45
46 pub fn offset(&self) -> usize {
47 match self {
48 SharedMemory::Array { offset, .. } => *offset,
49 SharedMemory::Value { offset, .. } => *offset,
50 }
51 }
52}
53
54#[derive(Debug, PartialEq, Clone)]
55pub struct ConstArray<D: Dialect> {
56 pub index: Id,
57 pub item: Item<D>,
58 pub size: u32,
59 pub values: Vec<Variable<D>>,
60}
61
62#[derive(Debug, PartialEq, Eq, Clone)]
63pub struct LocalArray<D: Dialect> {
64 pub index: Id,
65 pub item: Item<D>,
66 pub size: usize,
67}
68
69impl<D: Dialect> LocalArray<D> {
70 pub fn new(index: Id, item: Item<D>, size: usize) -> Self {
71 Self { index, item, size }
72 }
73}
74
75impl<D: Dialect> SharedMemory<D> {
76 pub fn new_array(index: Id, item: Item<D>, size: usize, align: usize) -> Self {
77 Self::Array {
78 index,
79 item,
80 length: size,
81 align,
82 offset: 0, }
84 }
85
86 pub fn new_value(index: Id, item: Item<D>, align: usize) -> Self {
87 Self::Value {
88 index,
89 item,
90 align,
91 offset: 0, }
93 }
94}
95
96#[derive(Debug, Clone)]
97pub struct ComputeKernel<D: Dialect> {
98 pub tensor_maps: Vec<KernelArg<D>>,
99 pub buffers: Vec<KernelArg<D>>,
100 pub scalars: Vec<(Elem<D>, usize)>,
101 pub info: cubecl_core::Info,
102 pub meta_static_len: usize,
103 pub body: Body<D>,
104 pub cube_dim: CubeDim,
105 pub cluster_dim: Option<CubeDim>,
106 pub extensions: Vec<D::Extension>,
107 pub flags: Flags<D>,
108 pub items: HashSet<super::Item<D>>,
109 pub kernel_name: String,
110}
111
112impl<D: Dialect> ComputeKernel<D> {
113 pub fn shared_memory_size(&self) -> usize {
114 let smems = self.body.shared_memories.iter();
115 let ends = smems.map(|it| it.offset() + it.size());
116 ends.max().unwrap_or_default()
117 }
118}
119
120impl<D: Dialect> Display for ComputeKernel<D> {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 let mut flags = self.flags.clone();
123 if !self.tensor_maps.is_empty() {
124 flags.inst_tma = true;
125 }
126
127 D::compile_includes(f, &flags)?;
129 D::compile_type_definitions(f, &self.items, &self.scalars, &self.info, &flags)?;
130 D::compile_polyfills(f, &flags)?;
131 D::compile_extensions(f, &self.extensions)?;
132
133 D::compile_kernel_signature(
135 f,
136 &self.kernel_name,
137 &self.tensor_maps,
138 &self.buffers,
139 &self.flags,
140 )?;
141
142 f.write_str(" {\n")?;
144 compile_cube_builtin_bindings_decl::<D>(f, &self.flags)?;
145 write!(f, "{}", self.body)?;
146 f.write_str("\n}")?;
147
148 Ok(())
149 }
150}
151
152pub fn type_definitions<D: Dialect>(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 writeln!(f, "typedef unsigned int uint;")?;
154 writeln!(f, "typedef unsigned char uint8;")?;
155 writeln!(f, "typedef unsigned short uint16;")?;
156 writeln!(f, "typedef unsigned int uint32;")?;
157 writeln!(f, "typedef unsigned long long int uint64;")?;
158
159 writeln!(f, "typedef signed char int8;")?;
160 writeln!(f, "typedef signed short int16;")?;
161 writeln!(f, "typedef signed int int32;")?;
162 writeln!(f, "typedef signed long long int int64;")?;
163
164 Ok(())
165}
166
167pub fn type_vectorized_definitions<D: Dialect>(
168 f: &mut std::fmt::Formatter<'_>,
169 items: &HashSet<Item<D>>,
170) -> std::fmt::Result {
171 for item in items.iter() {
172 let elem = item.elem;
173 let size = item.vectorization;
174 let alignment = elem.size() * size;
175 if size > 1 {
176 write!(
177 f,
178 "
179struct __align__({alignment}) {item} {{"
180 )?;
181
182 for i in 0..size {
183 write!(
184 f,
185 "
186 {elem} i_{i};"
187 )?;
188 }
189
190 f.write_str("\n};")?;
191 }
192 }
193 Ok(())
194}
195
196pub fn type_info_definition_sized<D: Dialect>(
197 f: &mut std::fmt::Formatter<'_>,
198 info: &cubecl_core::Info,
199 scalars: &[(Elem<D>, usize)],
200 address_type: Item<D>,
201) -> std::fmt::Result {
202 let scalars = info
203 .scalars
204 .iter()
205 .zip(scalars)
206 .map(|(field, (ty, _))| format!("{ty} scalars_{ty}[{}];", field.padded_size()))
207 .collect::<Vec<_>>()
208 .join("\n");
209 let static_meta = info
210 .sized_meta
211 .as_ref()
212 .map(|field| format!("{address_type} static_meta[{}];", field.padded_size()))
213 .unwrap_or_default();
214 write!(
215 f,
216 "
217struct info_st {{
218 {scalars}{static_meta}
219}};
220"
221 )
222}
223
224pub fn compile_bindings<D: Dialect>(
225 f: &mut core::fmt::Formatter<'_>,
226 tensor_maps: &[KernelArg<D>],
227 buffers: &[KernelArg<D>],
228 trailing_comma: bool,
229) -> core::fmt::Result {
230 write!(f, " ")?;
231
232 let mut args = Vec::new();
233
234 args.extend(tensor_maps.iter().map(|binding| {
235 format!(
236 "const __grid_constant__ CUtensorMap tensor_map_{}",
237 binding.id
238 )
239 }));
240 args.extend(
241 tensor_maps
242 .iter()
243 .chain(buffers.iter())
244 .map(|binding| match binding.vis {
245 Visibility::Read if !binding.item.is_atomic() => {
246 format!("const {}* __restrict__ buffer_{}", binding.item, binding.id)
247 }
248 Visibility::Read => {
249 format!("{}* buffer_{}", binding.item, binding.id)
250 }
251 Visibility::ReadWrite => {
252 format!("{}* buffer_{}", binding.item, binding.id)
253 }
254 }),
255 );
256
257 write!(f, "{}", args.join(", "))?;
258 if trailing_comma {
259 f.write_str(", ")?;
260 }
261 Ok(())
262}
263
264pub fn compile_info_dynamic<D: Dialect>(
265 f: &mut std::fmt::Formatter<'_>,
266 flags: &Flags<D>,
267) -> core::fmt::Result {
268 if flags.has_info {
269 write!(f, "const info_st* __restrict__ {INFO_NAME}_ptr")
270 } else {
271 Ok(())
272 }
273}
274
275pub fn compile_info_static<D: Dialect>(
276 f: &mut std::fmt::Formatter<'_>,
277 flags: &Flags<D>,
278) -> core::fmt::Result {
279 let mut inputs = Vec::new();
280
281 if flags.has_dynamic_meta {
282 inputs.push(format!(
283 "const {}* __restrict__ dynamic_meta",
284 flags.address_type
285 ))
286 }
287
288 if flags.has_info {
289 inputs.push(format!("const __grid_constant__ info_st {INFO_NAME}"));
290 }
291
292 write!(f, "{}", inputs.join(", "))
293}
294
295fn compile_cube_builtin_bindings_decl<D: Dialect>(
296 f: &mut core::fmt::Formatter<'_>,
297 settings: &Flags<D>,
298) -> core::fmt::Result {
299 if settings.indexes.absolute_pos_tuple {
300 D::compile_absolute_pos_tuple_computation(f)?;
301 }
302
303 if settings.indexes.unit_pos {
304 D::compile_unit_pos_computation(f)?;
305 }
306
307 if settings.indexes.absolute_pos {
308 let variable = Variable::<D>::AbsolutePos(settings.address_type.elem);
309 let ty = variable.item();
310 let absolute_pos_x = Variable::<D>::AbsolutePosX.fmt_cast_to(ty);
311 let absolute_pos_y = Variable::<D>::AbsolutePosY.fmt_cast_to(ty);
312 let absolute_pos_z = Variable::<D>::AbsolutePosZ.fmt_cast_to(ty);
313 let cube_count_x = Variable::<D>::CubeCountX.fmt_cast_to(ty);
314 let cube_count_y = Variable::<D>::CubeCountY.fmt_cast_to(ty);
315 let cube_dim_x = Variable::<D>::CubeDimX.fmt_cast_to(ty);
316 let cube_dim_y = Variable::<D>::CubeDimY.fmt_cast_to(ty);
317 writeln!(
318 f,
319 "{ty} {variable} = (
320 {absolute_pos_z} * {cube_count_x} * {cube_dim_x} * {cube_count_y} * {cube_dim_y})
321 + ({absolute_pos_y} * {cube_count_x} * {cube_dim_x})
322 + {absolute_pos_x};"
323 )?;
324 }
325
326 if settings.indexes.cube_dim {
327 let variable = Variable::<D>::CubeDim;
328 let ty = variable.item();
329 let cube_dim_x = Variable::<D>::CubeDimX;
330 let cube_dim_y = Variable::<D>::CubeDimY;
331 let cube_dim_z = Variable::<D>::CubeDimZ;
332 writeln!(
333 f,
334 "{ty} {variable} = {cube_dim_x} * {cube_dim_y} * {cube_dim_z};"
335 )?;
336 }
337
338 if settings.indexes.cube_count {
339 let variable = Variable::<D>::CubeCount(settings.address_type.elem);
340 let ty = variable.item();
341 let cube_count_x = Variable::<D>::CubeCountX.fmt_cast_to(ty);
342 let cube_count_y = Variable::<D>::CubeCountY.fmt_cast_to(ty);
343 let cube_count_z = Variable::<D>::CubeCountZ.fmt_cast_to(ty);
344 writeln!(
345 f,
346 "{ty} {variable} = {cube_count_x} * {cube_count_y} * {cube_count_z};"
347 )?;
348 }
349
350 if settings.indexes.cube_pos {
351 let variable = Variable::<D>::CubePos(settings.address_type.elem);
352 let ty = variable.item();
353 let cube_pos_x = Variable::<D>::CubePosX.fmt_cast_to(ty);
354 let cube_pos_y = Variable::<D>::CubePosY.fmt_cast_to(ty);
355 let cube_pos_z = Variable::<D>::CubePosZ.fmt_cast_to(ty);
356 let cube_count_x = Variable::<D>::CubeCountX.fmt_cast_to(ty);
357 let cube_count_y = Variable::<D>::CubeCountY.fmt_cast_to(ty);
358 writeln!(
359 f,
360 "{ty} {variable} = ({cube_pos_z} * {cube_count_y} * {cube_count_x}) + ({cube_pos_y} * {cube_count_x}) + {cube_pos_x};"
361 )?;
362 }
363
364 if settings.indexes.plane_dim_checked {
365 let plane_dim = Variable::<D>::PlaneDim;
366 let variable = Variable::<D>::PlaneDimChecked;
367 let ty = variable.item();
368 let cube_dim_x = Variable::<D>::CubeDimX;
369 let cube_dim_y = Variable::<D>::CubeDimY;
370 let cube_dim_z = Variable::<D>::CubeDimZ;
371 writeln!(
372 f,
373 "{ty} {variable} = min({plane_dim}, {cube_dim_x} * {cube_dim_y} * {cube_dim_z});"
374 )?;
375 }
376
377 if settings.indexes.cluster_pos {
378 f.write_str(
379 "
380cooperative_groups::cluster_group cluster = cooperative_groups::this_cluster();
381",
382 )?;
383 }
384
385 Ok(())
386}