cubecl_cpp/shared/
kernel.rs

1use crate::shared::STATIC_INFO_NAME;
2
3use super::{Body, Component, Dialect, Elem, Flags, INFO_NAME, Item, Variable};
4use cubecl_core::{
5    CubeDim,
6    compute::{Location, Visibility},
7    ir::Id,
8};
9
10use std::{collections::HashSet, fmt::Display};
11
12#[derive(Debug, PartialEq, Eq, Clone)]
13pub struct Binding<D: Dialect> {
14    pub id: Id,
15    pub item: Item<D>,
16    pub location: Location,
17    pub size: Option<usize>,
18    pub vis: Visibility,
19}
20
21#[derive(Debug, PartialEq, Eq, Clone)]
22pub struct SharedMemory<D: Dialect> {
23    pub index: Id,
24    pub item: Item<D>,
25    pub size: u32,
26    pub align: Option<u32>,
27}
28
29#[derive(Debug, PartialEq, Clone)]
30pub struct ConstArray<D: Dialect> {
31    pub index: Id,
32    pub item: Item<D>,
33    pub size: u32,
34    pub values: Vec<Variable<D>>,
35}
36
37#[derive(Debug, PartialEq, Eq, Clone)]
38pub struct LocalArray<D: Dialect> {
39    pub index: Id,
40    pub item: Item<D>,
41    pub size: u32,
42}
43
44impl<D: Dialect> LocalArray<D> {
45    pub fn new(index: Id, item: Item<D>, size: u32) -> Self {
46        Self { index, item, size }
47    }
48}
49
50impl<D: Dialect> SharedMemory<D> {
51    pub fn new(index: Id, item: Item<D>, size: u32, align: Option<u32>) -> Self {
52        Self {
53            index,
54            item,
55            size,
56            align,
57        }
58    }
59}
60
61#[derive(Debug, Clone)]
62pub struct ComputeKernel<D: Dialect> {
63    pub tensor_maps: Vec<Id>,
64    pub buffers: Vec<Binding<D>>,
65    pub scalars: Vec<(Elem<D>, usize)>,
66    pub meta_static_len: usize,
67    pub body: Body<D>,
68    pub cube_dim: CubeDim,
69    pub cluster_dim: Option<CubeDim>,
70    pub extensions: Vec<D::Extension>,
71    pub flags: Flags,
72    pub items: HashSet<super::Item<D>>,
73    pub kernel_name: String,
74}
75
76impl<D: Dialect> ComputeKernel<D> {
77    pub fn shared_memory_size(&self) -> usize {
78        let mut current = 0usize;
79
80        for var in self.body.shared_memories.iter() {
81            let factor = var.item.vectorization;
82            let elem_size_bytes = var.item.elem().size();
83            current += (var.size as usize) * factor * elem_size_bytes;
84        }
85
86        current
87    }
88}
89
90impl<D: Dialect> Display for ComputeKernel<D> {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        let mut flags = self.flags.clone();
93        if !self.tensor_maps.is_empty() {
94            flags.inst_tma = true;
95        }
96
97        // Program Scope -----------------------------------------------------
98        D::compile_includes(f, &flags)?;
99        D::compile_type_definitions(f, &self.items, &self.scalars, &flags)?;
100        D::compile_polyfills(f, &flags)?;
101        D::compile_extensions(f, &self.extensions)?;
102
103        // Kernel signature --------------------------------------------------
104        D::compile_kernel_signature(
105            f,
106            &self.kernel_name,
107            &self.tensor_maps,
108            &self.buffers,
109            &self.scalars,
110            &self.flags,
111        )?;
112
113        // Body --------------------------------------------------------------
114        f.write_str(" {\n")?;
115        compile_cube_builtin_bindings_decl::<D>(f, &self.flags)?;
116        write!(f, "{}", self.body)?;
117        f.write_str("\n}")?;
118
119        Ok(())
120    }
121}
122
123pub fn type_definitions<D: Dialect>(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
124    writeln!(f, "typedef unsigned int uint;")?;
125    writeln!(f, "typedef unsigned char uint8;")?;
126    writeln!(f, "typedef unsigned short uint16;")?;
127    writeln!(f, "typedef unsigned int uint32;")?;
128    writeln!(f, "typedef unsigned long long int uint64;")?;
129
130    writeln!(f, "typedef signed char int8;")?;
131    writeln!(f, "typedef signed short int16;")?;
132    writeln!(f, "typedef signed int int32;")?;
133    writeln!(f, "typedef signed long long int int64;")?;
134
135    Ok(())
136}
137
138pub fn type_vectorized_definitions<D: Dialect>(
139    f: &mut std::fmt::Formatter<'_>,
140    items: &HashSet<Item<D>>,
141) -> std::fmt::Result {
142    for item in items.iter() {
143        let elem = item.elem;
144        let size = item.vectorization;
145        let alignment = elem.size() * size;
146        if size > 1 {
147            write!(
148                f,
149                "
150struct __align__({alignment}) {item} {{"
151            )?;
152
153            for i in 0..size {
154                write!(
155                    f,
156                    "
157    {elem} i_{i};"
158                )?;
159            }
160
161            f.write_str("\n};")?;
162        }
163    }
164    Ok(())
165}
166
167pub fn type_scalar_definitions<D: Dialect>(
168    f: &mut std::fmt::Formatter<'_>,
169    scalars: &[(Elem<D>, usize)],
170) -> std::fmt::Result {
171    for (elem, count) in scalars.iter() {
172        writeln!(
173            f,
174            "
175struct scalars_{elem}_st {{
176{elem} x[{count}];
177}};"
178        )?;
179    }
180    Ok(())
181}
182
183pub fn type_info_definition<D: Dialect>(
184    f: &mut std::fmt::Formatter<'_>,
185    static_len: usize,
186) -> std::fmt::Result {
187    if static_len > 0 {
188        write!(
189            f,
190            "
191struct metadata_st {{
192uint x[{}];
193}};
194",
195            static_len
196        )?;
197    }
198    Ok(())
199}
200
201pub fn compile_bindings<D: Dialect>(
202    f: &mut core::fmt::Formatter<'_>,
203    tensor_maps: &[Id],
204    buffers: &[Binding<D>],
205    trailing_comma: bool,
206    flags: &Flags,
207) -> core::fmt::Result {
208    write!(f, "    ")?;
209
210    let mut args = Vec::new();
211
212    args.extend(
213        tensor_maps
214            .iter()
215            .map(|id| format!("const __grid_constant__ CUtensorMap tensor_map_{id}")),
216    );
217    args.extend(buffers.iter().map(|binding| match binding.vis {
218        Visibility::Read => {
219            format!("{} buffer_{}[]", binding.item, binding.id)
220            // TODO: It breaks slices, because we can't easily create pointer to __restrict__,
221            // we should have multiple pointer types to enable that optimization.
222            //
223            // write!(f, "const {}* __restrict__ input_{}", binding.item, index)?;
224        }
225        Visibility::ReadWrite => {
226            format!("{} buffer_{}[]", binding.item, binding.id)
227        }
228    }));
229    args.extend(
230        flags
231            .has_dynamic_meta
232            .then(|| format!("const uint32* __restrict__ {INFO_NAME}")),
233    );
234
235    write!(f, "{}", args.join(", "))?;
236    if trailing_comma {
237        f.write_str(", ")?;
238    }
239    Ok(())
240}
241
242pub fn compile_scalars_dynamic<D: Dialect>(
243    f: &mut std::fmt::Formatter<'_>,
244    scalars: &[(Elem<D>, usize)],
245) -> core::fmt::Result {
246    let scalar_inputs = scalars
247        .iter()
248        .map(|(elem, _)| format!("const {}* __restrict__ scalars_{}", elem, elem));
249    let scalar_inputs = scalar_inputs.collect::<Vec<String>>();
250
251    write!(f, "{}", scalar_inputs.join(","))
252}
253
254pub fn compile_scalars_static<D: Dialect>(
255    f: &mut std::fmt::Formatter<'_>,
256    scalars: &[(Elem<D>, usize)],
257    flags: &Flags,
258) -> core::fmt::Result {
259    let mut scalar_inputs = Vec::new();
260
261    // Need to sort elements because of alignment when packing
262    // Metadata is align 4 so it needs to be spliced in the middle.
263    let scalars_of_size = |scalar_inputs: &mut Vec<String>, size: usize| {
264        for (elem, _) in scalars.iter().filter(|it| it.0.size() == size) {
265            scalar_inputs.push(format!(
266                "const __grid_constant__ scalars_{elem}_st scalars_{elem}"
267            ));
268        }
269    };
270
271    // Pack 64-bit aligned types first, since metadata is 32-bit aligned
272    scalars_of_size(&mut scalar_inputs, 8);
273
274    // Pack metadata
275    if flags.static_meta_length > 0 {
276        scalar_inputs.push(format!(
277            "const __grid_constant__ metadata_st {STATIC_INFO_NAME}"
278        ));
279    }
280
281    // Pack remaining scalars that are 4 bytes or below
282    for size in [4, 2, 1] {
283        scalars_of_size(&mut scalar_inputs, size);
284    }
285
286    write!(f, "{}", scalar_inputs.join(", "))
287}
288
289fn compile_cube_builtin_bindings_decl<D: Dialect>(
290    f: &mut core::fmt::Formatter<'_>,
291    settings: &Flags,
292) -> core::fmt::Result {
293    if settings.indexes.absolute_pos_tuple {
294        D::compile_absolute_pos_tuple_computation(f)?;
295    }
296
297    if settings.indexes.unit_pos {
298        D::compile_unit_pos_computation(f)?;
299    }
300
301    if settings.indexes.absolute_pos {
302        let variable = Variable::<D>::AbsolutePos;
303        let ty = variable.item();
304        let absolute_pos_x = Variable::<D>::AbsolutePosX;
305        let absolute_pos_y = Variable::<D>::AbsolutePosY;
306        let absolute_pos_z = Variable::<D>::AbsolutePosZ;
307        let cube_count_x = Variable::<D>::CubeCountX;
308        let cube_count_y = Variable::<D>::CubeCountY;
309        let cube_dim_x = Variable::<D>::CubeDimX;
310        let cube_dim_y = Variable::<D>::CubeDimY;
311        writeln!(
312            f,
313            "{ty} {variable} = ({absolute_pos_z} * {cube_count_x} * {cube_dim_x} * {cube_count_y} * {cube_dim_y}) + ({absolute_pos_y} * {cube_count_x} * {cube_dim_x}) + {absolute_pos_x};"
314        )?;
315    }
316
317    if settings.indexes.cube_dim {
318        let variable = Variable::<D>::CubeDim;
319        let ty = variable.item();
320        let cube_dim_x = Variable::<D>::CubeDimX;
321        let cube_dim_y = Variable::<D>::CubeDimY;
322        let cube_dim_z = Variable::<D>::CubeDimZ;
323        writeln!(
324            f,
325            "{ty} {variable} = {cube_dim_x} * {cube_dim_y} * {cube_dim_z};"
326        )?;
327    }
328
329    if settings.indexes.cube_count {
330        let variable = Variable::<D>::CubeCount;
331        let ty = variable.item();
332        let cube_count_x = Variable::<D>::CubeCountX;
333        let cube_count_y = Variable::<D>::CubeCountY;
334        let cube_count_z = Variable::<D>::CubeCountZ;
335        writeln!(
336            f,
337            "{ty} {variable} = {cube_count_x} * {cube_count_y} * {cube_count_z};"
338        )?;
339    }
340
341    if settings.indexes.cube_pos {
342        let variable = Variable::<D>::CubePos;
343        let ty = variable.item();
344        let cube_pos_x = Variable::<D>::CubePosX;
345        let cube_pos_y = Variable::<D>::CubePosY;
346        let cube_pos_z = Variable::<D>::CubePosZ;
347        let cube_count_x = Variable::<D>::CubeCountX;
348        let cube_count_y = Variable::<D>::CubeCountY;
349        writeln!(
350            f,
351            "{ty} {variable} = ({cube_pos_z} * {cube_count_y} * {cube_count_x}) + ({cube_pos_y} * {cube_count_x}) + {cube_pos_x};"
352        )?;
353    }
354
355    if settings.indexes.plane_dim_checked {
356        let plane_dim = Variable::<D>::PlaneDim;
357        let variable = Variable::<D>::PlaneDimChecked;
358        let ty = variable.item();
359        let cube_dim_x = Variable::<D>::CubeDimX;
360        let cube_dim_y = Variable::<D>::CubeDimY;
361        let cube_dim_z = Variable::<D>::CubeDimZ;
362        writeln!(
363            f,
364            "{ty} {variable} = min({plane_dim}, {cube_dim_x} * {cube_dim_y} * {cube_dim_z});"
365        )?;
366    }
367
368    if settings.indexes.cluster_pos {
369        f.write_str(
370            "
371cooperative_groups::cluster_group cluster = cooperative_groups::this_cluster();
372",
373        )?;
374    }
375
376    Ok(())
377}