Skip to main content

cubecl_cpp/shared/
kernel.rs

1use super::{Body, Component, Dialect, Elem, Flags, INFO_NAME, Item, Variable};
2use cubecl_core::{CubeDim, ir::Id, prelude::Visibility};
3
4use std::{collections::HashSet, fmt::Display};
5
6#[derive(Debug, PartialEq, Eq, Clone)]
7pub struct KernelArg<D: Dialect> {
8    pub id: Id,
9    pub item: Item<D>,
10    pub size: Option<usize>,
11    pub vis: Visibility,
12}
13
14#[derive(Debug, PartialEq, Eq, Clone)]
15pub enum SharedMemory<D: Dialect> {
16    Array {
17        index: Id,
18        item: Item<D>,
19        length: usize,
20        align: usize,
21        offset: usize,
22    },
23    Value {
24        index: Id,
25        item: Item<D>,
26        align: usize,
27        offset: usize,
28    },
29}
30
31impl<D: Dialect> SharedMemory<D> {
32    pub fn size(&self) -> usize {
33        match self {
34            SharedMemory::Array { item, length, .. } => *length * item.size(),
35            SharedMemory::Value { item, .. } => item.size(),
36        }
37    }
38
39    pub fn align(&self) -> usize {
40        match self {
41            SharedMemory::Array { align, .. } => *align,
42            SharedMemory::Value { align, .. } => *align,
43        }
44    }
45
46    pub fn offset(&self) -> usize {
47        match self {
48            SharedMemory::Array { offset, .. } => *offset,
49            SharedMemory::Value { offset, .. } => *offset,
50        }
51    }
52}
53
54#[derive(Debug, PartialEq, Clone)]
55pub struct ConstArray<D: Dialect> {
56    pub index: Id,
57    pub item: Item<D>,
58    pub size: u32,
59    pub values: Vec<Variable<D>>,
60}
61
62#[derive(Debug, PartialEq, Eq, Clone)]
63pub struct LocalArray<D: Dialect> {
64    pub index: Id,
65    pub item: Item<D>,
66    pub size: usize,
67}
68
69impl<D: Dialect> LocalArray<D> {
70    pub fn new(index: Id, item: Item<D>, size: usize) -> Self {
71        Self { index, item, size }
72    }
73}
74
75impl<D: Dialect> SharedMemory<D> {
76    pub fn new_array(index: Id, item: Item<D>, size: usize, align: usize) -> Self {
77        Self::Array {
78            index,
79            item,
80            length: size,
81            align,
82            offset: 0, // initialized later
83        }
84    }
85
86    pub fn new_value(index: Id, item: Item<D>, align: usize) -> Self {
87        Self::Value {
88            index,
89            item,
90            align,
91            offset: 0, // initialized later
92        }
93    }
94}
95
96#[derive(Debug, Clone)]
97pub struct ComputeKernel<D: Dialect> {
98    pub tensor_maps: Vec<KernelArg<D>>,
99    pub buffers: Vec<KernelArg<D>>,
100    pub scalars: Vec<(Elem<D>, usize)>,
101    pub info: cubecl_core::Info,
102    pub meta_static_len: usize,
103    pub body: Body<D>,
104    pub cube_dim: CubeDim,
105    pub cluster_dim: Option<CubeDim>,
106    pub extensions: Vec<D::Extension>,
107    pub flags: Flags<D>,
108    pub items: HashSet<super::Item<D>>,
109    pub kernel_name: String,
110}
111
112impl<D: Dialect> ComputeKernel<D> {
113    pub fn shared_memory_size(&self) -> usize {
114        let smems = self.body.shared_memories.iter();
115        let ends = smems.map(|it| it.offset() + it.size());
116        ends.max().unwrap_or_default()
117    }
118}
119
120impl<D: Dialect> Display for ComputeKernel<D> {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        let mut flags = self.flags.clone();
123        if !self.tensor_maps.is_empty() {
124            flags.inst_tma = true;
125        }
126
127        // Program Scope -----------------------------------------------------
128        D::compile_includes(f, &flags)?;
129        D::compile_type_definitions(f, &self.items, &self.scalars, &self.info, &flags)?;
130        D::compile_polyfills(f, &flags)?;
131        D::compile_extensions(f, &self.extensions)?;
132
133        // Kernel signature --------------------------------------------------
134        D::compile_kernel_signature(
135            f,
136            &self.kernel_name,
137            &self.tensor_maps,
138            &self.buffers,
139            &self.flags,
140        )?;
141
142        // Body --------------------------------------------------------------
143        f.write_str(" {\n")?;
144        compile_cube_builtin_bindings_decl::<D>(f, &self.flags)?;
145        write!(f, "{}", self.body)?;
146        f.write_str("\n}")?;
147
148        Ok(())
149    }
150}
151
152pub fn type_definitions<D: Dialect>(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153    writeln!(f, "typedef unsigned int uint;")?;
154    writeln!(f, "typedef unsigned char uint8;")?;
155    writeln!(f, "typedef unsigned short uint16;")?;
156    writeln!(f, "typedef unsigned int uint32;")?;
157    writeln!(f, "typedef unsigned long long int uint64;")?;
158
159    writeln!(f, "typedef signed char int8;")?;
160    writeln!(f, "typedef signed short int16;")?;
161    writeln!(f, "typedef signed int int32;")?;
162    writeln!(f, "typedef signed long long int int64;")?;
163
164    Ok(())
165}
166
167pub fn type_vectorized_definitions<D: Dialect>(
168    f: &mut std::fmt::Formatter<'_>,
169    items: &HashSet<Item<D>>,
170) -> std::fmt::Result {
171    for item in items.iter() {
172        let elem = item.elem;
173        let size = item.vectorization;
174        let alignment = elem.size() * size;
175        if size > 1 {
176            write!(
177                f,
178                "
179struct __align__({alignment}) {item} {{"
180            )?;
181
182            for i in 0..size {
183                write!(
184                    f,
185                    "
186    {elem} i_{i};"
187                )?;
188            }
189
190            f.write_str("\n};")?;
191        }
192    }
193    Ok(())
194}
195
196pub fn type_info_definition_sized<D: Dialect>(
197    f: &mut std::fmt::Formatter<'_>,
198    info: &cubecl_core::Info,
199    scalars: &[(Elem<D>, usize)],
200    address_type: Item<D>,
201) -> std::fmt::Result {
202    let scalars = info
203        .scalars
204        .iter()
205        .zip(scalars)
206        .map(|(field, (ty, _))| format!("{ty} scalars_{ty}[{}];", field.padded_size()))
207        .collect::<Vec<_>>()
208        .join("\n");
209    let static_meta = info
210        .sized_meta
211        .as_ref()
212        .map(|field| format!("{address_type} static_meta[{}];", field.padded_size()))
213        .unwrap_or_default();
214    write!(
215        f,
216        "
217struct info_st {{
218    {scalars}{static_meta}
219}};
220"
221    )
222}
223
224pub fn compile_bindings<D: Dialect>(
225    f: &mut core::fmt::Formatter<'_>,
226    tensor_maps: &[KernelArg<D>],
227    buffers: &[KernelArg<D>],
228    trailing_comma: bool,
229) -> core::fmt::Result {
230    write!(f, "    ")?;
231
232    let mut args = Vec::new();
233
234    args.extend(tensor_maps.iter().map(|binding| {
235        format!(
236            "const __grid_constant__ CUtensorMap tensor_map_{}",
237            binding.id
238        )
239    }));
240    args.extend(
241        tensor_maps
242            .iter()
243            .chain(buffers.iter())
244            .map(|binding| match binding.vis {
245                Visibility::Read if !binding.item.is_atomic() => {
246                    format!("const {}* __restrict__ buffer_{}", binding.item, binding.id)
247                }
248                Visibility::Read => {
249                    format!("{}* buffer_{}", binding.item, binding.id)
250                }
251                Visibility::ReadWrite => {
252                    format!("{}* buffer_{}", binding.item, binding.id)
253                }
254            }),
255    );
256
257    write!(f, "{}", args.join(", "))?;
258    if trailing_comma {
259        f.write_str(", ")?;
260    }
261    Ok(())
262}
263
264pub fn compile_info_dynamic<D: Dialect>(
265    f: &mut std::fmt::Formatter<'_>,
266    flags: &Flags<D>,
267) -> core::fmt::Result {
268    if flags.has_info {
269        write!(f, "const info_st* __restrict__ {INFO_NAME}_ptr")
270    } else {
271        Ok(())
272    }
273}
274
275pub fn compile_info_static<D: Dialect>(
276    f: &mut std::fmt::Formatter<'_>,
277    flags: &Flags<D>,
278) -> core::fmt::Result {
279    let mut inputs = Vec::new();
280
281    if flags.has_dynamic_meta {
282        inputs.push(format!(
283            "const {}* __restrict__ dynamic_meta",
284            flags.address_type
285        ))
286    }
287
288    if flags.has_info {
289        inputs.push(format!("const __grid_constant__ info_st {INFO_NAME}"));
290    }
291
292    write!(f, "{}", inputs.join(", "))
293}
294
295fn compile_cube_builtin_bindings_decl<D: Dialect>(
296    f: &mut core::fmt::Formatter<'_>,
297    settings: &Flags<D>,
298) -> core::fmt::Result {
299    if settings.indexes.absolute_pos_tuple {
300        D::compile_absolute_pos_tuple_computation(f)?;
301    }
302
303    if settings.indexes.unit_pos {
304        D::compile_unit_pos_computation(f)?;
305    }
306
307    if settings.indexes.absolute_pos {
308        let variable = Variable::<D>::AbsolutePos(settings.address_type.elem);
309        let ty = variable.item();
310        let absolute_pos_x = Variable::<D>::AbsolutePosX.fmt_cast_to(ty);
311        let absolute_pos_y = Variable::<D>::AbsolutePosY.fmt_cast_to(ty);
312        let absolute_pos_z = Variable::<D>::AbsolutePosZ.fmt_cast_to(ty);
313        let cube_count_x = Variable::<D>::CubeCountX.fmt_cast_to(ty);
314        let cube_count_y = Variable::<D>::CubeCountY.fmt_cast_to(ty);
315        let cube_dim_x = Variable::<D>::CubeDimX.fmt_cast_to(ty);
316        let cube_dim_y = Variable::<D>::CubeDimY.fmt_cast_to(ty);
317        writeln!(
318            f,
319            "{ty} {variable} = (
320                {absolute_pos_z} * {cube_count_x} * {cube_dim_x} * {cube_count_y} * {cube_dim_y})
321                + ({absolute_pos_y} * {cube_count_x} * {cube_dim_x})
322                + {absolute_pos_x};"
323        )?;
324    }
325
326    if settings.indexes.cube_dim {
327        let variable = Variable::<D>::CubeDim;
328        let ty = variable.item();
329        let cube_dim_x = Variable::<D>::CubeDimX;
330        let cube_dim_y = Variable::<D>::CubeDimY;
331        let cube_dim_z = Variable::<D>::CubeDimZ;
332        writeln!(
333            f,
334            "{ty} {variable} = {cube_dim_x} * {cube_dim_y} * {cube_dim_z};"
335        )?;
336    }
337
338    if settings.indexes.cube_count {
339        let variable = Variable::<D>::CubeCount(settings.address_type.elem);
340        let ty = variable.item();
341        let cube_count_x = Variable::<D>::CubeCountX.fmt_cast_to(ty);
342        let cube_count_y = Variable::<D>::CubeCountY.fmt_cast_to(ty);
343        let cube_count_z = Variable::<D>::CubeCountZ.fmt_cast_to(ty);
344        writeln!(
345            f,
346            "{ty} {variable} = {cube_count_x} * {cube_count_y} * {cube_count_z};"
347        )?;
348    }
349
350    if settings.indexes.cube_pos {
351        let variable = Variable::<D>::CubePos(settings.address_type.elem);
352        let ty = variable.item();
353        let cube_pos_x = Variable::<D>::CubePosX.fmt_cast_to(ty);
354        let cube_pos_y = Variable::<D>::CubePosY.fmt_cast_to(ty);
355        let cube_pos_z = Variable::<D>::CubePosZ.fmt_cast_to(ty);
356        let cube_count_x = Variable::<D>::CubeCountX.fmt_cast_to(ty);
357        let cube_count_y = Variable::<D>::CubeCountY.fmt_cast_to(ty);
358        writeln!(
359            f,
360            "{ty} {variable} = ({cube_pos_z} * {cube_count_y} * {cube_count_x}) + ({cube_pos_y} * {cube_count_x}) + {cube_pos_x};"
361        )?;
362    }
363
364    if settings.indexes.plane_dim_checked {
365        let plane_dim = Variable::<D>::PlaneDim;
366        let variable = Variable::<D>::PlaneDimChecked;
367        let ty = variable.item();
368        let cube_dim_x = Variable::<D>::CubeDimX;
369        let cube_dim_y = Variable::<D>::CubeDimY;
370        let cube_dim_z = Variable::<D>::CubeDimZ;
371        writeln!(
372            f,
373            "{ty} {variable} = min({plane_dim}, {cube_dim_x} * {cube_dim_y} * {cube_dim_z});"
374        )?;
375    }
376
377    if settings.indexes.cluster_pos {
378        f.write_str(
379            "
380cooperative_groups::cluster_group cluster = cooperative_groups::this_cluster();
381",
382        )?;
383    }
384
385    Ok(())
386}