cubecl_cpp/shared/
body.rs1use super::{Dialect, Instruction, Variable};
2use std::fmt::Display;
3
4#[derive(Debug, Clone)]
6pub struct Body<D: Dialect> {
7 pub instructions: Vec<Instruction<D>>,
8 pub shared_memories: Vec<super::SharedMemory<D>>,
9 pub const_arrays: Vec<super::ConstArray<D>>,
10 pub local_arrays: Vec<super::LocalArray<D>>,
11 pub warp_size_checked: bool,
12 pub settings: VariableSettings,
13}
14
15#[derive(Debug, Clone, Default)]
17pub struct VariableSettings {
18 pub idx_global: bool,
19 pub thread_idx_global: bool,
20 pub absolute_idx: (bool, bool, bool),
21 pub block_idx_global: bool,
22 pub block_dim_global: bool,
23 pub grid_dim_global: bool,
24}
25
26impl<D: Dialect> Display for Body<D> {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 if self.settings.idx_global
29 || self.settings.absolute_idx.0
30 || self.settings.absolute_idx.1
31 || self.settings.absolute_idx.2
32 {
33 f.write_str(
34 "
35 int3 absoluteIdx = make_int3(
36 blockIdx.x * blockDim.x + threadIdx.x,
37 blockIdx.y * blockDim.y + threadIdx.y,
38 blockIdx.z * blockDim.z + threadIdx.z
39 );
40",
41 )?;
42 }
43
44 if self.settings.idx_global {
45 f.write_str(
46 "
47 uint idxGlobal = (absoluteIdx.z * gridDim.x * blockDim.x * gridDim.y * blockDim.y) + (absoluteIdx.y * gridDim.x * blockDim.x) + absoluteIdx.x;
48",
49 )?;
50 }
51
52 if self.settings.thread_idx_global {
53 f.write_str(
54 "
55 int threadIdxGlobal = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * (blockDim.x * blockDim.y);
56 ",
57 )?;
58 }
59 if self.settings.block_idx_global {
60 f.write_str(
61 "
62 int blockIdxGlobal = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * (gridDim.x * gridDim.y);
63 ",
64 )?;
65 }
66
67 if self.settings.block_dim_global {
68 f.write_str(
69 "
70 int blockDimGlobal = blockDim.x * blockDim.y * blockDim.z;
71 ",
72 )?;
73 }
74
75 if self.settings.grid_dim_global {
76 f.write_str(
77 "
78 int gridDimGlobal = gridDim.x * gridDim.y * gridDim.z;
79 ",
80 )?;
81 }
82
83 if self.warp_size_checked {
84 f.write_str(
85 "
86 int warpSizeChecked = min(warpSize, blockDim.x * blockDim.y * blockDim.z);
87",
88 )?;
89 }
90
91 for shared in self.shared_memories.iter() {
92 writeln!(
93 f,
94 "__shared__ {} shared_memory_{}[{}];",
95 shared.item, shared.index, shared.size
96 )?;
97 }
98
99 for const_array in self.const_arrays.iter() {
100 f.write_fmt(format_args!(
101 "const {} arrays_{}[{}] = {{",
102 const_array.item, const_array.index, const_array.size
103 ))?;
104 let elem = const_array.item.elem;
105 for value in const_array.values.iter().copied() {
106 let value = match value {
107 Variable::ConstantScalar(value, _) => Variable::ConstantScalar(value, elem),
108 _ => unreachable!("Value is always constant"),
109 };
110 f.write_fmt(format_args!("{value},"))?;
111 }
112 f.write_str("};\n")?;
113 }
114
115 for array in self.local_arrays.iter() {
117 write!(
118 f,
119 "{} l_arr_{}[{}];\n\n",
120 array.item, array.index, array.size
121 )?;
122 }
123
124 D::local_variables(f)?;
125
126 for ops in self.instructions.iter() {
127 write!(f, "{ops}")?;
128 }
129
130 Ok(())
131 }
132}