cubecl_cpp/shared/
kernel.rs

1use crate::shared::STATIC_INFO_NAME;
2
3use super::{Body, Component, Dialect, Elem, Flags, INFO_NAME, Item, Variable};
4use cubecl_core::{
5    CubeDim,
6    compute::{Location, Visibility},
7    ir::Id,
8};
9
10use std::{collections::HashSet, fmt::Display};
11
12#[derive(Debug, PartialEq, Eq, Clone)]
13pub struct Binding<D: Dialect> {
14    pub id: Id,
15    pub item: Item<D>,
16    pub location: Location,
17    pub size: Option<usize>,
18    pub vis: Visibility,
19}
20
21#[derive(Debug, PartialEq, Eq, Clone)]
22pub struct SharedMemory<D: Dialect> {
23    pub index: Id,
24    pub item: Item<D>,
25    pub size: u32,
26    pub align: Option<u32>,
27    pub offset: u32,
28}
29
30#[derive(Debug, PartialEq, Clone)]
31pub struct ConstArray<D: Dialect> {
32    pub index: Id,
33    pub item: Item<D>,
34    pub size: u32,
35    pub values: Vec<Variable<D>>,
36}
37
38#[derive(Debug, PartialEq, Eq, Clone)]
39pub struct LocalArray<D: Dialect> {
40    pub index: Id,
41    pub item: Item<D>,
42    pub size: u32,
43}
44
45impl<D: Dialect> LocalArray<D> {
46    pub fn new(index: Id, item: Item<D>, size: u32) -> Self {
47        Self { index, item, size }
48    }
49}
50
51impl<D: Dialect> SharedMemory<D> {
52    pub fn new(index: Id, item: Item<D>, size: u32, align: Option<u32>) -> Self {
53        Self {
54            index,
55            item,
56            size,
57            align,
58            offset: 0, // initialized later
59        }
60    }
61}
62
63#[derive(Debug, Clone)]
64pub struct ComputeKernel<D: Dialect> {
65    pub tensor_maps: Vec<Id>,
66    pub buffers: Vec<Binding<D>>,
67    pub scalars: Vec<(Elem<D>, usize)>,
68    pub meta_static_len: usize,
69    pub body: Body<D>,
70    pub cube_dim: CubeDim,
71    pub cluster_dim: Option<CubeDim>,
72    pub extensions: Vec<D::Extension>,
73    pub flags: Flags,
74    pub items: HashSet<super::Item<D>>,
75    pub kernel_name: String,
76}
77
78impl<D: Dialect> ComputeKernel<D> {
79    pub fn shared_memory_size(&self) -> usize {
80        // Account for alignment padding between shared memory buffers
81        // Sorted to minimize that padding
82        let mut shared_memories = self.body.shared_memories.clone();
83        shared_memories.sort_by_key(|smem| smem.align.unwrap_or(smem.item.size() as u32));
84        shared_memories.reverse();
85
86        let mut current = 0usize;
87
88        for var in self.body.shared_memories.iter() {
89            let align = var.align.unwrap_or(var.item.size() as u32);
90            let size_bytes = var.size as usize * var.item.size();
91            let offset = current.next_multiple_of(align as usize);
92            current = offset + size_bytes;
93        }
94
95        current
96    }
97}
98
99impl<D: Dialect> Display for ComputeKernel<D> {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        let mut flags = self.flags.clone();
102        if !self.tensor_maps.is_empty() {
103            flags.inst_tma = true;
104        }
105
106        // Program Scope -----------------------------------------------------
107        D::compile_includes(f, &flags)?;
108        D::compile_type_definitions(f, &self.items, &self.scalars, &flags)?;
109        D::compile_polyfills(f, &flags)?;
110        D::compile_extensions(f, &self.extensions)?;
111
112        // Kernel signature --------------------------------------------------
113        D::compile_kernel_signature(
114            f,
115            &self.kernel_name,
116            &self.tensor_maps,
117            &self.buffers,
118            &self.scalars,
119            &self.flags,
120        )?;
121
122        // Body --------------------------------------------------------------
123        f.write_str(" {\n")?;
124        compile_cube_builtin_bindings_decl::<D>(f, &self.flags)?;
125        write!(f, "{}", self.body)?;
126        f.write_str("\n}")?;
127
128        Ok(())
129    }
130}
131
132pub fn type_definitions<D: Dialect>(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133    writeln!(f, "typedef unsigned int uint;")?;
134    writeln!(f, "typedef unsigned char uint8;")?;
135    writeln!(f, "typedef unsigned short uint16;")?;
136    writeln!(f, "typedef unsigned int uint32;")?;
137    writeln!(f, "typedef unsigned long long int uint64;")?;
138
139    writeln!(f, "typedef signed char int8;")?;
140    writeln!(f, "typedef signed short int16;")?;
141    writeln!(f, "typedef signed int int32;")?;
142    writeln!(f, "typedef signed long long int int64;")?;
143
144    Ok(())
145}
146
147pub fn type_vectorized_definitions<D: Dialect>(
148    f: &mut std::fmt::Formatter<'_>,
149    items: &HashSet<Item<D>>,
150) -> std::fmt::Result {
151    for item in items.iter() {
152        let elem = item.elem;
153        let size = item.vectorization;
154        let alignment = elem.size() * size;
155        if size > 1 {
156            write!(
157                f,
158                "
159struct __align__({alignment}) {item} {{"
160            )?;
161
162            for i in 0..size {
163                write!(
164                    f,
165                    "
166    {elem} i_{i};"
167                )?;
168            }
169
170            f.write_str("\n};")?;
171        }
172    }
173    Ok(())
174}
175
176pub fn type_scalar_definitions<D: Dialect>(
177    f: &mut std::fmt::Formatter<'_>,
178    scalars: &[(Elem<D>, usize)],
179) -> std::fmt::Result {
180    for (elem, count) in scalars.iter() {
181        writeln!(
182            f,
183            "
184struct scalars_{elem}_st {{
185{elem} x[{count}];
186}};"
187        )?;
188    }
189    Ok(())
190}
191
192pub fn type_info_definition<D: Dialect>(
193    f: &mut std::fmt::Formatter<'_>,
194    static_len: usize,
195) -> std::fmt::Result {
196    if static_len > 0 {
197        write!(
198            f,
199            "
200struct metadata_st {{
201uint x[{static_len}];
202}};
203"
204        )?;
205    }
206    Ok(())
207}
208
209pub fn compile_bindings<D: Dialect>(
210    f: &mut core::fmt::Formatter<'_>,
211    tensor_maps: &[Id],
212    buffers: &[Binding<D>],
213    trailing_comma: bool,
214    flags: &Flags,
215) -> core::fmt::Result {
216    write!(f, "    ")?;
217
218    let mut args = Vec::new();
219
220    args.extend(
221        tensor_maps
222            .iter()
223            .map(|id| format!("const __grid_constant__ CUtensorMap tensor_map_{id}")),
224    );
225    args.extend(buffers.iter().map(|binding| match binding.vis {
226        Visibility::Read => {
227            format!("const {}* __restrict__ buffer_{}", binding.item, binding.id)
228        }
229        Visibility::ReadWrite => {
230            format!("{}* buffer_{}", binding.item, binding.id)
231        }
232    }));
233    args.extend(
234        flags
235            .has_dynamic_meta
236            .then(|| format!("const uint32* __restrict__ {INFO_NAME}")),
237    );
238
239    write!(f, "{}", args.join(", "))?;
240    if trailing_comma {
241        f.write_str(", ")?;
242    }
243    Ok(())
244}
245
246pub fn compile_scalars_dynamic<D: Dialect>(
247    f: &mut std::fmt::Formatter<'_>,
248    scalars: &[(Elem<D>, usize)],
249) -> core::fmt::Result {
250    let scalar_inputs = scalars
251        .iter()
252        .map(|(elem, _)| format!("const {elem}* __restrict__ scalars_{elem}"));
253    let scalar_inputs = scalar_inputs.collect::<Vec<String>>();
254
255    write!(f, "{}", scalar_inputs.join(","))
256}
257
258pub fn compile_scalars_static<D: Dialect>(
259    f: &mut std::fmt::Formatter<'_>,
260    scalars: &[(Elem<D>, usize)],
261    flags: &Flags,
262) -> core::fmt::Result {
263    let mut scalar_inputs = Vec::new();
264
265    // Need to sort elements because of alignment when packing
266    // Metadata is align 4 so it needs to be spliced in the middle.
267    let scalars_of_size = |scalar_inputs: &mut Vec<String>, size: usize| {
268        for (elem, _) in scalars.iter().filter(|it| it.0.size() == size) {
269            scalar_inputs.push(format!(
270                "const __grid_constant__ scalars_{elem}_st scalars_{elem}"
271            ));
272        }
273    };
274
275    // Pack 64-bit aligned types first, since metadata is 32-bit aligned
276    scalars_of_size(&mut scalar_inputs, 8);
277
278    // Pack metadata
279    if flags.static_meta_length > 0 {
280        scalar_inputs.push(format!(
281            "const __grid_constant__ metadata_st {STATIC_INFO_NAME}"
282        ));
283    }
284
285    // Pack remaining scalars that are 4 bytes or below
286    for size in [4, 2, 1] {
287        scalars_of_size(&mut scalar_inputs, size);
288    }
289
290    write!(f, "{}", scalar_inputs.join(", "))
291}
292
293fn compile_cube_builtin_bindings_decl<D: Dialect>(
294    f: &mut core::fmt::Formatter<'_>,
295    settings: &Flags,
296) -> core::fmt::Result {
297    if settings.indexes.absolute_pos_tuple {
298        D::compile_absolute_pos_tuple_computation(f)?;
299    }
300
301    if settings.indexes.unit_pos {
302        D::compile_unit_pos_computation(f)?;
303    }
304
305    if settings.indexes.absolute_pos {
306        let variable = Variable::<D>::AbsolutePos;
307        let ty = variable.item();
308        let absolute_pos_x = Variable::<D>::AbsolutePosX;
309        let absolute_pos_y = Variable::<D>::AbsolutePosY;
310        let absolute_pos_z = Variable::<D>::AbsolutePosZ;
311        let cube_count_x = Variable::<D>::CubeCountX;
312        let cube_count_y = Variable::<D>::CubeCountY;
313        let cube_dim_x = Variable::<D>::CubeDimX;
314        let cube_dim_y = Variable::<D>::CubeDimY;
315        writeln!(
316            f,
317            "{ty} {variable} = ({absolute_pos_z} * {cube_count_x} * {cube_dim_x} * {cube_count_y} * {cube_dim_y}) + ({absolute_pos_y} * {cube_count_x} * {cube_dim_x}) + {absolute_pos_x};"
318        )?;
319    }
320
321    if settings.indexes.cube_dim {
322        let variable = Variable::<D>::CubeDim;
323        let ty = variable.item();
324        let cube_dim_x = Variable::<D>::CubeDimX;
325        let cube_dim_y = Variable::<D>::CubeDimY;
326        let cube_dim_z = Variable::<D>::CubeDimZ;
327        writeln!(
328            f,
329            "{ty} {variable} = {cube_dim_x} * {cube_dim_y} * {cube_dim_z};"
330        )?;
331    }
332
333    if settings.indexes.cube_count {
334        let variable = Variable::<D>::CubeCount;
335        let ty = variable.item();
336        let cube_count_x = Variable::<D>::CubeCountX;
337        let cube_count_y = Variable::<D>::CubeCountY;
338        let cube_count_z = Variable::<D>::CubeCountZ;
339        writeln!(
340            f,
341            "{ty} {variable} = {cube_count_x} * {cube_count_y} * {cube_count_z};"
342        )?;
343    }
344
345    if settings.indexes.cube_pos {
346        let variable = Variable::<D>::CubePos;
347        let ty = variable.item();
348        let cube_pos_x = Variable::<D>::CubePosX;
349        let cube_pos_y = Variable::<D>::CubePosY;
350        let cube_pos_z = Variable::<D>::CubePosZ;
351        let cube_count_x = Variable::<D>::CubeCountX;
352        let cube_count_y = Variable::<D>::CubeCountY;
353        writeln!(
354            f,
355            "{ty} {variable} = ({cube_pos_z} * {cube_count_y} * {cube_count_x}) + ({cube_pos_y} * {cube_count_x}) + {cube_pos_x};"
356        )?;
357    }
358
359    if settings.indexes.plane_dim_checked {
360        let plane_dim = Variable::<D>::PlaneDim;
361        let variable = Variable::<D>::PlaneDimChecked;
362        let ty = variable.item();
363        let cube_dim_x = Variable::<D>::CubeDimX;
364        let cube_dim_y = Variable::<D>::CubeDimY;
365        let cube_dim_z = Variable::<D>::CubeDimZ;
366        writeln!(
367            f,
368            "{ty} {variable} = min({plane_dim}, {cube_dim_x} * {cube_dim_y} * {cube_dim_z});"
369        )?;
370    }
371
372    if settings.indexes.cluster_pos {
373        f.write_str(
374            "
375cooperative_groups::cluster_group cluster = cooperative_groups::this_cluster();
376",
377        )?;
378    }
379
380    Ok(())
381}