Skip to main content

cubecl_cpp/shared/
body.rs

1use crate::shared::Item;
2
3use super::{Dialect, Instruction, Variable, barrier::BarrierOps, pipeline::PipelineOps};
4use std::fmt::Display;
5
6/// A body is composed of a list of [instructions](Instruction).
7#[derive(Debug, Clone)]
8pub struct Body<D: Dialect> {
9    pub instructions: Vec<Instruction<D>>,
10    pub shared_memories: Vec<super::SharedMemory<D>>,
11    pub pipelines: Vec<PipelineOps<D>>,
12    pub barriers: Vec<BarrierOps<D>>,
13    pub const_arrays: Vec<super::ConstArray<D>>,
14    pub local_arrays: Vec<super::LocalArray<D>>,
15    pub info_by_ptr: bool,
16    pub has_dynamic_meta: bool,
17    pub address_type: Item<D>,
18}
19
20impl<D: Dialect> Display for Body<D> {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        D::compile_bindings_body(f, self)?;
23
24        for shared in &self.shared_memories {
25            D::compile_shared_memory_declaration(f, shared)?;
26        }
27
28        for pipeline in self.pipelines.iter() {
29            writeln!(f, "{pipeline}")?;
30        }
31        for barrier in self.barriers.iter() {
32            writeln!(f, "{barrier}")?;
33        }
34
35        for const_array in self.const_arrays.iter() {
36            f.write_fmt(format_args!(
37                "const {} arrays_{}[{}] = {{",
38                const_array.item, const_array.index, const_array.size
39            ))?;
40            let item = const_array.item;
41            for value in const_array.values.iter().copied() {
42                let value = match value {
43                    Variable::Constant(value, _) => Variable::Constant(value, item),
44                    _ => unreachable!("Value is always constant"),
45                };
46                f.write_fmt(format_args!("{value},"))?;
47            }
48            f.write_str("};\n")?;
49        }
50
51        // Local arrays
52        for array in self.local_arrays.iter() {
53            write!(
54                f,
55                "{} l_arr_{}[{}];\n\n",
56                array.item, array.index, array.size
57            )?;
58        }
59
60        D::compile_wmma_local_variables(f)?;
61
62        for ops in self.instructions.iter() {
63            write!(f, "{ops}")?;
64        }
65
66        Ok(())
67    }
68}