cubecl_cpp/shared/
kernel.rs

1use super::{Body, Dialect, Item, Variable};
2use cubecl_core::{
3    ir::{CubeDim, Id, Visibility},
4    CompilerRepresentation,
5};
6use std::{collections::HashSet, fmt::Display};
7
8#[derive(Debug, PartialEq, Eq, Clone)]
9pub struct Binding<D: Dialect> {
10    pub item: Item<D>,
11    pub size: Option<usize>,
12    pub vis: Visibility,
13}
14
15#[derive(Debug, PartialEq, Eq, Clone)]
16pub struct SharedMemory<D: Dialect> {
17    pub index: Id,
18    pub item: Item<D>,
19    pub size: u32,
20}
21
22#[derive(Debug, PartialEq, Clone)]
23pub struct ConstArray<D: Dialect> {
24    pub index: Id,
25    pub item: Item<D>,
26    pub size: u32,
27    pub values: Vec<Variable<D>>,
28}
29
30#[derive(Debug, PartialEq, Eq, Clone)]
31pub struct LocalArray<D: Dialect> {
32    pub index: Id,
33    pub item: Item<D>,
34    pub size: u32,
35}
36
37impl<D: Dialect> LocalArray<D> {
38    pub fn new(index: Id, item: Item<D>, size: u32) -> Self {
39        Self { index, item, size }
40    }
41}
42
43impl<D: Dialect> SharedMemory<D> {
44    pub fn new(index: Id, item: Item<D>, size: u32) -> Self {
45        Self { index, item, size }
46    }
47}
48
49#[derive(Debug, Clone)]
50pub struct ComputeKernel<D: Dialect> {
51    pub inputs: Vec<Binding<D>>,
52    pub outputs: Vec<Binding<D>>,
53    pub named: Vec<(String, Binding<D>)>,
54    pub cube_dim: CubeDim,
55    pub body: Body<D>,
56    pub wmma_activated: bool,
57    pub bf16: bool,
58    pub f16: bool,
59    pub items: HashSet<super::Item<D>>,
60    pub kernel_name: String,
61}
62
63impl<D: Dialect> CompilerRepresentation for ComputeKernel<D> {
64    fn shared_memory_size(&self) -> usize {
65        let mut current = 0usize;
66
67        for var in self.body.shared_memories.iter() {
68            let factor = var.item.vectorization;
69            let elem_size_bytes = var.item.elem().size();
70            current += (var.size as usize) * factor * elem_size_bytes;
71        }
72
73        current
74    }
75}
76
77impl<D: Dialect> Display for ComputeKernel<D> {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        if self.bf16 {
80            D::include_bf16(f)?;
81        }
82
83        if self.f16 {
84            D::include_f16(f)?;
85        }
86
87        if self.wmma_activated {
88            D::wmma_includes(f)?;
89        }
90
91        f.write_str("typedef unsigned char uint8;\n")?;
92        f.write_str("typedef unsigned short uint16;\n")?;
93        f.write_str("typedef unsigned int uint;\n")?;
94        f.write_str("typedef unsigned long long int uint64;\n")?;
95        f.write_str("typedef long long int int64;\n")?;
96        D::deftypes(f)?;
97
98        for item in self.items.iter() {
99            let elem = item.elem;
100            let size = item.vectorization;
101            let alignment = elem.size() * size;
102            if size > 1 {
103                write!(
104                    f,
105                    "
106struct __align__({alignment}) {item} {{"
107                )?;
108
109                for i in 0..size {
110                    write!(
111                        f,
112                        "
113    {elem} i_{i};"
114                    )?;
115                }
116
117                f.write_str("\n};\n")?;
118            }
119        }
120
121        write!(
122            f,
123            "
124
125extern \"C\" __global__ void {}(
126",
127            self.kernel_name
128        )?;
129
130        let num_bindings = self.inputs.len() + self.outputs.len() + self.named.len();
131        let mut binding_index = 0;
132        for (index, binding) in self.inputs.iter().enumerate() {
133            binding_index += 1;
134            match binding.vis {
135                Visibility::Read => {
136                    write!(f, "{} input_{}[]", binding.item, index)?;
137                    // TODO: It breaks slices, because we can't easily create pointer to __restrict__,
138                    // we should have multiple pointer types to enable that optimization.
139                    //
140                    // write!(f, "const {}* __restrict__ input_{}", binding.item, index)?;
141                }
142                Visibility::ReadWrite => {
143                    write!(f, "{} input_{}[]", binding.item, index)?;
144                }
145            }
146            if binding_index < num_bindings {
147                f.write_str(",")?;
148            }
149        }
150        for (index, binding) in self.outputs.iter().enumerate() {
151            binding_index += 1;
152            write!(f, "{} output_{}[]", binding.item, index)?;
153            if binding_index < num_bindings {
154                f.write_str(",")?;
155            }
156        }
157        for (name, binding) in self.named.iter() {
158            binding_index += 1;
159
160            match binding.vis {
161                Visibility::Read => {
162                    write!(f, "{} {}[]", binding.item, name)?;
163                    // TODO: It breaks slices, because we can't easily create pointer to __restrict__,
164                    // we should have multiple pointer types to enable that optimization.
165                    //
166                    // write!(f, "const {}* __restrict__ {}", binding.item, name)?;
167                }
168                Visibility::ReadWrite => {
169                    write!(f, "{} {}[]", binding.item, name)?;
170                }
171            }
172
173            if binding_index < num_bindings {
174                f.write_str(",")?;
175            }
176        }
177
178        f.write_str("\n) {\n")?;
179
180        write!(f, "{}", self.body)?;
181        f.write_str("\n}")?;
182
183        Ok(())
184    }
185}