cubecl_cpp/shared/
kernel.rs

1use crate::shared::STATIC_INFO_NAME;
2
3use super::{Body, Component, Dialect, Elem, Flags, INFO_NAME, Item, Variable};
4use cubecl_core::{
5    CubeDim,
6    ir::Id,
7    prelude::{Location, Visibility},
8};
9
10use std::{collections::HashSet, fmt::Display};
11
12#[derive(Debug, PartialEq, Eq, Clone)]
13pub struct Binding<D: Dialect> {
14    pub id: Id,
15    pub item: Item<D>,
16    pub location: Location,
17    pub size: Option<usize>,
18    pub vis: Visibility,
19}
20
21#[derive(Debug, PartialEq, Eq, Clone)]
22pub enum SharedMemory<D: Dialect> {
23    Array {
24        index: Id,
25        item: Item<D>,
26        length: usize,
27        align: usize,
28        offset: usize,
29    },
30    Value {
31        index: Id,
32        item: Item<D>,
33        align: usize,
34        offset: usize,
35    },
36}
37
38impl<D: Dialect> SharedMemory<D> {
39    pub fn size(&self) -> usize {
40        match self {
41            SharedMemory::Array { item, length, .. } => *length * item.size(),
42            SharedMemory::Value { item, .. } => item.size(),
43        }
44    }
45
46    pub fn align(&self) -> usize {
47        match self {
48            SharedMemory::Array { align, .. } => *align,
49            SharedMemory::Value { align, .. } => *align,
50        }
51    }
52
53    pub fn offset(&self) -> usize {
54        match self {
55            SharedMemory::Array { offset, .. } => *offset,
56            SharedMemory::Value { offset, .. } => *offset,
57        }
58    }
59}
60
61#[derive(Debug, PartialEq, Clone)]
62pub struct ConstArray<D: Dialect> {
63    pub index: Id,
64    pub item: Item<D>,
65    pub size: u32,
66    pub values: Vec<Variable<D>>,
67}
68
69#[derive(Debug, PartialEq, Eq, Clone)]
70pub struct LocalArray<D: Dialect> {
71    pub index: Id,
72    pub item: Item<D>,
73    pub size: usize,
74}
75
76impl<D: Dialect> LocalArray<D> {
77    pub fn new(index: Id, item: Item<D>, size: usize) -> Self {
78        Self { index, item, size }
79    }
80}
81
82impl<D: Dialect> SharedMemory<D> {
83    pub fn new_array(index: Id, item: Item<D>, size: usize, align: usize) -> Self {
84        Self::Array {
85            index,
86            item,
87            length: size,
88            align,
89            offset: 0, // initialized later
90        }
91    }
92
93    pub fn new_value(index: Id, item: Item<D>, align: usize) -> Self {
94        Self::Value {
95            index,
96            item,
97            align,
98            offset: 0, // initialized later
99        }
100    }
101}
102
103#[derive(Debug, Clone)]
104pub struct ComputeKernel<D: Dialect> {
105    pub tensor_maps: Vec<Binding<D>>,
106    pub buffers: Vec<Binding<D>>,
107    pub scalars: Vec<(Elem<D>, usize)>,
108    pub meta_static_len: usize,
109    pub body: Body<D>,
110    pub cube_dim: CubeDim,
111    pub cluster_dim: Option<CubeDim>,
112    pub extensions: Vec<D::Extension>,
113    pub flags: Flags<D>,
114    pub items: HashSet<super::Item<D>>,
115    pub kernel_name: String,
116}
117
118impl<D: Dialect> ComputeKernel<D> {
119    pub fn shared_memory_size(&self) -> usize {
120        let smems = self.body.shared_memories.iter();
121        let ends = smems.map(|it| it.offset() + it.size());
122        ends.max().unwrap_or_default()
123    }
124}
125
126impl<D: Dialect> Display for ComputeKernel<D> {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        let mut flags = self.flags.clone();
129        if !self.tensor_maps.is_empty() {
130            flags.inst_tma = true;
131        }
132
133        // Program Scope -----------------------------------------------------
134        D::compile_includes(f, &flags)?;
135        D::compile_type_definitions(f, &self.items, &self.scalars, &flags)?;
136        D::compile_polyfills(f, &flags)?;
137        D::compile_extensions(f, &self.extensions)?;
138
139        // Kernel signature --------------------------------------------------
140        D::compile_kernel_signature(
141            f,
142            &self.kernel_name,
143            &self.tensor_maps,
144            &self.buffers,
145            &self.scalars,
146            &self.flags,
147        )?;
148
149        // Body --------------------------------------------------------------
150        f.write_str(" {\n")?;
151        compile_cube_builtin_bindings_decl::<D>(f, &self.flags)?;
152        write!(f, "{}", self.body)?;
153        f.write_str("\n}")?;
154
155        Ok(())
156    }
157}
158
159pub fn type_definitions<D: Dialect>(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160    writeln!(f, "typedef unsigned int uint;")?;
161    writeln!(f, "typedef unsigned char uint8;")?;
162    writeln!(f, "typedef unsigned short uint16;")?;
163    writeln!(f, "typedef unsigned int uint32;")?;
164    writeln!(f, "typedef unsigned long long int uint64;")?;
165
166    writeln!(f, "typedef signed char int8;")?;
167    writeln!(f, "typedef signed short int16;")?;
168    writeln!(f, "typedef signed int int32;")?;
169    writeln!(f, "typedef signed long long int int64;")?;
170
171    Ok(())
172}
173
174pub fn type_vectorized_definitions<D: Dialect>(
175    f: &mut std::fmt::Formatter<'_>,
176    items: &HashSet<Item<D>>,
177) -> std::fmt::Result {
178    for item in items.iter() {
179        let elem = item.elem;
180        let size = item.vectorization;
181        let alignment = elem.size() * size;
182        if size > 1 {
183            write!(
184                f,
185                "
186struct __align__({alignment}) {item} {{"
187            )?;
188
189            for i in 0..size {
190                write!(
191                    f,
192                    "
193    {elem} i_{i};"
194                )?;
195            }
196
197            f.write_str("\n};")?;
198        }
199    }
200    Ok(())
201}
202
203pub fn type_scalar_definitions<D: Dialect>(
204    f: &mut std::fmt::Formatter<'_>,
205    scalars: &[(Elem<D>, usize)],
206) -> std::fmt::Result {
207    for (elem, count) in scalars.iter() {
208        writeln!(
209            f,
210            "
211struct scalars_{elem}_st {{
212{elem} x[{count}];
213}};"
214        )?;
215    }
216    Ok(())
217}
218
219pub fn type_info_definition<D: Dialect>(
220    f: &mut std::fmt::Formatter<'_>,
221    static_len: usize,
222    address_type: Item<D>,
223) -> std::fmt::Result {
224    if static_len > 0 {
225        write!(
226            f,
227            "
228struct metadata_st {{
229    {address_type} x[{static_len}];
230}};
231"
232        )?;
233    }
234    Ok(())
235}
236
237pub fn compile_bindings<D: Dialect>(
238    f: &mut core::fmt::Formatter<'_>,
239    tensor_maps: &[Binding<D>],
240    buffers: &[Binding<D>],
241    trailing_comma: bool,
242    flags: &Flags<D>,
243) -> core::fmt::Result {
244    write!(f, "    ")?;
245
246    let mut args = Vec::new();
247
248    args.extend(tensor_maps.iter().map(|binding| {
249        format!(
250            "const __grid_constant__ CUtensorMap tensor_map_{}",
251            binding.id
252        )
253    }));
254    args.extend(
255        tensor_maps
256            .iter()
257            .chain(buffers.iter())
258            .map(|binding| match binding.vis {
259                Visibility::Read => {
260                    format!("const {}* __restrict__ buffer_{}", binding.item, binding.id)
261                }
262                Visibility::ReadWrite => {
263                    format!("{}* buffer_{}", binding.item, binding.id)
264                }
265            }),
266    );
267    args.extend(
268        flags
269            .has_dynamic_meta
270            .then(|| format!("const {}* __restrict__ {INFO_NAME}", flags.address_type)),
271    );
272
273    write!(f, "{}", args.join(", "))?;
274    if trailing_comma {
275        f.write_str(", ")?;
276    }
277    Ok(())
278}
279
280pub fn compile_scalars_dynamic<D: Dialect>(
281    f: &mut std::fmt::Formatter<'_>,
282    scalars: &[(Elem<D>, usize)],
283) -> core::fmt::Result {
284    let scalar_inputs = scalars
285        .iter()
286        .map(|(elem, _)| format!("const {elem}* __restrict__ scalars_{elem}"));
287    let scalar_inputs = scalar_inputs.collect::<Vec<String>>();
288
289    write!(f, "{}", scalar_inputs.join(","))
290}
291
292pub fn compile_scalars_static<D: Dialect>(
293    f: &mut std::fmt::Formatter<'_>,
294    scalars: &[(Elem<D>, usize)],
295    flags: &Flags<D>,
296) -> core::fmt::Result {
297    let mut scalar_inputs = Vec::new();
298
299    // Need to sort elements because of alignment when packing
300    // Metadata is align 4 so it needs to be spliced in the middle.
301    let scalars_of_size = |scalar_inputs: &mut Vec<String>, size: usize| {
302        for (elem, _) in scalars.iter().filter(|it| it.0.size() == size) {
303            scalar_inputs.push(format!(
304                "const __grid_constant__ scalars_{elem}_st scalars_{elem}"
305            ));
306        }
307    };
308
309    // Pack 64-bit aligned types first, since metadata is 32-bit aligned
310    scalars_of_size(&mut scalar_inputs, 8);
311
312    // Pack metadata
313    if flags.static_meta_length > 0 {
314        scalar_inputs.push(format!(
315            "const __grid_constant__ metadata_st {STATIC_INFO_NAME}"
316        ));
317    }
318
319    // Pack remaining scalars that are 4 bytes or below
320    for size in [4, 2, 1] {
321        scalars_of_size(&mut scalar_inputs, size);
322    }
323
324    write!(f, "{}", scalar_inputs.join(", "))
325}
326
327fn compile_cube_builtin_bindings_decl<D: Dialect>(
328    f: &mut core::fmt::Formatter<'_>,
329    settings: &Flags<D>,
330) -> core::fmt::Result {
331    if settings.indexes.absolute_pos_tuple {
332        D::compile_absolute_pos_tuple_computation(f)?;
333    }
334
335    if settings.indexes.unit_pos {
336        D::compile_unit_pos_computation(f)?;
337    }
338
339    if settings.indexes.absolute_pos {
340        let variable = Variable::<D>::AbsolutePos(settings.address_type.elem);
341        let ty = variable.item();
342        let absolute_pos_x = Variable::<D>::AbsolutePosX.fmt_cast_to(ty);
343        let absolute_pos_y = Variable::<D>::AbsolutePosY.fmt_cast_to(ty);
344        let absolute_pos_z = Variable::<D>::AbsolutePosZ.fmt_cast_to(ty);
345        let cube_count_x = Variable::<D>::CubeCountX.fmt_cast_to(ty);
346        let cube_count_y = Variable::<D>::CubeCountY.fmt_cast_to(ty);
347        let cube_dim_x = Variable::<D>::CubeDimX.fmt_cast_to(ty);
348        let cube_dim_y = Variable::<D>::CubeDimY.fmt_cast_to(ty);
349        writeln!(
350            f,
351            "{ty} {variable} = (
352                {absolute_pos_z} * {cube_count_x} * {cube_dim_x} * {cube_count_y} * {cube_dim_y})
353                + ({absolute_pos_y} * {cube_count_x} * {cube_dim_x})
354                + {absolute_pos_x};"
355        )?;
356    }
357
358    if settings.indexes.cube_dim {
359        let variable = Variable::<D>::CubeDim;
360        let ty = variable.item();
361        let cube_dim_x = Variable::<D>::CubeDimX;
362        let cube_dim_y = Variable::<D>::CubeDimY;
363        let cube_dim_z = Variable::<D>::CubeDimZ;
364        writeln!(
365            f,
366            "{ty} {variable} = {cube_dim_x} * {cube_dim_y} * {cube_dim_z};"
367        )?;
368    }
369
370    if settings.indexes.cube_count {
371        let variable = Variable::<D>::CubeCount(settings.address_type.elem);
372        let ty = variable.item();
373        let cube_count_x = Variable::<D>::CubeCountX.fmt_cast_to(ty);
374        let cube_count_y = Variable::<D>::CubeCountY.fmt_cast_to(ty);
375        let cube_count_z = Variable::<D>::CubeCountZ.fmt_cast_to(ty);
376        writeln!(
377            f,
378            "{ty} {variable} = {cube_count_x} * {cube_count_y} * {cube_count_z};"
379        )?;
380    }
381
382    if settings.indexes.cube_pos {
383        let variable = Variable::<D>::CubePos(settings.address_type.elem);
384        let ty = variable.item();
385        let cube_pos_x = Variable::<D>::CubePosX.fmt_cast_to(ty);
386        let cube_pos_y = Variable::<D>::CubePosY.fmt_cast_to(ty);
387        let cube_pos_z = Variable::<D>::CubePosZ.fmt_cast_to(ty);
388        let cube_count_x = Variable::<D>::CubeCountX.fmt_cast_to(ty);
389        let cube_count_y = Variable::<D>::CubeCountY.fmt_cast_to(ty);
390        writeln!(
391            f,
392            "{ty} {variable} = ({cube_pos_z} * {cube_count_y} * {cube_count_x}) + ({cube_pos_y} * {cube_count_x}) + {cube_pos_x};"
393        )?;
394    }
395
396    if settings.indexes.plane_dim_checked {
397        let plane_dim = Variable::<D>::PlaneDim;
398        let variable = Variable::<D>::PlaneDimChecked;
399        let ty = variable.item();
400        let cube_dim_x = Variable::<D>::CubeDimX;
401        let cube_dim_y = Variable::<D>::CubeDimY;
402        let cube_dim_z = Variable::<D>::CubeDimZ;
403        writeln!(
404            f,
405            "{ty} {variable} = min({plane_dim}, {cube_dim_x} * {cube_dim_y} * {cube_dim_z});"
406        )?;
407    }
408
409    if settings.indexes.cluster_pos {
410        f.write_str(
411            "
412cooperative_groups::cluster_group cluster = cooperative_groups::this_cluster();
413",
414        )?;
415    }
416
417    Ok(())
418}