cubecl_cpp/shared/
kernel.rs

1use crate::shared::STATIC_INFO_NAME;
2
3use super::{Body, Component, Dialect, Elem, Flags, INFO_NAME, Item, Variable};
4use cubecl_core::{
5    CubeDim,
6    compute::{Location, Visibility},
7    ir::Id,
8};
9
10use std::{collections::HashSet, fmt::Display};
11
12#[derive(Debug, PartialEq, Eq, Clone)]
13pub struct Binding<D: Dialect> {
14    pub id: Id,
15    pub item: Item<D>,
16    pub location: Location,
17    pub size: Option<usize>,
18    pub vis: Visibility,
19}
20
21#[derive(Debug, PartialEq, Eq, Clone)]
22pub struct SharedMemory<D: Dialect> {
23    pub index: Id,
24    pub item: Item<D>,
25    pub length: u32,
26    pub align: u32,
27    pub offset: u32,
28}
29
30impl<D: Dialect> SharedMemory<D> {
31    pub fn size(&self) -> u32 {
32        self.length * self.item.size() as u32
33    }
34}
35
36#[derive(Debug, PartialEq, Clone)]
37pub struct ConstArray<D: Dialect> {
38    pub index: Id,
39    pub item: Item<D>,
40    pub size: u32,
41    pub values: Vec<Variable<D>>,
42}
43
44#[derive(Debug, PartialEq, Eq, Clone)]
45pub struct LocalArray<D: Dialect> {
46    pub index: Id,
47    pub item: Item<D>,
48    pub size: u32,
49}
50
51impl<D: Dialect> LocalArray<D> {
52    pub fn new(index: Id, item: Item<D>, size: u32) -> Self {
53        Self { index, item, size }
54    }
55}
56
57impl<D: Dialect> SharedMemory<D> {
58    pub fn new(index: Id, item: Item<D>, size: u32, align: u32) -> Self {
59        Self {
60            index,
61            item,
62            length: size,
63            align,
64            offset: 0, // initialized later
65        }
66    }
67}
68
69#[derive(Debug, Clone)]
70pub struct ComputeKernel<D: Dialect> {
71    pub tensor_maps: Vec<Binding<D>>,
72    pub buffers: Vec<Binding<D>>,
73    pub scalars: Vec<(Elem<D>, usize)>,
74    pub meta_static_len: usize,
75    pub body: Body<D>,
76    pub cube_dim: CubeDim,
77    pub cluster_dim: Option<CubeDim>,
78    pub extensions: Vec<D::Extension>,
79    pub flags: Flags,
80    pub items: HashSet<super::Item<D>>,
81    pub kernel_name: String,
82}
83
84impl<D: Dialect> ComputeKernel<D> {
85    pub fn shared_memory_size(&self) -> usize {
86        let smems = self.body.shared_memories.iter();
87        let ends = smems.map(|it| it.offset + it.size());
88        ends.max().unwrap_or_default() as usize
89    }
90}
91
92impl<D: Dialect> Display for ComputeKernel<D> {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        let mut flags = self.flags.clone();
95        if !self.tensor_maps.is_empty() {
96            flags.inst_tma = true;
97        }
98
99        // Program Scope -----------------------------------------------------
100        D::compile_includes(f, &flags)?;
101        D::compile_type_definitions(f, &self.items, &self.scalars, &flags)?;
102        D::compile_polyfills(f, &flags)?;
103        D::compile_extensions(f, &self.extensions)?;
104
105        // Kernel signature --------------------------------------------------
106        D::compile_kernel_signature(
107            f,
108            &self.kernel_name,
109            &self.tensor_maps,
110            &self.buffers,
111            &self.scalars,
112            &self.flags,
113        )?;
114
115        // Body --------------------------------------------------------------
116        f.write_str(" {\n")?;
117        compile_cube_builtin_bindings_decl::<D>(f, &self.flags)?;
118        write!(f, "{}", self.body)?;
119        f.write_str("\n}")?;
120
121        Ok(())
122    }
123}
124
125pub fn type_definitions<D: Dialect>(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126    writeln!(f, "typedef unsigned int uint;")?;
127    writeln!(f, "typedef unsigned char uint8;")?;
128    writeln!(f, "typedef unsigned short uint16;")?;
129    writeln!(f, "typedef unsigned int uint32;")?;
130    writeln!(f, "typedef unsigned long long int uint64;")?;
131
132    writeln!(f, "typedef signed char int8;")?;
133    writeln!(f, "typedef signed short int16;")?;
134    writeln!(f, "typedef signed int int32;")?;
135    writeln!(f, "typedef signed long long int int64;")?;
136
137    Ok(())
138}
139
140pub fn type_vectorized_definitions<D: Dialect>(
141    f: &mut std::fmt::Formatter<'_>,
142    items: &HashSet<Item<D>>,
143) -> std::fmt::Result {
144    for item in items.iter() {
145        let elem = item.elem;
146        let size = item.vectorization;
147        let alignment = elem.size() * size;
148        if size > 1 {
149            write!(
150                f,
151                "
152struct __align__({alignment}) {item} {{"
153            )?;
154
155            for i in 0..size {
156                write!(
157                    f,
158                    "
159    {elem} i_{i};"
160                )?;
161            }
162
163            f.write_str("\n};")?;
164        }
165    }
166    Ok(())
167}
168
169pub fn type_scalar_definitions<D: Dialect>(
170    f: &mut std::fmt::Formatter<'_>,
171    scalars: &[(Elem<D>, usize)],
172) -> std::fmt::Result {
173    for (elem, count) in scalars.iter() {
174        writeln!(
175            f,
176            "
177struct scalars_{elem}_st {{
178{elem} x[{count}];
179}};"
180        )?;
181    }
182    Ok(())
183}
184
185pub fn type_info_definition<D: Dialect>(
186    f: &mut std::fmt::Formatter<'_>,
187    static_len: usize,
188) -> std::fmt::Result {
189    if static_len > 0 {
190        write!(
191            f,
192            "
193struct metadata_st {{
194uint x[{static_len}];
195}};
196"
197        )?;
198    }
199    Ok(())
200}
201
202pub fn compile_bindings<D: Dialect>(
203    f: &mut core::fmt::Formatter<'_>,
204    tensor_maps: &[Binding<D>],
205    buffers: &[Binding<D>],
206    trailing_comma: bool,
207    flags: &Flags,
208) -> core::fmt::Result {
209    write!(f, "    ")?;
210
211    let mut args = Vec::new();
212
213    args.extend(tensor_maps.iter().map(|binding| {
214        format!(
215            "const __grid_constant__ CUtensorMap tensor_map_{}",
216            binding.id
217        )
218    }));
219    args.extend(
220        tensor_maps
221            .iter()
222            .chain(buffers.iter())
223            .map(|binding| match binding.vis {
224                Visibility::Read => {
225                    format!("const {}* __restrict__ buffer_{}", binding.item, binding.id)
226                }
227                Visibility::ReadWrite => {
228                    format!("{}* buffer_{}", binding.item, binding.id)
229                }
230            }),
231    );
232    args.extend(
233        flags
234            .has_dynamic_meta
235            .then(|| format!("const uint32* __restrict__ {INFO_NAME}")),
236    );
237
238    write!(f, "{}", args.join(", "))?;
239    if trailing_comma {
240        f.write_str(", ")?;
241    }
242    Ok(())
243}
244
245pub fn compile_scalars_dynamic<D: Dialect>(
246    f: &mut std::fmt::Formatter<'_>,
247    scalars: &[(Elem<D>, usize)],
248) -> core::fmt::Result {
249    let scalar_inputs = scalars
250        .iter()
251        .map(|(elem, _)| format!("const {elem}* __restrict__ scalars_{elem}"));
252    let scalar_inputs = scalar_inputs.collect::<Vec<String>>();
253
254    write!(f, "{}", scalar_inputs.join(","))
255}
256
257pub fn compile_scalars_static<D: Dialect>(
258    f: &mut std::fmt::Formatter<'_>,
259    scalars: &[(Elem<D>, usize)],
260    flags: &Flags,
261) -> core::fmt::Result {
262    let mut scalar_inputs = Vec::new();
263
264    // Need to sort elements because of alignment when packing
265    // Metadata is align 4 so it needs to be spliced in the middle.
266    let scalars_of_size = |scalar_inputs: &mut Vec<String>, size: usize| {
267        for (elem, _) in scalars.iter().filter(|it| it.0.size() == size) {
268            scalar_inputs.push(format!(
269                "const __grid_constant__ scalars_{elem}_st scalars_{elem}"
270            ));
271        }
272    };
273
274    // Pack 64-bit aligned types first, since metadata is 32-bit aligned
275    scalars_of_size(&mut scalar_inputs, 8);
276
277    // Pack metadata
278    if flags.static_meta_length > 0 {
279        scalar_inputs.push(format!(
280            "const __grid_constant__ metadata_st {STATIC_INFO_NAME}"
281        ));
282    }
283
284    // Pack remaining scalars that are 4 bytes or below
285    for size in [4, 2, 1] {
286        scalars_of_size(&mut scalar_inputs, size);
287    }
288
289    write!(f, "{}", scalar_inputs.join(", "))
290}
291
292fn compile_cube_builtin_bindings_decl<D: Dialect>(
293    f: &mut core::fmt::Formatter<'_>,
294    settings: &Flags,
295) -> core::fmt::Result {
296    if settings.indexes.absolute_pos_tuple {
297        D::compile_absolute_pos_tuple_computation(f)?;
298    }
299
300    if settings.indexes.unit_pos {
301        D::compile_unit_pos_computation(f)?;
302    }
303
304    if settings.indexes.absolute_pos {
305        let variable = Variable::<D>::AbsolutePos;
306        let ty = variable.item();
307        let absolute_pos_x = Variable::<D>::AbsolutePosX;
308        let absolute_pos_y = Variable::<D>::AbsolutePosY;
309        let absolute_pos_z = Variable::<D>::AbsolutePosZ;
310        let cube_count_x = Variable::<D>::CubeCountX;
311        let cube_count_y = Variable::<D>::CubeCountY;
312        let cube_dim_x = Variable::<D>::CubeDimX;
313        let cube_dim_y = Variable::<D>::CubeDimY;
314        writeln!(
315            f,
316            "{ty} {variable} = ({absolute_pos_z} * {cube_count_x} * {cube_dim_x} * {cube_count_y} * {cube_dim_y}) + ({absolute_pos_y} * {cube_count_x} * {cube_dim_x}) + {absolute_pos_x};"
317        )?;
318    }
319
320    if settings.indexes.cube_dim {
321        let variable = Variable::<D>::CubeDim;
322        let ty = variable.item();
323        let cube_dim_x = Variable::<D>::CubeDimX;
324        let cube_dim_y = Variable::<D>::CubeDimY;
325        let cube_dim_z = Variable::<D>::CubeDimZ;
326        writeln!(
327            f,
328            "{ty} {variable} = {cube_dim_x} * {cube_dim_y} * {cube_dim_z};"
329        )?;
330    }
331
332    if settings.indexes.cube_count {
333        let variable = Variable::<D>::CubeCount;
334        let ty = variable.item();
335        let cube_count_x = Variable::<D>::CubeCountX;
336        let cube_count_y = Variable::<D>::CubeCountY;
337        let cube_count_z = Variable::<D>::CubeCountZ;
338        writeln!(
339            f,
340            "{ty} {variable} = {cube_count_x} * {cube_count_y} * {cube_count_z};"
341        )?;
342    }
343
344    if settings.indexes.cube_pos {
345        let variable = Variable::<D>::CubePos;
346        let ty = variable.item();
347        let cube_pos_x = Variable::<D>::CubePosX;
348        let cube_pos_y = Variable::<D>::CubePosY;
349        let cube_pos_z = Variable::<D>::CubePosZ;
350        let cube_count_x = Variable::<D>::CubeCountX;
351        let cube_count_y = Variable::<D>::CubeCountY;
352        writeln!(
353            f,
354            "{ty} {variable} = ({cube_pos_z} * {cube_count_y} * {cube_count_x}) + ({cube_pos_y} * {cube_count_x}) + {cube_pos_x};"
355        )?;
356    }
357
358    if settings.indexes.plane_dim_checked {
359        let plane_dim = Variable::<D>::PlaneDim;
360        let variable = Variable::<D>::PlaneDimChecked;
361        let ty = variable.item();
362        let cube_dim_x = Variable::<D>::CubeDimX;
363        let cube_dim_y = Variable::<D>::CubeDimY;
364        let cube_dim_z = Variable::<D>::CubeDimZ;
365        writeln!(
366            f,
367            "{ty} {variable} = min({plane_dim}, {cube_dim_x} * {cube_dim_y} * {cube_dim_z});"
368        )?;
369    }
370
371    if settings.indexes.cluster_pos {
372        f.write_str(
373            "
374cooperative_groups::cluster_group cluster = cooperative_groups::this_cluster();
375",
376        )?;
377    }
378
379    Ok(())
380}