1use crate::shared::STATIC_INFO_NAME;
2
3use super::{Body, Component, Dialect, Elem, Flags, INFO_NAME, Item, Variable};
4use cubecl_core::{
5 CubeDim,
6 compute::{Location, Visibility},
7 ir::Id,
8};
9
10use std::{collections::HashSet, fmt::Display};
11
12#[derive(Debug, PartialEq, Eq, Clone)]
13pub struct Binding<D: Dialect> {
14 pub id: Id,
15 pub item: Item<D>,
16 pub location: Location,
17 pub size: Option<usize>,
18 pub vis: Visibility,
19}
20
21#[derive(Debug, PartialEq, Eq, Clone)]
22pub struct SharedMemory<D: Dialect> {
23 pub index: Id,
24 pub item: Item<D>,
25 pub size: u32,
26 pub align: Option<u32>,
27 pub offset: u32,
28}
29
30#[derive(Debug, PartialEq, Clone)]
31pub struct ConstArray<D: Dialect> {
32 pub index: Id,
33 pub item: Item<D>,
34 pub size: u32,
35 pub values: Vec<Variable<D>>,
36}
37
38#[derive(Debug, PartialEq, Eq, Clone)]
39pub struct LocalArray<D: Dialect> {
40 pub index: Id,
41 pub item: Item<D>,
42 pub size: u32,
43}
44
45impl<D: Dialect> LocalArray<D> {
46 pub fn new(index: Id, item: Item<D>, size: u32) -> Self {
47 Self { index, item, size }
48 }
49}
50
51impl<D: Dialect> SharedMemory<D> {
52 pub fn new(index: Id, item: Item<D>, size: u32, align: Option<u32>) -> Self {
53 Self {
54 index,
55 item,
56 size,
57 align,
58 offset: 0, }
60 }
61}
62
63#[derive(Debug, Clone)]
64pub struct ComputeKernel<D: Dialect> {
65 pub tensor_maps: Vec<Id>,
66 pub buffers: Vec<Binding<D>>,
67 pub scalars: Vec<(Elem<D>, usize)>,
68 pub meta_static_len: usize,
69 pub body: Body<D>,
70 pub cube_dim: CubeDim,
71 pub cluster_dim: Option<CubeDim>,
72 pub extensions: Vec<D::Extension>,
73 pub flags: Flags,
74 pub items: HashSet<super::Item<D>>,
75 pub kernel_name: String,
76}
77
78impl<D: Dialect> ComputeKernel<D> {
79 pub fn shared_memory_size(&self) -> usize {
80 let mut shared_memories = self.body.shared_memories.clone();
83 shared_memories.sort_by_key(|smem| smem.align.unwrap_or(smem.item.size() as u32));
84 shared_memories.reverse();
85
86 let mut current = 0usize;
87
88 for var in self.body.shared_memories.iter() {
89 let align = var.align.unwrap_or(var.item.size() as u32);
90 let size_bytes = var.size as usize * var.item.size();
91 let offset = current.next_multiple_of(align as usize);
92 current = offset + size_bytes;
93 }
94
95 current
96 }
97}
98
99impl<D: Dialect> Display for ComputeKernel<D> {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 let mut flags = self.flags.clone();
102 if !self.tensor_maps.is_empty() {
103 flags.inst_tma = true;
104 }
105
106 D::compile_includes(f, &flags)?;
108 D::compile_type_definitions(f, &self.items, &self.scalars, &flags)?;
109 D::compile_polyfills(f, &flags)?;
110 D::compile_extensions(f, &self.extensions)?;
111
112 D::compile_kernel_signature(
114 f,
115 &self.kernel_name,
116 &self.tensor_maps,
117 &self.buffers,
118 &self.scalars,
119 &self.flags,
120 )?;
121
122 f.write_str(" {\n")?;
124 compile_cube_builtin_bindings_decl::<D>(f, &self.flags)?;
125 write!(f, "{}", self.body)?;
126 f.write_str("\n}")?;
127
128 Ok(())
129 }
130}
131
132pub fn type_definitions<D: Dialect>(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 writeln!(f, "typedef unsigned int uint;")?;
134 writeln!(f, "typedef unsigned char uint8;")?;
135 writeln!(f, "typedef unsigned short uint16;")?;
136 writeln!(f, "typedef unsigned int uint32;")?;
137 writeln!(f, "typedef unsigned long long int uint64;")?;
138
139 writeln!(f, "typedef signed char int8;")?;
140 writeln!(f, "typedef signed short int16;")?;
141 writeln!(f, "typedef signed int int32;")?;
142 writeln!(f, "typedef signed long long int int64;")?;
143
144 Ok(())
145}
146
147pub fn type_vectorized_definitions<D: Dialect>(
148 f: &mut std::fmt::Formatter<'_>,
149 items: &HashSet<Item<D>>,
150) -> std::fmt::Result {
151 for item in items.iter() {
152 let elem = item.elem;
153 let size = item.vectorization;
154 let alignment = elem.size() * size;
155 if size > 1 {
156 write!(
157 f,
158 "
159struct __align__({alignment}) {item} {{"
160 )?;
161
162 for i in 0..size {
163 write!(
164 f,
165 "
166 {elem} i_{i};"
167 )?;
168 }
169
170 f.write_str("\n};")?;
171 }
172 }
173 Ok(())
174}
175
176pub fn type_scalar_definitions<D: Dialect>(
177 f: &mut std::fmt::Formatter<'_>,
178 scalars: &[(Elem<D>, usize)],
179) -> std::fmt::Result {
180 for (elem, count) in scalars.iter() {
181 writeln!(
182 f,
183 "
184struct scalars_{elem}_st {{
185{elem} x[{count}];
186}};"
187 )?;
188 }
189 Ok(())
190}
191
192pub fn type_info_definition<D: Dialect>(
193 f: &mut std::fmt::Formatter<'_>,
194 static_len: usize,
195) -> std::fmt::Result {
196 if static_len > 0 {
197 write!(
198 f,
199 "
200struct metadata_st {{
201uint x[{static_len}];
202}};
203"
204 )?;
205 }
206 Ok(())
207}
208
209pub fn compile_bindings<D: Dialect>(
210 f: &mut core::fmt::Formatter<'_>,
211 tensor_maps: &[Id],
212 buffers: &[Binding<D>],
213 trailing_comma: bool,
214 flags: &Flags,
215) -> core::fmt::Result {
216 write!(f, " ")?;
217
218 let mut args = Vec::new();
219
220 args.extend(
221 tensor_maps
222 .iter()
223 .map(|id| format!("const __grid_constant__ CUtensorMap tensor_map_{id}")),
224 );
225 args.extend(buffers.iter().map(|binding| match binding.vis {
226 Visibility::Read => {
227 format!("const {}* __restrict__ buffer_{}", binding.item, binding.id)
228 }
229 Visibility::ReadWrite => {
230 format!("{}* buffer_{}", binding.item, binding.id)
231 }
232 }));
233 args.extend(
234 flags
235 .has_dynamic_meta
236 .then(|| format!("const uint32* __restrict__ {INFO_NAME}")),
237 );
238
239 write!(f, "{}", args.join(", "))?;
240 if trailing_comma {
241 f.write_str(", ")?;
242 }
243 Ok(())
244}
245
246pub fn compile_scalars_dynamic<D: Dialect>(
247 f: &mut std::fmt::Formatter<'_>,
248 scalars: &[(Elem<D>, usize)],
249) -> core::fmt::Result {
250 let scalar_inputs = scalars
251 .iter()
252 .map(|(elem, _)| format!("const {elem}* __restrict__ scalars_{elem}"));
253 let scalar_inputs = scalar_inputs.collect::<Vec<String>>();
254
255 write!(f, "{}", scalar_inputs.join(","))
256}
257
258pub fn compile_scalars_static<D: Dialect>(
259 f: &mut std::fmt::Formatter<'_>,
260 scalars: &[(Elem<D>, usize)],
261 flags: &Flags,
262) -> core::fmt::Result {
263 let mut scalar_inputs = Vec::new();
264
265 let scalars_of_size = |scalar_inputs: &mut Vec<String>, size: usize| {
268 for (elem, _) in scalars.iter().filter(|it| it.0.size() == size) {
269 scalar_inputs.push(format!(
270 "const __grid_constant__ scalars_{elem}_st scalars_{elem}"
271 ));
272 }
273 };
274
275 scalars_of_size(&mut scalar_inputs, 8);
277
278 if flags.static_meta_length > 0 {
280 scalar_inputs.push(format!(
281 "const __grid_constant__ metadata_st {STATIC_INFO_NAME}"
282 ));
283 }
284
285 for size in [4, 2, 1] {
287 scalars_of_size(&mut scalar_inputs, size);
288 }
289
290 write!(f, "{}", scalar_inputs.join(", "))
291}
292
293fn compile_cube_builtin_bindings_decl<D: Dialect>(
294 f: &mut core::fmt::Formatter<'_>,
295 settings: &Flags,
296) -> core::fmt::Result {
297 if settings.indexes.absolute_pos_tuple {
298 D::compile_absolute_pos_tuple_computation(f)?;
299 }
300
301 if settings.indexes.unit_pos {
302 D::compile_unit_pos_computation(f)?;
303 }
304
305 if settings.indexes.absolute_pos {
306 let variable = Variable::<D>::AbsolutePos;
307 let ty = variable.item();
308 let absolute_pos_x = Variable::<D>::AbsolutePosX;
309 let absolute_pos_y = Variable::<D>::AbsolutePosY;
310 let absolute_pos_z = Variable::<D>::AbsolutePosZ;
311 let cube_count_x = Variable::<D>::CubeCountX;
312 let cube_count_y = Variable::<D>::CubeCountY;
313 let cube_dim_x = Variable::<D>::CubeDimX;
314 let cube_dim_y = Variable::<D>::CubeDimY;
315 writeln!(
316 f,
317 "{ty} {variable} = ({absolute_pos_z} * {cube_count_x} * {cube_dim_x} * {cube_count_y} * {cube_dim_y}) + ({absolute_pos_y} * {cube_count_x} * {cube_dim_x}) + {absolute_pos_x};"
318 )?;
319 }
320
321 if settings.indexes.cube_dim {
322 let variable = Variable::<D>::CubeDim;
323 let ty = variable.item();
324 let cube_dim_x = Variable::<D>::CubeDimX;
325 let cube_dim_y = Variable::<D>::CubeDimY;
326 let cube_dim_z = Variable::<D>::CubeDimZ;
327 writeln!(
328 f,
329 "{ty} {variable} = {cube_dim_x} * {cube_dim_y} * {cube_dim_z};"
330 )?;
331 }
332
333 if settings.indexes.cube_count {
334 let variable = Variable::<D>::CubeCount;
335 let ty = variable.item();
336 let cube_count_x = Variable::<D>::CubeCountX;
337 let cube_count_y = Variable::<D>::CubeCountY;
338 let cube_count_z = Variable::<D>::CubeCountZ;
339 writeln!(
340 f,
341 "{ty} {variable} = {cube_count_x} * {cube_count_y} * {cube_count_z};"
342 )?;
343 }
344
345 if settings.indexes.cube_pos {
346 let variable = Variable::<D>::CubePos;
347 let ty = variable.item();
348 let cube_pos_x = Variable::<D>::CubePosX;
349 let cube_pos_y = Variable::<D>::CubePosY;
350 let cube_pos_z = Variable::<D>::CubePosZ;
351 let cube_count_x = Variable::<D>::CubeCountX;
352 let cube_count_y = Variable::<D>::CubeCountY;
353 writeln!(
354 f,
355 "{ty} {variable} = ({cube_pos_z} * {cube_count_y} * {cube_count_x}) + ({cube_pos_y} * {cube_count_x}) + {cube_pos_x};"
356 )?;
357 }
358
359 if settings.indexes.plane_dim_checked {
360 let plane_dim = Variable::<D>::PlaneDim;
361 let variable = Variable::<D>::PlaneDimChecked;
362 let ty = variable.item();
363 let cube_dim_x = Variable::<D>::CubeDimX;
364 let cube_dim_y = Variable::<D>::CubeDimY;
365 let cube_dim_z = Variable::<D>::CubeDimZ;
366 writeln!(
367 f,
368 "{ty} {variable} = min({plane_dim}, {cube_dim_x} * {cube_dim_y} * {cube_dim_z});"
369 )?;
370 }
371
372 if settings.indexes.cluster_pos {
373 f.write_str(
374 "
375cooperative_groups::cluster_group cluster = cooperative_groups::this_cluster();
376",
377 )?;
378 }
379
380 Ok(())
381}