cubecl_cpp/shared/
kernel.rs1use super::{Body, Dialect, Item, Variable};
2use cubecl_core::{
3 ir::{CubeDim, Id, Visibility},
4 CompilerRepresentation,
5};
6use std::{collections::HashSet, fmt::Display};
7
8#[derive(Debug, PartialEq, Eq, Clone)]
9pub struct Binding<D: Dialect> {
10 pub item: Item<D>,
11 pub size: Option<usize>,
12 pub vis: Visibility,
13}
14
15#[derive(Debug, PartialEq, Eq, Clone)]
16pub struct SharedMemory<D: Dialect> {
17 pub index: Id,
18 pub item: Item<D>,
19 pub size: u32,
20}
21
22#[derive(Debug, PartialEq, Clone)]
23pub struct ConstArray<D: Dialect> {
24 pub index: Id,
25 pub item: Item<D>,
26 pub size: u32,
27 pub values: Vec<Variable<D>>,
28}
29
30#[derive(Debug, PartialEq, Eq, Clone)]
31pub struct LocalArray<D: Dialect> {
32 pub index: Id,
33 pub item: Item<D>,
34 pub size: u32,
35}
36
37impl<D: Dialect> LocalArray<D> {
38 pub fn new(index: Id, item: Item<D>, size: u32) -> Self {
39 Self { index, item, size }
40 }
41}
42
43impl<D: Dialect> SharedMemory<D> {
44 pub fn new(index: Id, item: Item<D>, size: u32) -> Self {
45 Self { index, item, size }
46 }
47}
48
49#[derive(Debug, Clone)]
50pub struct ComputeKernel<D: Dialect> {
51 pub inputs: Vec<Binding<D>>,
52 pub outputs: Vec<Binding<D>>,
53 pub named: Vec<(String, Binding<D>)>,
54 pub cube_dim: CubeDim,
55 pub body: Body<D>,
56 pub wmma_activated: bool,
57 pub bf16: bool,
58 pub f16: bool,
59 pub items: HashSet<super::Item<D>>,
60 pub kernel_name: String,
61}
62
63impl<D: Dialect> CompilerRepresentation for ComputeKernel<D> {
64 fn shared_memory_size(&self) -> usize {
65 let mut current = 0usize;
66
67 for var in self.body.shared_memories.iter() {
68 let factor = var.item.vectorization;
69 let elem_size_bytes = var.item.elem().size();
70 current += (var.size as usize) * factor * elem_size_bytes;
71 }
72
73 current
74 }
75}
76
77impl<D: Dialect> Display for ComputeKernel<D> {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 if self.bf16 {
80 D::include_bf16(f)?;
81 }
82
83 if self.f16 {
84 D::include_f16(f)?;
85 }
86
87 if self.wmma_activated {
88 D::wmma_includes(f)?;
89 }
90
91 f.write_str("typedef unsigned char uint8;\n")?;
92 f.write_str("typedef unsigned short uint16;\n")?;
93 f.write_str("typedef unsigned int uint;\n")?;
94 f.write_str("typedef unsigned long long int uint64;\n")?;
95 f.write_str("typedef long long int int64;\n")?;
96 D::deftypes(f)?;
97
98 for item in self.items.iter() {
99 let elem = item.elem;
100 let size = item.vectorization;
101 let alignment = elem.size() * size;
102 if size > 1 {
103 write!(
104 f,
105 "
106struct __align__({alignment}) {item} {{"
107 )?;
108
109 for i in 0..size {
110 write!(
111 f,
112 "
113 {elem} i_{i};"
114 )?;
115 }
116
117 f.write_str("\n};\n")?;
118 }
119 }
120
121 write!(
122 f,
123 "
124
125extern \"C\" __global__ void {}(
126",
127 self.kernel_name
128 )?;
129
130 let num_bindings = self.inputs.len() + self.outputs.len() + self.named.len();
131 let mut binding_index = 0;
132 for (index, binding) in self.inputs.iter().enumerate() {
133 binding_index += 1;
134 match binding.vis {
135 Visibility::Read => {
136 write!(f, "{} input_{}[]", binding.item, index)?;
137 }
142 Visibility::ReadWrite => {
143 write!(f, "{} input_{}[]", binding.item, index)?;
144 }
145 }
146 if binding_index < num_bindings {
147 f.write_str(",")?;
148 }
149 }
150 for (index, binding) in self.outputs.iter().enumerate() {
151 binding_index += 1;
152 write!(f, "{} output_{}[]", binding.item, index)?;
153 if binding_index < num_bindings {
154 f.write_str(",")?;
155 }
156 }
157 for (name, binding) in self.named.iter() {
158 binding_index += 1;
159
160 match binding.vis {
161 Visibility::Read => {
162 write!(f, "{} {}[]", binding.item, name)?;
163 }
168 Visibility::ReadWrite => {
169 write!(f, "{} {}[]", binding.item, name)?;
170 }
171 }
172
173 if binding_index < num_bindings {
174 f.write_str(",")?;
175 }
176 }
177
178 f.write_str("\n) {\n")?;
179
180 write!(f, "{}", self.body)?;
181 f.write_str("\n}")?;
182
183 Ok(())
184 }
185}