cubecl_cpp/shared/
body.rs

1use super::{Dialect, Instruction, Variable, barrier::BarrierOps, pipeline::PipelineOps};
2use std::fmt::Display;
3
4/// A body is composed of a list of [instructions](Instruction).
5#[derive(Debug, Clone)]
6pub struct Body<D: Dialect> {
7    pub instructions: Vec<Instruction<D>>,
8    pub shared_memories: Vec<super::SharedMemory<D>>,
9    pub pipelines: Vec<PipelineOps<D>>,
10    pub barriers: Vec<BarrierOps<D>>,
11    pub const_arrays: Vec<super::ConstArray<D>>,
12    pub local_arrays: Vec<super::LocalArray<D>>,
13}
14
15impl<D: Dialect> Display for Body<D> {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        D::compile_bindings_body(f, self)?;
18
19        // Put highest alignment at the front to reduce padding
20        let mut shared_memories = self.shared_memories.clone();
21        shared_memories.sort_by_key(|smem| smem.align.unwrap_or(smem.item.size() as u32));
22        shared_memories.reverse();
23
24        let mut shared_offset = 0u32;
25
26        for mut shared in shared_memories {
27            let align = shared.align.unwrap_or(shared.item.size() as u32);
28            let size_bytes = shared.size * shared.item.size() as u32;
29            shared.offset = shared_offset.next_multiple_of(align);
30            shared_offset = shared.offset + size_bytes;
31            D::compile_shared_memory_declaration(f, &shared)?;
32        }
33
34        for pipeline in self.pipelines.iter() {
35            writeln!(f, "{pipeline}")?;
36        }
37        for barrier in self.barriers.iter() {
38            writeln!(f, "{barrier}")?;
39        }
40
41        for const_array in self.const_arrays.iter() {
42            f.write_fmt(format_args!(
43                "const {} arrays_{}[{}] = {{",
44                const_array.item, const_array.index, const_array.size
45            ))?;
46            let elem = const_array.item.elem;
47            for value in const_array.values.iter().copied() {
48                let value = match value {
49                    Variable::ConstantScalar(value, _) => Variable::ConstantScalar(value, elem),
50                    _ => unreachable!("Value is always constant"),
51                };
52                f.write_fmt(format_args!("{value},"))?;
53            }
54            f.write_str("};\n")?;
55        }
56
57        // Local arrays
58        for array in self.local_arrays.iter() {
59            write!(
60                f,
61                "{} l_arr_{}[{}];\n\n",
62                array.item, array.index, array.size
63            )?;
64        }
65
66        D::compile_wmma_local_variables(f)?;
67
68        for ops in self.instructions.iter() {
69            write!(f, "{ops}")?;
70        }
71
72        Ok(())
73    }
74}