1use crate::shared::STATIC_INFO_NAME;
2
3use super::{Body, Component, Dialect, Elem, Flags, INFO_NAME, Item, Variable};
4use cubecl_core::{
5 CubeDim,
6 ir::Id,
7 prelude::{Location, Visibility},
8};
9
10use std::{collections::HashSet, fmt::Display};
11
12#[derive(Debug, PartialEq, Eq, Clone)]
13pub struct Binding<D: Dialect> {
14 pub id: Id,
15 pub item: Item<D>,
16 pub location: Location,
17 pub size: Option<usize>,
18 pub vis: Visibility,
19}
20
21#[derive(Debug, PartialEq, Eq, Clone)]
22pub enum SharedMemory<D: Dialect> {
23 Array {
24 index: Id,
25 item: Item<D>,
26 length: u32,
27 align: u32,
28 offset: u32,
29 },
30 Value {
31 index: Id,
32 item: Item<D>,
33 align: u32,
34 offset: u32,
35 },
36}
37
38impl<D: Dialect> SharedMemory<D> {
39 pub fn size(&self) -> u32 {
40 match self {
41 SharedMemory::Array { item, length, .. } => *length * item.size() as u32,
42 SharedMemory::Value { item, .. } => item.size() as u32,
43 }
44 }
45
46 pub fn align(&self) -> u32 {
47 match self {
48 SharedMemory::Array { align, .. } => *align,
49 SharedMemory::Value { align, .. } => *align,
50 }
51 }
52
53 pub fn offset(&self) -> u32 {
54 match self {
55 SharedMemory::Array { offset, .. } => *offset,
56 SharedMemory::Value { offset, .. } => *offset,
57 }
58 }
59}
60
61#[derive(Debug, PartialEq, Clone)]
62pub struct ConstArray<D: Dialect> {
63 pub index: Id,
64 pub item: Item<D>,
65 pub size: u32,
66 pub values: Vec<Variable<D>>,
67}
68
69#[derive(Debug, PartialEq, Eq, Clone)]
70pub struct LocalArray<D: Dialect> {
71 pub index: Id,
72 pub item: Item<D>,
73 pub size: u32,
74}
75
76impl<D: Dialect> LocalArray<D> {
77 pub fn new(index: Id, item: Item<D>, size: u32) -> Self {
78 Self { index, item, size }
79 }
80}
81
82impl<D: Dialect> SharedMemory<D> {
83 pub fn new_array(index: Id, item: Item<D>, size: u32, align: u32) -> Self {
84 Self::Array {
85 index,
86 item,
87 length: size,
88 align,
89 offset: 0, }
91 }
92
93 pub fn new_value(index: Id, item: Item<D>, align: u32) -> Self {
94 Self::Value {
95 index,
96 item,
97 align,
98 offset: 0, }
100 }
101}
102
103#[derive(Debug, Clone)]
104pub struct ComputeKernel<D: Dialect> {
105 pub tensor_maps: Vec<Binding<D>>,
106 pub buffers: Vec<Binding<D>>,
107 pub scalars: Vec<(Elem<D>, usize)>,
108 pub meta_static_len: usize,
109 pub body: Body<D>,
110 pub cube_dim: CubeDim,
111 pub cluster_dim: Option<CubeDim>,
112 pub extensions: Vec<D::Extension>,
113 pub flags: Flags,
114 pub items: HashSet<super::Item<D>>,
115 pub kernel_name: String,
116}
117
118impl<D: Dialect> ComputeKernel<D> {
119 pub fn shared_memory_size(&self) -> usize {
120 let smems = self.body.shared_memories.iter();
121 let ends = smems.map(|it| it.offset() + it.size());
122 ends.max().unwrap_or_default() as usize
123 }
124}
125
126impl<D: Dialect> Display for ComputeKernel<D> {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 let mut flags = self.flags.clone();
129 if !self.tensor_maps.is_empty() {
130 flags.inst_tma = true;
131 }
132
133 D::compile_includes(f, &flags)?;
135 D::compile_type_definitions(f, &self.items, &self.scalars, &flags)?;
136 D::compile_polyfills(f, &flags)?;
137 D::compile_extensions(f, &self.extensions)?;
138
139 D::compile_kernel_signature(
141 f,
142 &self.kernel_name,
143 &self.tensor_maps,
144 &self.buffers,
145 &self.scalars,
146 &self.flags,
147 )?;
148
149 f.write_str(" {\n")?;
151 compile_cube_builtin_bindings_decl::<D>(f, &self.flags)?;
152 write!(f, "{}", self.body)?;
153 f.write_str("\n}")?;
154
155 Ok(())
156 }
157}
158
159pub fn type_definitions<D: Dialect>(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 writeln!(f, "typedef unsigned int uint;")?;
161 writeln!(f, "typedef unsigned char uint8;")?;
162 writeln!(f, "typedef unsigned short uint16;")?;
163 writeln!(f, "typedef unsigned int uint32;")?;
164 writeln!(f, "typedef unsigned long long int uint64;")?;
165
166 writeln!(f, "typedef signed char int8;")?;
167 writeln!(f, "typedef signed short int16;")?;
168 writeln!(f, "typedef signed int int32;")?;
169 writeln!(f, "typedef signed long long int int64;")?;
170
171 Ok(())
172}
173
174pub fn type_vectorized_definitions<D: Dialect>(
175 f: &mut std::fmt::Formatter<'_>,
176 items: &HashSet<Item<D>>,
177) -> std::fmt::Result {
178 for item in items.iter() {
179 let elem = item.elem;
180 let size = item.vectorization;
181 let alignment = elem.size() * size;
182 if size > 1 {
183 write!(
184 f,
185 "
186struct __align__({alignment}) {item} {{"
187 )?;
188
189 for i in 0..size {
190 write!(
191 f,
192 "
193 {elem} i_{i};"
194 )?;
195 }
196
197 f.write_str("\n};")?;
198 }
199 }
200 Ok(())
201}
202
203pub fn type_scalar_definitions<D: Dialect>(
204 f: &mut std::fmt::Formatter<'_>,
205 scalars: &[(Elem<D>, usize)],
206) -> std::fmt::Result {
207 for (elem, count) in scalars.iter() {
208 writeln!(
209 f,
210 "
211struct scalars_{elem}_st {{
212{elem} x[{count}];
213}};"
214 )?;
215 }
216 Ok(())
217}
218
219pub fn type_info_definition<D: Dialect>(
220 f: &mut std::fmt::Formatter<'_>,
221 static_len: usize,
222) -> std::fmt::Result {
223 if static_len > 0 {
224 write!(
225 f,
226 "
227struct metadata_st {{
228uint x[{static_len}];
229}};
230"
231 )?;
232 }
233 Ok(())
234}
235
236pub fn compile_bindings<D: Dialect>(
237 f: &mut core::fmt::Formatter<'_>,
238 tensor_maps: &[Binding<D>],
239 buffers: &[Binding<D>],
240 trailing_comma: bool,
241 flags: &Flags,
242) -> core::fmt::Result {
243 write!(f, " ")?;
244
245 let mut args = Vec::new();
246
247 args.extend(tensor_maps.iter().map(|binding| {
248 format!(
249 "const __grid_constant__ CUtensorMap tensor_map_{}",
250 binding.id
251 )
252 }));
253 args.extend(
254 tensor_maps
255 .iter()
256 .chain(buffers.iter())
257 .map(|binding| match binding.vis {
258 Visibility::Read => {
259 format!("const {}* __restrict__ buffer_{}", binding.item, binding.id)
260 }
261 Visibility::ReadWrite => {
262 format!("{}* buffer_{}", binding.item, binding.id)
263 }
264 }),
265 );
266 args.extend(
267 flags
268 .has_dynamic_meta
269 .then(|| format!("const uint32* __restrict__ {INFO_NAME}")),
270 );
271
272 write!(f, "{}", args.join(", "))?;
273 if trailing_comma {
274 f.write_str(", ")?;
275 }
276 Ok(())
277}
278
279pub fn compile_scalars_dynamic<D: Dialect>(
280 f: &mut std::fmt::Formatter<'_>,
281 scalars: &[(Elem<D>, usize)],
282) -> core::fmt::Result {
283 let scalar_inputs = scalars
284 .iter()
285 .map(|(elem, _)| format!("const {elem}* __restrict__ scalars_{elem}"));
286 let scalar_inputs = scalar_inputs.collect::<Vec<String>>();
287
288 write!(f, "{}", scalar_inputs.join(","))
289}
290
291pub fn compile_scalars_static<D: Dialect>(
292 f: &mut std::fmt::Formatter<'_>,
293 scalars: &[(Elem<D>, usize)],
294 flags: &Flags,
295) -> core::fmt::Result {
296 let mut scalar_inputs = Vec::new();
297
298 let scalars_of_size = |scalar_inputs: &mut Vec<String>, size: usize| {
301 for (elem, _) in scalars.iter().filter(|it| it.0.size() == size) {
302 scalar_inputs.push(format!(
303 "const __grid_constant__ scalars_{elem}_st scalars_{elem}"
304 ));
305 }
306 };
307
308 scalars_of_size(&mut scalar_inputs, 8);
310
311 if flags.static_meta_length > 0 {
313 scalar_inputs.push(format!(
314 "const __grid_constant__ metadata_st {STATIC_INFO_NAME}"
315 ));
316 }
317
318 for size in [4, 2, 1] {
320 scalars_of_size(&mut scalar_inputs, size);
321 }
322
323 write!(f, "{}", scalar_inputs.join(", "))
324}
325
326fn compile_cube_builtin_bindings_decl<D: Dialect>(
327 f: &mut core::fmt::Formatter<'_>,
328 settings: &Flags,
329) -> core::fmt::Result {
330 if settings.indexes.absolute_pos_tuple {
331 D::compile_absolute_pos_tuple_computation(f)?;
332 }
333
334 if settings.indexes.unit_pos {
335 D::compile_unit_pos_computation(f)?;
336 }
337
338 if settings.indexes.absolute_pos {
339 let variable = Variable::<D>::AbsolutePos;
340 let ty = variable.item();
341 let absolute_pos_x = Variable::<D>::AbsolutePosX;
342 let absolute_pos_y = Variable::<D>::AbsolutePosY;
343 let absolute_pos_z = Variable::<D>::AbsolutePosZ;
344 let cube_count_x = Variable::<D>::CubeCountX;
345 let cube_count_y = Variable::<D>::CubeCountY;
346 let cube_dim_x = Variable::<D>::CubeDimX;
347 let cube_dim_y = Variable::<D>::CubeDimY;
348 writeln!(
349 f,
350 "{ty} {variable} = ({absolute_pos_z} * {cube_count_x} * {cube_dim_x} * {cube_count_y} * {cube_dim_y}) + ({absolute_pos_y} * {cube_count_x} * {cube_dim_x}) + {absolute_pos_x};"
351 )?;
352 }
353
354 if settings.indexes.cube_dim {
355 let variable = Variable::<D>::CubeDim;
356 let ty = variable.item();
357 let cube_dim_x = Variable::<D>::CubeDimX;
358 let cube_dim_y = Variable::<D>::CubeDimY;
359 let cube_dim_z = Variable::<D>::CubeDimZ;
360 writeln!(
361 f,
362 "{ty} {variable} = {cube_dim_x} * {cube_dim_y} * {cube_dim_z};"
363 )?;
364 }
365
366 if settings.indexes.cube_count {
367 let variable = Variable::<D>::CubeCount;
368 let ty = variable.item();
369 let cube_count_x = Variable::<D>::CubeCountX;
370 let cube_count_y = Variable::<D>::CubeCountY;
371 let cube_count_z = Variable::<D>::CubeCountZ;
372 writeln!(
373 f,
374 "{ty} {variable} = {cube_count_x} * {cube_count_y} * {cube_count_z};"
375 )?;
376 }
377
378 if settings.indexes.cube_pos {
379 let variable = Variable::<D>::CubePos;
380 let ty = variable.item();
381 let cube_pos_x = Variable::<D>::CubePosX;
382 let cube_pos_y = Variable::<D>::CubePosY;
383 let cube_pos_z = Variable::<D>::CubePosZ;
384 let cube_count_x = Variable::<D>::CubeCountX;
385 let cube_count_y = Variable::<D>::CubeCountY;
386 writeln!(
387 f,
388 "{ty} {variable} = ({cube_pos_z} * {cube_count_y} * {cube_count_x}) + ({cube_pos_y} * {cube_count_x}) + {cube_pos_x};"
389 )?;
390 }
391
392 if settings.indexes.plane_dim_checked {
393 let plane_dim = Variable::<D>::PlaneDim;
394 let variable = Variable::<D>::PlaneDimChecked;
395 let ty = variable.item();
396 let cube_dim_x = Variable::<D>::CubeDimX;
397 let cube_dim_y = Variable::<D>::CubeDimY;
398 let cube_dim_z = Variable::<D>::CubeDimZ;
399 writeln!(
400 f,
401 "{ty} {variable} = min({plane_dim}, {cube_dim_x} * {cube_dim_y} * {cube_dim_z});"
402 )?;
403 }
404
405 if settings.indexes.cluster_pos {
406 f.write_str(
407 "
408cooperative_groups::cluster_group cluster = cooperative_groups::this_cluster();
409",
410 )?;
411 }
412
413 Ok(())
414}