cubecl_cpp/shared/
body.rs

1use super::{Dialect, Instruction, Variable};
2use std::fmt::Display;
3
4/// A body is composed of a list of [instructions](Instruction).
5#[derive(Debug, Clone)]
6pub struct Body<D: Dialect> {
7    pub instructions: Vec<Instruction<D>>,
8    pub shared_memories: Vec<super::SharedMemory<D>>,
9    pub const_arrays: Vec<super::ConstArray<D>>,
10    pub local_arrays: Vec<super::LocalArray<D>>,
11    pub warp_size_checked: bool,
12    pub settings: VariableSettings,
13}
14
15/// The settings to generate the right variables.
16#[derive(Debug, Clone, Default)]
17pub struct VariableSettings {
18    pub idx_global: bool,
19    pub thread_idx_global: bool,
20    pub absolute_idx: (bool, bool, bool),
21    pub block_idx_global: bool,
22    pub block_dim_global: bool,
23    pub grid_dim_global: bool,
24}
25
26impl<D: Dialect> Display for Body<D> {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        if self.settings.idx_global
29            || self.settings.absolute_idx.0
30            || self.settings.absolute_idx.1
31            || self.settings.absolute_idx.2
32        {
33            f.write_str(
34                "
35    int3 absoluteIdx = make_int3(
36        blockIdx.x * blockDim.x + threadIdx.x,
37        blockIdx.y * blockDim.y + threadIdx.y,
38        blockIdx.z * blockDim.z + threadIdx.z
39    );
40",
41            )?;
42        }
43
44        if self.settings.idx_global {
45            f.write_str(
46                "
47    uint idxGlobal = (absoluteIdx.z * gridDim.x * blockDim.x * gridDim.y * blockDim.y) + (absoluteIdx.y * gridDim.x * blockDim.x) + absoluteIdx.x;
48",
49            )?;
50        }
51
52        if self.settings.thread_idx_global {
53            f.write_str(
54                "
55    int threadIdxGlobal = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * (blockDim.x * blockDim.y);
56            ",
57            )?;
58        }
59        if self.settings.block_idx_global {
60            f.write_str(
61                "
62    int blockIdxGlobal = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * (gridDim.x * gridDim.y);
63            ",
64            )?;
65        }
66
67        if self.settings.block_dim_global {
68            f.write_str(
69                "
70    int blockDimGlobal = blockDim.x * blockDim.y * blockDim.z;
71            ",
72            )?;
73        }
74
75        if self.settings.grid_dim_global {
76            f.write_str(
77                "
78    int gridDimGlobal = gridDim.x * gridDim.y * gridDim.z;
79            ",
80            )?;
81        }
82
83        if self.warp_size_checked {
84            f.write_str(
85                "
86 int warpSizeChecked = min(warpSize, blockDim.x * blockDim.y * blockDim.z);
87",
88            )?;
89        }
90
91        for shared in self.shared_memories.iter() {
92            writeln!(
93                f,
94                "__shared__ {} shared_memory_{}[{}];",
95                shared.item, shared.index, shared.size
96            )?;
97        }
98
99        for const_array in self.const_arrays.iter() {
100            f.write_fmt(format_args!(
101                "const {} arrays_{}[{}] = {{",
102                const_array.item, const_array.index, const_array.size
103            ))?;
104            let elem = const_array.item.elem;
105            for value in const_array.values.iter().copied() {
106                let value = match value {
107                    Variable::ConstantScalar(value, _) => Variable::ConstantScalar(value, elem),
108                    _ => unreachable!("Value is always constant"),
109                };
110                f.write_fmt(format_args!("{value},"))?;
111            }
112            f.write_str("};\n")?;
113        }
114
115        // Local arrays
116        for array in self.local_arrays.iter() {
117            write!(
118                f,
119                "{} l_arr_{}[{}];\n\n",
120                array.item, array.index, array.size
121            )?;
122        }
123
124        D::local_variables(f)?;
125
126        for ops in self.instructions.iter() {
127            write!(f, "{ops}")?;
128        }
129
130        Ok(())
131    }
132}