1use crate::shared::STATIC_INFO_NAME;
2
3use super::{Body, Component, Dialect, Elem, Flags, INFO_NAME, Item, Variable};
4use cubecl_core::{
5 CubeDim,
6 compute::{Location, Visibility},
7 ir::Id,
8};
9
10use std::{collections::HashSet, fmt::Display};
11
12#[derive(Debug, PartialEq, Eq, Clone)]
13pub struct Binding<D: Dialect> {
14 pub id: Id,
15 pub item: Item<D>,
16 pub location: Location,
17 pub size: Option<usize>,
18 pub vis: Visibility,
19}
20
21#[derive(Debug, PartialEq, Eq, Clone)]
22pub struct SharedMemory<D: Dialect> {
23 pub index: Id,
24 pub item: Item<D>,
25 pub length: u32,
26 pub align: u32,
27 pub offset: u32,
28}
29
30impl<D: Dialect> SharedMemory<D> {
31 pub fn size(&self) -> u32 {
32 self.length * self.item.size() as u32
33 }
34}
35
36#[derive(Debug, PartialEq, Clone)]
37pub struct ConstArray<D: Dialect> {
38 pub index: Id,
39 pub item: Item<D>,
40 pub size: u32,
41 pub values: Vec<Variable<D>>,
42}
43
44#[derive(Debug, PartialEq, Eq, Clone)]
45pub struct LocalArray<D: Dialect> {
46 pub index: Id,
47 pub item: Item<D>,
48 pub size: u32,
49}
50
51impl<D: Dialect> LocalArray<D> {
52 pub fn new(index: Id, item: Item<D>, size: u32) -> Self {
53 Self { index, item, size }
54 }
55}
56
57impl<D: Dialect> SharedMemory<D> {
58 pub fn new(index: Id, item: Item<D>, size: u32, align: u32) -> Self {
59 Self {
60 index,
61 item,
62 length: size,
63 align,
64 offset: 0, }
66 }
67}
68
69#[derive(Debug, Clone)]
70pub struct ComputeKernel<D: Dialect> {
71 pub tensor_maps: Vec<Binding<D>>,
72 pub buffers: Vec<Binding<D>>,
73 pub scalars: Vec<(Elem<D>, usize)>,
74 pub meta_static_len: usize,
75 pub body: Body<D>,
76 pub cube_dim: CubeDim,
77 pub cluster_dim: Option<CubeDim>,
78 pub extensions: Vec<D::Extension>,
79 pub flags: Flags,
80 pub items: HashSet<super::Item<D>>,
81 pub kernel_name: String,
82}
83
84impl<D: Dialect> ComputeKernel<D> {
85 pub fn shared_memory_size(&self) -> usize {
86 let smems = self.body.shared_memories.iter();
87 let ends = smems.map(|it| it.offset + it.size());
88 ends.max().unwrap_or_default() as usize
89 }
90}
91
92impl<D: Dialect> Display for ComputeKernel<D> {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 let mut flags = self.flags.clone();
95 if !self.tensor_maps.is_empty() {
96 flags.inst_tma = true;
97 }
98
99 D::compile_includes(f, &flags)?;
101 D::compile_type_definitions(f, &self.items, &self.scalars, &flags)?;
102 D::compile_polyfills(f, &flags)?;
103 D::compile_extensions(f, &self.extensions)?;
104
105 D::compile_kernel_signature(
107 f,
108 &self.kernel_name,
109 &self.tensor_maps,
110 &self.buffers,
111 &self.scalars,
112 &self.flags,
113 )?;
114
115 f.write_str(" {\n")?;
117 compile_cube_builtin_bindings_decl::<D>(f, &self.flags)?;
118 write!(f, "{}", self.body)?;
119 f.write_str("\n}")?;
120
121 Ok(())
122 }
123}
124
125pub fn type_definitions<D: Dialect>(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 writeln!(f, "typedef unsigned int uint;")?;
127 writeln!(f, "typedef unsigned char uint8;")?;
128 writeln!(f, "typedef unsigned short uint16;")?;
129 writeln!(f, "typedef unsigned int uint32;")?;
130 writeln!(f, "typedef unsigned long long int uint64;")?;
131
132 writeln!(f, "typedef signed char int8;")?;
133 writeln!(f, "typedef signed short int16;")?;
134 writeln!(f, "typedef signed int int32;")?;
135 writeln!(f, "typedef signed long long int int64;")?;
136
137 Ok(())
138}
139
140pub fn type_vectorized_definitions<D: Dialect>(
141 f: &mut std::fmt::Formatter<'_>,
142 items: &HashSet<Item<D>>,
143) -> std::fmt::Result {
144 for item in items.iter() {
145 let elem = item.elem;
146 let size = item.vectorization;
147 let alignment = elem.size() * size;
148 if size > 1 {
149 write!(
150 f,
151 "
152struct __align__({alignment}) {item} {{"
153 )?;
154
155 for i in 0..size {
156 write!(
157 f,
158 "
159 {elem} i_{i};"
160 )?;
161 }
162
163 f.write_str("\n};")?;
164 }
165 }
166 Ok(())
167}
168
169pub fn type_scalar_definitions<D: Dialect>(
170 f: &mut std::fmt::Formatter<'_>,
171 scalars: &[(Elem<D>, usize)],
172) -> std::fmt::Result {
173 for (elem, count) in scalars.iter() {
174 writeln!(
175 f,
176 "
177struct scalars_{elem}_st {{
178{elem} x[{count}];
179}};"
180 )?;
181 }
182 Ok(())
183}
184
185pub fn type_info_definition<D: Dialect>(
186 f: &mut std::fmt::Formatter<'_>,
187 static_len: usize,
188) -> std::fmt::Result {
189 if static_len > 0 {
190 write!(
191 f,
192 "
193struct metadata_st {{
194uint x[{static_len}];
195}};
196"
197 )?;
198 }
199 Ok(())
200}
201
202pub fn compile_bindings<D: Dialect>(
203 f: &mut core::fmt::Formatter<'_>,
204 tensor_maps: &[Binding<D>],
205 buffers: &[Binding<D>],
206 trailing_comma: bool,
207 flags: &Flags,
208) -> core::fmt::Result {
209 write!(f, " ")?;
210
211 let mut args = Vec::new();
212
213 args.extend(tensor_maps.iter().map(|binding| {
214 format!(
215 "const __grid_constant__ CUtensorMap tensor_map_{}",
216 binding.id
217 )
218 }));
219 args.extend(
220 tensor_maps
221 .iter()
222 .chain(buffers.iter())
223 .map(|binding| match binding.vis {
224 Visibility::Read => {
225 format!("const {}* __restrict__ buffer_{}", binding.item, binding.id)
226 }
227 Visibility::ReadWrite => {
228 format!("{}* buffer_{}", binding.item, binding.id)
229 }
230 }),
231 );
232 args.extend(
233 flags
234 .has_dynamic_meta
235 .then(|| format!("const uint32* __restrict__ {INFO_NAME}")),
236 );
237
238 write!(f, "{}", args.join(", "))?;
239 if trailing_comma {
240 f.write_str(", ")?;
241 }
242 Ok(())
243}
244
245pub fn compile_scalars_dynamic<D: Dialect>(
246 f: &mut std::fmt::Formatter<'_>,
247 scalars: &[(Elem<D>, usize)],
248) -> core::fmt::Result {
249 let scalar_inputs = scalars
250 .iter()
251 .map(|(elem, _)| format!("const {elem}* __restrict__ scalars_{elem}"));
252 let scalar_inputs = scalar_inputs.collect::<Vec<String>>();
253
254 write!(f, "{}", scalar_inputs.join(","))
255}
256
257pub fn compile_scalars_static<D: Dialect>(
258 f: &mut std::fmt::Formatter<'_>,
259 scalars: &[(Elem<D>, usize)],
260 flags: &Flags,
261) -> core::fmt::Result {
262 let mut scalar_inputs = Vec::new();
263
264 let scalars_of_size = |scalar_inputs: &mut Vec<String>, size: usize| {
267 for (elem, _) in scalars.iter().filter(|it| it.0.size() == size) {
268 scalar_inputs.push(format!(
269 "const __grid_constant__ scalars_{elem}_st scalars_{elem}"
270 ));
271 }
272 };
273
274 scalars_of_size(&mut scalar_inputs, 8);
276
277 if flags.static_meta_length > 0 {
279 scalar_inputs.push(format!(
280 "const __grid_constant__ metadata_st {STATIC_INFO_NAME}"
281 ));
282 }
283
284 for size in [4, 2, 1] {
286 scalars_of_size(&mut scalar_inputs, size);
287 }
288
289 write!(f, "{}", scalar_inputs.join(", "))
290}
291
292fn compile_cube_builtin_bindings_decl<D: Dialect>(
293 f: &mut core::fmt::Formatter<'_>,
294 settings: &Flags,
295) -> core::fmt::Result {
296 if settings.indexes.absolute_pos_tuple {
297 D::compile_absolute_pos_tuple_computation(f)?;
298 }
299
300 if settings.indexes.unit_pos {
301 D::compile_unit_pos_computation(f)?;
302 }
303
304 if settings.indexes.absolute_pos {
305 let variable = Variable::<D>::AbsolutePos;
306 let ty = variable.item();
307 let absolute_pos_x = Variable::<D>::AbsolutePosX;
308 let absolute_pos_y = Variable::<D>::AbsolutePosY;
309 let absolute_pos_z = Variable::<D>::AbsolutePosZ;
310 let cube_count_x = Variable::<D>::CubeCountX;
311 let cube_count_y = Variable::<D>::CubeCountY;
312 let cube_dim_x = Variable::<D>::CubeDimX;
313 let cube_dim_y = Variable::<D>::CubeDimY;
314 writeln!(
315 f,
316 "{ty} {variable} = ({absolute_pos_z} * {cube_count_x} * {cube_dim_x} * {cube_count_y} * {cube_dim_y}) + ({absolute_pos_y} * {cube_count_x} * {cube_dim_x}) + {absolute_pos_x};"
317 )?;
318 }
319
320 if settings.indexes.cube_dim {
321 let variable = Variable::<D>::CubeDim;
322 let ty = variable.item();
323 let cube_dim_x = Variable::<D>::CubeDimX;
324 let cube_dim_y = Variable::<D>::CubeDimY;
325 let cube_dim_z = Variable::<D>::CubeDimZ;
326 writeln!(
327 f,
328 "{ty} {variable} = {cube_dim_x} * {cube_dim_y} * {cube_dim_z};"
329 )?;
330 }
331
332 if settings.indexes.cube_count {
333 let variable = Variable::<D>::CubeCount;
334 let ty = variable.item();
335 let cube_count_x = Variable::<D>::CubeCountX;
336 let cube_count_y = Variable::<D>::CubeCountY;
337 let cube_count_z = Variable::<D>::CubeCountZ;
338 writeln!(
339 f,
340 "{ty} {variable} = {cube_count_x} * {cube_count_y} * {cube_count_z};"
341 )?;
342 }
343
344 if settings.indexes.cube_pos {
345 let variable = Variable::<D>::CubePos;
346 let ty = variable.item();
347 let cube_pos_x = Variable::<D>::CubePosX;
348 let cube_pos_y = Variable::<D>::CubePosY;
349 let cube_pos_z = Variable::<D>::CubePosZ;
350 let cube_count_x = Variable::<D>::CubeCountX;
351 let cube_count_y = Variable::<D>::CubeCountY;
352 writeln!(
353 f,
354 "{ty} {variable} = ({cube_pos_z} * {cube_count_y} * {cube_count_x}) + ({cube_pos_y} * {cube_count_x}) + {cube_pos_x};"
355 )?;
356 }
357
358 if settings.indexes.plane_dim_checked {
359 let plane_dim = Variable::<D>::PlaneDim;
360 let variable = Variable::<D>::PlaneDimChecked;
361 let ty = variable.item();
362 let cube_dim_x = Variable::<D>::CubeDimX;
363 let cube_dim_y = Variable::<D>::CubeDimY;
364 let cube_dim_z = Variable::<D>::CubeDimZ;
365 writeln!(
366 f,
367 "{ty} {variable} = min({plane_dim}, {cube_dim_x} * {cube_dim_y} * {cube_dim_z});"
368 )?;
369 }
370
371 if settings.indexes.cluster_pos {
372 f.write_str(
373 "
374cooperative_groups::cluster_group cluster = cooperative_groups::this_cluster();
375",
376 )?;
377 }
378
379 Ok(())
380}