1use crate::shared::STATIC_INFO_NAME;
2
3use super::{Body, Component, Dialect, Elem, Flags, INFO_NAME, Item, Variable};
4use cubecl_core::{
5 CubeDim,
6 compute::{Location, Visibility},
7 ir::Id,
8};
9
10use std::{collections::HashSet, fmt::Display};
11
12#[derive(Debug, PartialEq, Eq, Clone)]
13pub struct Binding<D: Dialect> {
14 pub id: Id,
15 pub item: Item<D>,
16 pub location: Location,
17 pub size: Option<usize>,
18 pub vis: Visibility,
19}
20
21#[derive(Debug, PartialEq, Eq, Clone)]
22pub struct SharedMemory<D: Dialect> {
23 pub index: Id,
24 pub item: Item<D>,
25 pub size: u32,
26 pub align: Option<u32>,
27}
28
29#[derive(Debug, PartialEq, Clone)]
30pub struct ConstArray<D: Dialect> {
31 pub index: Id,
32 pub item: Item<D>,
33 pub size: u32,
34 pub values: Vec<Variable<D>>,
35}
36
37#[derive(Debug, PartialEq, Eq, Clone)]
38pub struct LocalArray<D: Dialect> {
39 pub index: Id,
40 pub item: Item<D>,
41 pub size: u32,
42}
43
44impl<D: Dialect> LocalArray<D> {
45 pub fn new(index: Id, item: Item<D>, size: u32) -> Self {
46 Self { index, item, size }
47 }
48}
49
50impl<D: Dialect> SharedMemory<D> {
51 pub fn new(index: Id, item: Item<D>, size: u32, align: Option<u32>) -> Self {
52 Self {
53 index,
54 item,
55 size,
56 align,
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
62pub struct ComputeKernel<D: Dialect> {
63 pub tensor_maps: Vec<Id>,
64 pub buffers: Vec<Binding<D>>,
65 pub scalars: Vec<(Elem<D>, usize)>,
66 pub meta_static_len: usize,
67 pub body: Body<D>,
68 pub cube_dim: CubeDim,
69 pub cluster_dim: Option<CubeDim>,
70 pub extensions: Vec<D::Extension>,
71 pub flags: Flags,
72 pub items: HashSet<super::Item<D>>,
73 pub kernel_name: String,
74}
75
76impl<D: Dialect> ComputeKernel<D> {
77 pub fn shared_memory_size(&self) -> usize {
78 let mut current = 0usize;
79
80 for var in self.body.shared_memories.iter() {
81 let factor = var.item.vectorization;
82 let elem_size_bytes = var.item.elem().size();
83 current += (var.size as usize) * factor * elem_size_bytes;
84 }
85
86 current
87 }
88}
89
90impl<D: Dialect> Display for ComputeKernel<D> {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 let mut flags = self.flags.clone();
93 if !self.tensor_maps.is_empty() {
94 flags.inst_tma = true;
95 }
96
97 D::compile_includes(f, &flags)?;
99 D::compile_type_definitions(f, &self.items, &self.scalars, &flags)?;
100 D::compile_polyfills(f, &flags)?;
101 D::compile_extensions(f, &self.extensions)?;
102
103 D::compile_kernel_signature(
105 f,
106 &self.kernel_name,
107 &self.tensor_maps,
108 &self.buffers,
109 &self.scalars,
110 &self.flags,
111 )?;
112
113 f.write_str(" {\n")?;
115 compile_cube_builtin_bindings_decl::<D>(f, &self.flags)?;
116 write!(f, "{}", self.body)?;
117 f.write_str("\n}")?;
118
119 Ok(())
120 }
121}
122
123pub fn type_definitions<D: Dialect>(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124 writeln!(f, "typedef unsigned int uint;")?;
125 writeln!(f, "typedef unsigned char uint8;")?;
126 writeln!(f, "typedef unsigned short uint16;")?;
127 writeln!(f, "typedef unsigned int uint32;")?;
128 writeln!(f, "typedef unsigned long long int uint64;")?;
129
130 writeln!(f, "typedef signed char int8;")?;
131 writeln!(f, "typedef signed short int16;")?;
132 writeln!(f, "typedef signed int int32;")?;
133 writeln!(f, "typedef signed long long int int64;")?;
134
135 Ok(())
136}
137
138pub fn type_vectorized_definitions<D: Dialect>(
139 f: &mut std::fmt::Formatter<'_>,
140 items: &HashSet<Item<D>>,
141) -> std::fmt::Result {
142 for item in items.iter() {
143 let elem = item.elem;
144 let size = item.vectorization;
145 let alignment = elem.size() * size;
146 if size > 1 {
147 write!(
148 f,
149 "
150struct __align__({alignment}) {item} {{"
151 )?;
152
153 for i in 0..size {
154 write!(
155 f,
156 "
157 {elem} i_{i};"
158 )?;
159 }
160
161 f.write_str("\n};")?;
162 }
163 }
164 Ok(())
165}
166
167pub fn type_scalar_definitions<D: Dialect>(
168 f: &mut std::fmt::Formatter<'_>,
169 scalars: &[(Elem<D>, usize)],
170) -> std::fmt::Result {
171 for (elem, count) in scalars.iter() {
172 writeln!(
173 f,
174 "
175struct scalars_{elem}_st {{
176{elem} x[{count}];
177}};"
178 )?;
179 }
180 Ok(())
181}
182
183pub fn type_info_definition<D: Dialect>(
184 f: &mut std::fmt::Formatter<'_>,
185 static_len: usize,
186) -> std::fmt::Result {
187 if static_len > 0 {
188 write!(
189 f,
190 "
191struct metadata_st {{
192uint x[{}];
193}};
194",
195 static_len
196 )?;
197 }
198 Ok(())
199}
200
201pub fn compile_bindings<D: Dialect>(
202 f: &mut core::fmt::Formatter<'_>,
203 tensor_maps: &[Id],
204 buffers: &[Binding<D>],
205 trailing_comma: bool,
206 flags: &Flags,
207) -> core::fmt::Result {
208 write!(f, " ")?;
209
210 let mut args = Vec::new();
211
212 args.extend(
213 tensor_maps
214 .iter()
215 .map(|id| format!("const __grid_constant__ CUtensorMap tensor_map_{id}")),
216 );
217 args.extend(buffers.iter().map(|binding| match binding.vis {
218 Visibility::Read => {
219 format!("{} buffer_{}[]", binding.item, binding.id)
220 }
225 Visibility::ReadWrite => {
226 format!("{} buffer_{}[]", binding.item, binding.id)
227 }
228 }));
229 args.extend(
230 flags
231 .has_dynamic_meta
232 .then(|| format!("const uint32* __restrict__ {INFO_NAME}")),
233 );
234
235 write!(f, "{}", args.join(", "))?;
236 if trailing_comma {
237 f.write_str(", ")?;
238 }
239 Ok(())
240}
241
242pub fn compile_scalars_dynamic<D: Dialect>(
243 f: &mut std::fmt::Formatter<'_>,
244 scalars: &[(Elem<D>, usize)],
245) -> core::fmt::Result {
246 let scalar_inputs = scalars
247 .iter()
248 .map(|(elem, _)| format!("const {}* __restrict__ scalars_{}", elem, elem));
249 let scalar_inputs = scalar_inputs.collect::<Vec<String>>();
250
251 write!(f, "{}", scalar_inputs.join(","))
252}
253
254pub fn compile_scalars_static<D: Dialect>(
255 f: &mut std::fmt::Formatter<'_>,
256 scalars: &[(Elem<D>, usize)],
257 flags: &Flags,
258) -> core::fmt::Result {
259 let mut scalar_inputs = Vec::new();
260
261 let scalars_of_size = |scalar_inputs: &mut Vec<String>, size: usize| {
264 for (elem, _) in scalars.iter().filter(|it| it.0.size() == size) {
265 scalar_inputs.push(format!(
266 "const __grid_constant__ scalars_{elem}_st scalars_{elem}"
267 ));
268 }
269 };
270
271 scalars_of_size(&mut scalar_inputs, 8);
273
274 if flags.static_meta_length > 0 {
276 scalar_inputs.push(format!(
277 "const __grid_constant__ metadata_st {STATIC_INFO_NAME}"
278 ));
279 }
280
281 for size in [4, 2, 1] {
283 scalars_of_size(&mut scalar_inputs, size);
284 }
285
286 write!(f, "{}", scalar_inputs.join(", "))
287}
288
289fn compile_cube_builtin_bindings_decl<D: Dialect>(
290 f: &mut core::fmt::Formatter<'_>,
291 settings: &Flags,
292) -> core::fmt::Result {
293 if settings.indexes.absolute_pos_tuple {
294 D::compile_absolute_pos_tuple_computation(f)?;
295 }
296
297 if settings.indexes.unit_pos {
298 D::compile_unit_pos_computation(f)?;
299 }
300
301 if settings.indexes.absolute_pos {
302 let variable = Variable::<D>::AbsolutePos;
303 let ty = variable.item();
304 let absolute_pos_x = Variable::<D>::AbsolutePosX;
305 let absolute_pos_y = Variable::<D>::AbsolutePosY;
306 let absolute_pos_z = Variable::<D>::AbsolutePosZ;
307 let cube_count_x = Variable::<D>::CubeCountX;
308 let cube_count_y = Variable::<D>::CubeCountY;
309 let cube_dim_x = Variable::<D>::CubeDimX;
310 let cube_dim_y = Variable::<D>::CubeDimY;
311 writeln!(
312 f,
313 "{ty} {variable} = ({absolute_pos_z} * {cube_count_x} * {cube_dim_x} * {cube_count_y} * {cube_dim_y}) + ({absolute_pos_y} * {cube_count_x} * {cube_dim_x}) + {absolute_pos_x};"
314 )?;
315 }
316
317 if settings.indexes.cube_dim {
318 let variable = Variable::<D>::CubeDim;
319 let ty = variable.item();
320 let cube_dim_x = Variable::<D>::CubeDimX;
321 let cube_dim_y = Variable::<D>::CubeDimY;
322 let cube_dim_z = Variable::<D>::CubeDimZ;
323 writeln!(
324 f,
325 "{ty} {variable} = {cube_dim_x} * {cube_dim_y} * {cube_dim_z};"
326 )?;
327 }
328
329 if settings.indexes.cube_count {
330 let variable = Variable::<D>::CubeCount;
331 let ty = variable.item();
332 let cube_count_x = Variable::<D>::CubeCountX;
333 let cube_count_y = Variable::<D>::CubeCountY;
334 let cube_count_z = Variable::<D>::CubeCountZ;
335 writeln!(
336 f,
337 "{ty} {variable} = {cube_count_x} * {cube_count_y} * {cube_count_z};"
338 )?;
339 }
340
341 if settings.indexes.cube_pos {
342 let variable = Variable::<D>::CubePos;
343 let ty = variable.item();
344 let cube_pos_x = Variable::<D>::CubePosX;
345 let cube_pos_y = Variable::<D>::CubePosY;
346 let cube_pos_z = Variable::<D>::CubePosZ;
347 let cube_count_x = Variable::<D>::CubeCountX;
348 let cube_count_y = Variable::<D>::CubeCountY;
349 writeln!(
350 f,
351 "{ty} {variable} = ({cube_pos_z} * {cube_count_y} * {cube_count_x}) + ({cube_pos_y} * {cube_count_x}) + {cube_pos_x};"
352 )?;
353 }
354
355 if settings.indexes.plane_dim_checked {
356 let plane_dim = Variable::<D>::PlaneDim;
357 let variable = Variable::<D>::PlaneDimChecked;
358 let ty = variable.item();
359 let cube_dim_x = Variable::<D>::CubeDimX;
360 let cube_dim_y = Variable::<D>::CubeDimY;
361 let cube_dim_z = Variable::<D>::CubeDimZ;
362 writeln!(
363 f,
364 "{ty} {variable} = min({plane_dim}, {cube_dim_x} * {cube_dim_y} * {cube_dim_z});"
365 )?;
366 }
367
368 if settings.indexes.cluster_pos {
369 f.write_str(
370 "
371cooperative_groups::cluster_group cluster = cooperative_groups::this_cluster();
372",
373 )?;
374 }
375
376 Ok(())
377}