cubecl_cpp/shared/
kernel.rs

1use crate::shared::STATIC_INFO_NAME;
2
3use super::{Body, Component, Dialect, Elem, Flags, INFO_NAME, Item, Variable};
4use cubecl_core::{
5    CubeDim,
6    ir::Id,
7    prelude::{Location, Visibility},
8};
9
10use std::{collections::HashSet, fmt::Display};
11
12#[derive(Debug, PartialEq, Eq, Clone)]
13pub struct Binding<D: Dialect> {
14    pub id: Id,
15    pub item: Item<D>,
16    pub location: Location,
17    pub size: Option<usize>,
18    pub vis: Visibility,
19}
20
21#[derive(Debug, PartialEq, Eq, Clone)]
22pub enum SharedMemory<D: Dialect> {
23    Array {
24        index: Id,
25        item: Item<D>,
26        length: u32,
27        align: u32,
28        offset: u32,
29    },
30    Value {
31        index: Id,
32        item: Item<D>,
33        align: u32,
34        offset: u32,
35    },
36}
37
38impl<D: Dialect> SharedMemory<D> {
39    pub fn size(&self) -> u32 {
40        match self {
41            SharedMemory::Array { item, length, .. } => *length * item.size() as u32,
42            SharedMemory::Value { item, .. } => item.size() as u32,
43        }
44    }
45
46    pub fn align(&self) -> u32 {
47        match self {
48            SharedMemory::Array { align, .. } => *align,
49            SharedMemory::Value { align, .. } => *align,
50        }
51    }
52
53    pub fn offset(&self) -> u32 {
54        match self {
55            SharedMemory::Array { offset, .. } => *offset,
56            SharedMemory::Value { offset, .. } => *offset,
57        }
58    }
59}
60
61#[derive(Debug, PartialEq, Clone)]
62pub struct ConstArray<D: Dialect> {
63    pub index: Id,
64    pub item: Item<D>,
65    pub size: u32,
66    pub values: Vec<Variable<D>>,
67}
68
69#[derive(Debug, PartialEq, Eq, Clone)]
70pub struct LocalArray<D: Dialect> {
71    pub index: Id,
72    pub item: Item<D>,
73    pub size: u32,
74}
75
76impl<D: Dialect> LocalArray<D> {
77    pub fn new(index: Id, item: Item<D>, size: u32) -> Self {
78        Self { index, item, size }
79    }
80}
81
82impl<D: Dialect> SharedMemory<D> {
83    pub fn new_array(index: Id, item: Item<D>, size: u32, align: u32) -> Self {
84        Self::Array {
85            index,
86            item,
87            length: size,
88            align,
89            offset: 0, // initialized later
90        }
91    }
92
93    pub fn new_value(index: Id, item: Item<D>, align: u32) -> Self {
94        Self::Value {
95            index,
96            item,
97            align,
98            offset: 0, // initialized later
99        }
100    }
101}
102
103#[derive(Debug, Clone)]
104pub struct ComputeKernel<D: Dialect> {
105    pub tensor_maps: Vec<Binding<D>>,
106    pub buffers: Vec<Binding<D>>,
107    pub scalars: Vec<(Elem<D>, usize)>,
108    pub meta_static_len: usize,
109    pub body: Body<D>,
110    pub cube_dim: CubeDim,
111    pub cluster_dim: Option<CubeDim>,
112    pub extensions: Vec<D::Extension>,
113    pub flags: Flags,
114    pub items: HashSet<super::Item<D>>,
115    pub kernel_name: String,
116}
117
118impl<D: Dialect> ComputeKernel<D> {
119    pub fn shared_memory_size(&self) -> usize {
120        let smems = self.body.shared_memories.iter();
121        let ends = smems.map(|it| it.offset() + it.size());
122        ends.max().unwrap_or_default() as usize
123    }
124}
125
126impl<D: Dialect> Display for ComputeKernel<D> {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        let mut flags = self.flags.clone();
129        if !self.tensor_maps.is_empty() {
130            flags.inst_tma = true;
131        }
132
133        // Program Scope -----------------------------------------------------
134        D::compile_includes(f, &flags)?;
135        D::compile_type_definitions(f, &self.items, &self.scalars, &flags)?;
136        D::compile_polyfills(f, &flags)?;
137        D::compile_extensions(f, &self.extensions)?;
138
139        // Kernel signature --------------------------------------------------
140        D::compile_kernel_signature(
141            f,
142            &self.kernel_name,
143            &self.tensor_maps,
144            &self.buffers,
145            &self.scalars,
146            &self.flags,
147        )?;
148
149        // Body --------------------------------------------------------------
150        f.write_str(" {\n")?;
151        compile_cube_builtin_bindings_decl::<D>(f, &self.flags)?;
152        write!(f, "{}", self.body)?;
153        f.write_str("\n}")?;
154
155        Ok(())
156    }
157}
158
159pub fn type_definitions<D: Dialect>(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160    writeln!(f, "typedef unsigned int uint;")?;
161    writeln!(f, "typedef unsigned char uint8;")?;
162    writeln!(f, "typedef unsigned short uint16;")?;
163    writeln!(f, "typedef unsigned int uint32;")?;
164    writeln!(f, "typedef unsigned long long int uint64;")?;
165
166    writeln!(f, "typedef signed char int8;")?;
167    writeln!(f, "typedef signed short int16;")?;
168    writeln!(f, "typedef signed int int32;")?;
169    writeln!(f, "typedef signed long long int int64;")?;
170
171    Ok(())
172}
173
174pub fn type_vectorized_definitions<D: Dialect>(
175    f: &mut std::fmt::Formatter<'_>,
176    items: &HashSet<Item<D>>,
177) -> std::fmt::Result {
178    for item in items.iter() {
179        let elem = item.elem;
180        let size = item.vectorization;
181        let alignment = elem.size() * size;
182        if size > 1 {
183            write!(
184                f,
185                "
186struct __align__({alignment}) {item} {{"
187            )?;
188
189            for i in 0..size {
190                write!(
191                    f,
192                    "
193    {elem} i_{i};"
194                )?;
195            }
196
197            f.write_str("\n};")?;
198        }
199    }
200    Ok(())
201}
202
203pub fn type_scalar_definitions<D: Dialect>(
204    f: &mut std::fmt::Formatter<'_>,
205    scalars: &[(Elem<D>, usize)],
206) -> std::fmt::Result {
207    for (elem, count) in scalars.iter() {
208        writeln!(
209            f,
210            "
211struct scalars_{elem}_st {{
212{elem} x[{count}];
213}};"
214        )?;
215    }
216    Ok(())
217}
218
219pub fn type_info_definition<D: Dialect>(
220    f: &mut std::fmt::Formatter<'_>,
221    static_len: usize,
222) -> std::fmt::Result {
223    if static_len > 0 {
224        write!(
225            f,
226            "
227struct metadata_st {{
228uint x[{static_len}];
229}};
230"
231        )?;
232    }
233    Ok(())
234}
235
236pub fn compile_bindings<D: Dialect>(
237    f: &mut core::fmt::Formatter<'_>,
238    tensor_maps: &[Binding<D>],
239    buffers: &[Binding<D>],
240    trailing_comma: bool,
241    flags: &Flags,
242) -> core::fmt::Result {
243    write!(f, "    ")?;
244
245    let mut args = Vec::new();
246
247    args.extend(tensor_maps.iter().map(|binding| {
248        format!(
249            "const __grid_constant__ CUtensorMap tensor_map_{}",
250            binding.id
251        )
252    }));
253    args.extend(
254        tensor_maps
255            .iter()
256            .chain(buffers.iter())
257            .map(|binding| match binding.vis {
258                Visibility::Read => {
259                    format!("const {}* __restrict__ buffer_{}", binding.item, binding.id)
260                }
261                Visibility::ReadWrite => {
262                    format!("{}* buffer_{}", binding.item, binding.id)
263                }
264            }),
265    );
266    args.extend(
267        flags
268            .has_dynamic_meta
269            .then(|| format!("const uint32* __restrict__ {INFO_NAME}")),
270    );
271
272    write!(f, "{}", args.join(", "))?;
273    if trailing_comma {
274        f.write_str(", ")?;
275    }
276    Ok(())
277}
278
279pub fn compile_scalars_dynamic<D: Dialect>(
280    f: &mut std::fmt::Formatter<'_>,
281    scalars: &[(Elem<D>, usize)],
282) -> core::fmt::Result {
283    let scalar_inputs = scalars
284        .iter()
285        .map(|(elem, _)| format!("const {elem}* __restrict__ scalars_{elem}"));
286    let scalar_inputs = scalar_inputs.collect::<Vec<String>>();
287
288    write!(f, "{}", scalar_inputs.join(","))
289}
290
291pub fn compile_scalars_static<D: Dialect>(
292    f: &mut std::fmt::Formatter<'_>,
293    scalars: &[(Elem<D>, usize)],
294    flags: &Flags,
295) -> core::fmt::Result {
296    let mut scalar_inputs = Vec::new();
297
298    // Need to sort elements because of alignment when packing
299    // Metadata is align 4 so it needs to be spliced in the middle.
300    let scalars_of_size = |scalar_inputs: &mut Vec<String>, size: usize| {
301        for (elem, _) in scalars.iter().filter(|it| it.0.size() == size) {
302            scalar_inputs.push(format!(
303                "const __grid_constant__ scalars_{elem}_st scalars_{elem}"
304            ));
305        }
306    };
307
308    // Pack 64-bit aligned types first, since metadata is 32-bit aligned
309    scalars_of_size(&mut scalar_inputs, 8);
310
311    // Pack metadata
312    if flags.static_meta_length > 0 {
313        scalar_inputs.push(format!(
314            "const __grid_constant__ metadata_st {STATIC_INFO_NAME}"
315        ));
316    }
317
318    // Pack remaining scalars that are 4 bytes or below
319    for size in [4, 2, 1] {
320        scalars_of_size(&mut scalar_inputs, size);
321    }
322
323    write!(f, "{}", scalar_inputs.join(", "))
324}
325
326fn compile_cube_builtin_bindings_decl<D: Dialect>(
327    f: &mut core::fmt::Formatter<'_>,
328    settings: &Flags,
329) -> core::fmt::Result {
330    if settings.indexes.absolute_pos_tuple {
331        D::compile_absolute_pos_tuple_computation(f)?;
332    }
333
334    if settings.indexes.unit_pos {
335        D::compile_unit_pos_computation(f)?;
336    }
337
338    if settings.indexes.absolute_pos {
339        let variable = Variable::<D>::AbsolutePos;
340        let ty = variable.item();
341        let absolute_pos_x = Variable::<D>::AbsolutePosX;
342        let absolute_pos_y = Variable::<D>::AbsolutePosY;
343        let absolute_pos_z = Variable::<D>::AbsolutePosZ;
344        let cube_count_x = Variable::<D>::CubeCountX;
345        let cube_count_y = Variable::<D>::CubeCountY;
346        let cube_dim_x = Variable::<D>::CubeDimX;
347        let cube_dim_y = Variable::<D>::CubeDimY;
348        writeln!(
349            f,
350            "{ty} {variable} = ({absolute_pos_z} * {cube_count_x} * {cube_dim_x} * {cube_count_y} * {cube_dim_y}) + ({absolute_pos_y} * {cube_count_x} * {cube_dim_x}) + {absolute_pos_x};"
351        )?;
352    }
353
354    if settings.indexes.cube_dim {
355        let variable = Variable::<D>::CubeDim;
356        let ty = variable.item();
357        let cube_dim_x = Variable::<D>::CubeDimX;
358        let cube_dim_y = Variable::<D>::CubeDimY;
359        let cube_dim_z = Variable::<D>::CubeDimZ;
360        writeln!(
361            f,
362            "{ty} {variable} = {cube_dim_x} * {cube_dim_y} * {cube_dim_z};"
363        )?;
364    }
365
366    if settings.indexes.cube_count {
367        let variable = Variable::<D>::CubeCount;
368        let ty = variable.item();
369        let cube_count_x = Variable::<D>::CubeCountX;
370        let cube_count_y = Variable::<D>::CubeCountY;
371        let cube_count_z = Variable::<D>::CubeCountZ;
372        writeln!(
373            f,
374            "{ty} {variable} = {cube_count_x} * {cube_count_y} * {cube_count_z};"
375        )?;
376    }
377
378    if settings.indexes.cube_pos {
379        let variable = Variable::<D>::CubePos;
380        let ty = variable.item();
381        let cube_pos_x = Variable::<D>::CubePosX;
382        let cube_pos_y = Variable::<D>::CubePosY;
383        let cube_pos_z = Variable::<D>::CubePosZ;
384        let cube_count_x = Variable::<D>::CubeCountX;
385        let cube_count_y = Variable::<D>::CubeCountY;
386        writeln!(
387            f,
388            "{ty} {variable} = ({cube_pos_z} * {cube_count_y} * {cube_count_x}) + ({cube_pos_y} * {cube_count_x}) + {cube_pos_x};"
389        )?;
390    }
391
392    if settings.indexes.plane_dim_checked {
393        let plane_dim = Variable::<D>::PlaneDim;
394        let variable = Variable::<D>::PlaneDimChecked;
395        let ty = variable.item();
396        let cube_dim_x = Variable::<D>::CubeDimX;
397        let cube_dim_y = Variable::<D>::CubeDimY;
398        let cube_dim_z = Variable::<D>::CubeDimZ;
399        writeln!(
400            f,
401            "{ty} {variable} = min({plane_dim}, {cube_dim_x} * {cube_dim_y} * {cube_dim_z});"
402        )?;
403    }
404
405    if settings.indexes.cluster_pos {
406        f.write_str(
407            "
408cooperative_groups::cluster_group cluster = cooperative_groups::this_cluster();
409",
410        )?;
411    }
412
413    Ok(())
414}