cubecl_cpp/shared/
body.rs

1use super::{Dialect, Instruction, Variable, barrier::BarrierOps, pipeline::PipelineOps};
2use std::fmt::Display;
3
4/// A body is composed of a list of [instructions](Instruction).
5#[derive(Debug, Clone)]
6pub struct Body<D: Dialect> {
7    pub instructions: Vec<Instruction<D>>,
8    pub shared_memories: Vec<super::SharedMemory<D>>,
9    pub pipelines: Vec<PipelineOps<D>>,
10    pub barriers: Vec<BarrierOps<D>>,
11    pub const_arrays: Vec<super::ConstArray<D>>,
12    pub local_arrays: Vec<super::LocalArray<D>>,
13}
14
15impl<D: Dialect> Display for Body<D> {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        for shared in self.shared_memories.iter() {
18            let item = &shared.item;
19            let index = &shared.index;
20            let size = &shared.size;
21            D::compile_shared_memory_qualifier(f, shared)?;
22            writeln!(f, " {item} shared_memory_{index}[{size}];",)?;
23        }
24
25        for pipeline in self.pipelines.iter() {
26            writeln!(f, "{pipeline}")?;
27        }
28        for barrier in self.barriers.iter() {
29            writeln!(f, "{barrier}")?;
30        }
31
32        for const_array in self.const_arrays.iter() {
33            f.write_fmt(format_args!(
34                "const {} arrays_{}[{}] = {{",
35                const_array.item, const_array.index, const_array.size
36            ))?;
37            let elem = const_array.item.elem;
38            for value in const_array.values.iter().copied() {
39                let value = match value {
40                    Variable::ConstantScalar(value, _) => Variable::ConstantScalar(value, elem),
41                    _ => unreachable!("Value is always constant"),
42                };
43                f.write_fmt(format_args!("{value},"))?;
44            }
45            f.write_str("};\n")?;
46        }
47
48        // Local arrays
49        for array in self.local_arrays.iter() {
50            write!(
51                f,
52                "{} l_arr_{}[{}];\n\n",
53                array.item, array.index, array.size
54            )?;
55        }
56
57        D::compile_local_variables(f)?;
58
59        for ops in self.instructions.iter() {
60            write!(f, "{ops}")?;
61        }
62
63        Ok(())
64    }
65}