1use crate::shared::STATIC_INFO_NAME;
2
3use super::{Body, Component, Dialect, Elem, Flags, INFO_NAME, Item, Variable};
4use cubecl_core::{
5 CubeDim,
6 ir::Id,
7 prelude::{Location, Visibility},
8};
9
10use std::{collections::HashSet, fmt::Display};
11
12#[derive(Debug, PartialEq, Eq, Clone)]
13pub struct Binding<D: Dialect> {
14 pub id: Id,
15 pub item: Item<D>,
16 pub location: Location,
17 pub size: Option<usize>,
18 pub vis: Visibility,
19}
20
21#[derive(Debug, PartialEq, Eq, Clone)]
22pub enum SharedMemory<D: Dialect> {
23 Array {
24 index: Id,
25 item: Item<D>,
26 length: usize,
27 align: usize,
28 offset: usize,
29 },
30 Value {
31 index: Id,
32 item: Item<D>,
33 align: usize,
34 offset: usize,
35 },
36}
37
38impl<D: Dialect> SharedMemory<D> {
39 pub fn size(&self) -> usize {
40 match self {
41 SharedMemory::Array { item, length, .. } => *length * item.size(),
42 SharedMemory::Value { item, .. } => item.size(),
43 }
44 }
45
46 pub fn align(&self) -> usize {
47 match self {
48 SharedMemory::Array { align, .. } => *align,
49 SharedMemory::Value { align, .. } => *align,
50 }
51 }
52
53 pub fn offset(&self) -> usize {
54 match self {
55 SharedMemory::Array { offset, .. } => *offset,
56 SharedMemory::Value { offset, .. } => *offset,
57 }
58 }
59}
60
61#[derive(Debug, PartialEq, Clone)]
62pub struct ConstArray<D: Dialect> {
63 pub index: Id,
64 pub item: Item<D>,
65 pub size: u32,
66 pub values: Vec<Variable<D>>,
67}
68
69#[derive(Debug, PartialEq, Eq, Clone)]
70pub struct LocalArray<D: Dialect> {
71 pub index: Id,
72 pub item: Item<D>,
73 pub size: usize,
74}
75
76impl<D: Dialect> LocalArray<D> {
77 pub fn new(index: Id, item: Item<D>, size: usize) -> Self {
78 Self { index, item, size }
79 }
80}
81
82impl<D: Dialect> SharedMemory<D> {
83 pub fn new_array(index: Id, item: Item<D>, size: usize, align: usize) -> Self {
84 Self::Array {
85 index,
86 item,
87 length: size,
88 align,
89 offset: 0, }
91 }
92
93 pub fn new_value(index: Id, item: Item<D>, align: usize) -> Self {
94 Self::Value {
95 index,
96 item,
97 align,
98 offset: 0, }
100 }
101}
102
103#[derive(Debug, Clone)]
104pub struct ComputeKernel<D: Dialect> {
105 pub tensor_maps: Vec<Binding<D>>,
106 pub buffers: Vec<Binding<D>>,
107 pub scalars: Vec<(Elem<D>, usize)>,
108 pub meta_static_len: usize,
109 pub body: Body<D>,
110 pub cube_dim: CubeDim,
111 pub cluster_dim: Option<CubeDim>,
112 pub extensions: Vec<D::Extension>,
113 pub flags: Flags<D>,
114 pub items: HashSet<super::Item<D>>,
115 pub kernel_name: String,
116}
117
118impl<D: Dialect> ComputeKernel<D> {
119 pub fn shared_memory_size(&self) -> usize {
120 let smems = self.body.shared_memories.iter();
121 let ends = smems.map(|it| it.offset() + it.size());
122 ends.max().unwrap_or_default()
123 }
124}
125
126impl<D: Dialect> Display for ComputeKernel<D> {
127 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128 let mut flags = self.flags.clone();
129 if !self.tensor_maps.is_empty() {
130 flags.inst_tma = true;
131 }
132
133 D::compile_includes(f, &flags)?;
135 D::compile_type_definitions(f, &self.items, &self.scalars, &flags)?;
136 D::compile_polyfills(f, &flags)?;
137 D::compile_extensions(f, &self.extensions)?;
138
139 D::compile_kernel_signature(
141 f,
142 &self.kernel_name,
143 &self.tensor_maps,
144 &self.buffers,
145 &self.scalars,
146 &self.flags,
147 )?;
148
149 f.write_str(" {\n")?;
151 compile_cube_builtin_bindings_decl::<D>(f, &self.flags)?;
152 write!(f, "{}", self.body)?;
153 f.write_str("\n}")?;
154
155 Ok(())
156 }
157}
158
159pub fn type_definitions<D: Dialect>(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 writeln!(f, "typedef unsigned int uint;")?;
161 writeln!(f, "typedef unsigned char uint8;")?;
162 writeln!(f, "typedef unsigned short uint16;")?;
163 writeln!(f, "typedef unsigned int uint32;")?;
164 writeln!(f, "typedef unsigned long long int uint64;")?;
165
166 writeln!(f, "typedef signed char int8;")?;
167 writeln!(f, "typedef signed short int16;")?;
168 writeln!(f, "typedef signed int int32;")?;
169 writeln!(f, "typedef signed long long int int64;")?;
170
171 Ok(())
172}
173
174pub fn type_vectorized_definitions<D: Dialect>(
175 f: &mut std::fmt::Formatter<'_>,
176 items: &HashSet<Item<D>>,
177) -> std::fmt::Result {
178 for item in items.iter() {
179 let elem = item.elem;
180 let size = item.vectorization;
181 let alignment = elem.size() * size;
182 if size > 1 {
183 write!(
184 f,
185 "
186struct __align__({alignment}) {item} {{"
187 )?;
188
189 for i in 0..size {
190 write!(
191 f,
192 "
193 {elem} i_{i};"
194 )?;
195 }
196
197 f.write_str("\n};")?;
198 }
199 }
200 Ok(())
201}
202
203pub fn type_scalar_definitions<D: Dialect>(
204 f: &mut std::fmt::Formatter<'_>,
205 scalars: &[(Elem<D>, usize)],
206) -> std::fmt::Result {
207 for (elem, count) in scalars.iter() {
208 writeln!(
209 f,
210 "
211struct scalars_{elem}_st {{
212{elem} x[{count}];
213}};"
214 )?;
215 }
216 Ok(())
217}
218
219pub fn type_info_definition<D: Dialect>(
220 f: &mut std::fmt::Formatter<'_>,
221 static_len: usize,
222 address_type: Item<D>,
223) -> std::fmt::Result {
224 if static_len > 0 {
225 write!(
226 f,
227 "
228struct metadata_st {{
229 {address_type} x[{static_len}];
230}};
231"
232 )?;
233 }
234 Ok(())
235}
236
237pub fn compile_bindings<D: Dialect>(
238 f: &mut core::fmt::Formatter<'_>,
239 tensor_maps: &[Binding<D>],
240 buffers: &[Binding<D>],
241 trailing_comma: bool,
242 flags: &Flags<D>,
243) -> core::fmt::Result {
244 write!(f, " ")?;
245
246 let mut args = Vec::new();
247
248 args.extend(tensor_maps.iter().map(|binding| {
249 format!(
250 "const __grid_constant__ CUtensorMap tensor_map_{}",
251 binding.id
252 )
253 }));
254 args.extend(
255 tensor_maps
256 .iter()
257 .chain(buffers.iter())
258 .map(|binding| match binding.vis {
259 Visibility::Read => {
260 format!("const {}* __restrict__ buffer_{}", binding.item, binding.id)
261 }
262 Visibility::ReadWrite => {
263 format!("{}* buffer_{}", binding.item, binding.id)
264 }
265 }),
266 );
267 args.extend(
268 flags
269 .has_dynamic_meta
270 .then(|| format!("const {}* __restrict__ {INFO_NAME}", flags.address_type)),
271 );
272
273 write!(f, "{}", args.join(", "))?;
274 if trailing_comma {
275 f.write_str(", ")?;
276 }
277 Ok(())
278}
279
280pub fn compile_scalars_dynamic<D: Dialect>(
281 f: &mut std::fmt::Formatter<'_>,
282 scalars: &[(Elem<D>, usize)],
283) -> core::fmt::Result {
284 let scalar_inputs = scalars
285 .iter()
286 .map(|(elem, _)| format!("const {elem}* __restrict__ scalars_{elem}"));
287 let scalar_inputs = scalar_inputs.collect::<Vec<String>>();
288
289 write!(f, "{}", scalar_inputs.join(","))
290}
291
292pub fn compile_scalars_static<D: Dialect>(
293 f: &mut std::fmt::Formatter<'_>,
294 scalars: &[(Elem<D>, usize)],
295 flags: &Flags<D>,
296) -> core::fmt::Result {
297 let mut scalar_inputs = Vec::new();
298
299 let scalars_of_size = |scalar_inputs: &mut Vec<String>, size: usize| {
302 for (elem, _) in scalars.iter().filter(|it| it.0.size() == size) {
303 scalar_inputs.push(format!(
304 "const __grid_constant__ scalars_{elem}_st scalars_{elem}"
305 ));
306 }
307 };
308
309 scalars_of_size(&mut scalar_inputs, 8);
311
312 if flags.static_meta_length > 0 {
314 scalar_inputs.push(format!(
315 "const __grid_constant__ metadata_st {STATIC_INFO_NAME}"
316 ));
317 }
318
319 for size in [4, 2, 1] {
321 scalars_of_size(&mut scalar_inputs, size);
322 }
323
324 write!(f, "{}", scalar_inputs.join(", "))
325}
326
327fn compile_cube_builtin_bindings_decl<D: Dialect>(
328 f: &mut core::fmt::Formatter<'_>,
329 settings: &Flags<D>,
330) -> core::fmt::Result {
331 if settings.indexes.absolute_pos_tuple {
332 D::compile_absolute_pos_tuple_computation(f)?;
333 }
334
335 if settings.indexes.unit_pos {
336 D::compile_unit_pos_computation(f)?;
337 }
338
339 if settings.indexes.absolute_pos {
340 let variable = Variable::<D>::AbsolutePos(settings.address_type.elem);
341 let ty = variable.item();
342 let absolute_pos_x = Variable::<D>::AbsolutePosX.fmt_cast_to(ty);
343 let absolute_pos_y = Variable::<D>::AbsolutePosY.fmt_cast_to(ty);
344 let absolute_pos_z = Variable::<D>::AbsolutePosZ.fmt_cast_to(ty);
345 let cube_count_x = Variable::<D>::CubeCountX.fmt_cast_to(ty);
346 let cube_count_y = Variable::<D>::CubeCountY.fmt_cast_to(ty);
347 let cube_dim_x = Variable::<D>::CubeDimX.fmt_cast_to(ty);
348 let cube_dim_y = Variable::<D>::CubeDimY.fmt_cast_to(ty);
349 writeln!(
350 f,
351 "{ty} {variable} = (
352 {absolute_pos_z} * {cube_count_x} * {cube_dim_x} * {cube_count_y} * {cube_dim_y})
353 + ({absolute_pos_y} * {cube_count_x} * {cube_dim_x})
354 + {absolute_pos_x};"
355 )?;
356 }
357
358 if settings.indexes.cube_dim {
359 let variable = Variable::<D>::CubeDim;
360 let ty = variable.item();
361 let cube_dim_x = Variable::<D>::CubeDimX;
362 let cube_dim_y = Variable::<D>::CubeDimY;
363 let cube_dim_z = Variable::<D>::CubeDimZ;
364 writeln!(
365 f,
366 "{ty} {variable} = {cube_dim_x} * {cube_dim_y} * {cube_dim_z};"
367 )?;
368 }
369
370 if settings.indexes.cube_count {
371 let variable = Variable::<D>::CubeCount(settings.address_type.elem);
372 let ty = variable.item();
373 let cube_count_x = Variable::<D>::CubeCountX.fmt_cast_to(ty);
374 let cube_count_y = Variable::<D>::CubeCountY.fmt_cast_to(ty);
375 let cube_count_z = Variable::<D>::CubeCountZ.fmt_cast_to(ty);
376 writeln!(
377 f,
378 "{ty} {variable} = {cube_count_x} * {cube_count_y} * {cube_count_z};"
379 )?;
380 }
381
382 if settings.indexes.cube_pos {
383 let variable = Variable::<D>::CubePos(settings.address_type.elem);
384 let ty = variable.item();
385 let cube_pos_x = Variable::<D>::CubePosX.fmt_cast_to(ty);
386 let cube_pos_y = Variable::<D>::CubePosY.fmt_cast_to(ty);
387 let cube_pos_z = Variable::<D>::CubePosZ.fmt_cast_to(ty);
388 let cube_count_x = Variable::<D>::CubeCountX.fmt_cast_to(ty);
389 let cube_count_y = Variable::<D>::CubeCountY.fmt_cast_to(ty);
390 writeln!(
391 f,
392 "{ty} {variable} = ({cube_pos_z} * {cube_count_y} * {cube_count_x}) + ({cube_pos_y} * {cube_count_x}) + {cube_pos_x};"
393 )?;
394 }
395
396 if settings.indexes.plane_dim_checked {
397 let plane_dim = Variable::<D>::PlaneDim;
398 let variable = Variable::<D>::PlaneDimChecked;
399 let ty = variable.item();
400 let cube_dim_x = Variable::<D>::CubeDimX;
401 let cube_dim_y = Variable::<D>::CubeDimY;
402 let cube_dim_z = Variable::<D>::CubeDimZ;
403 writeln!(
404 f,
405 "{ty} {variable} = min({plane_dim}, {cube_dim_x} * {cube_dim_y} * {cube_dim_z});"
406 )?;
407 }
408
409 if settings.indexes.cluster_pos {
410 f.write_str(
411 "
412cooperative_groups::cluster_group cluster = cooperative_groups::this_cluster();
413",
414 )?;
415 }
416
417 Ok(())
418}