cubecl_cpp/shared/
body.rs1use super::{Dialect, Instruction, Variable, barrier::BarrierOps, pipeline::PipelineOps};
2use std::fmt::Display;
3
4#[derive(Debug, Clone)]
6pub struct Body<D: Dialect> {
7 pub instructions: Vec<Instruction<D>>,
8 pub shared_memories: Vec<super::SharedMemory<D>>,
9 pub pipelines: Vec<PipelineOps<D>>,
10 pub barriers: Vec<BarrierOps<D>>,
11 pub const_arrays: Vec<super::ConstArray<D>>,
12 pub local_arrays: Vec<super::LocalArray<D>>,
13}
14
15impl<D: Dialect> Display for Body<D> {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 D::compile_bindings_body(f, self)?;
18
19 let mut shared_memories = self.shared_memories.clone();
21 shared_memories.sort_by_key(|smem| smem.align.unwrap_or(smem.item.size() as u32));
22 shared_memories.reverse();
23
24 let mut shared_offset = 0u32;
25
26 for mut shared in shared_memories {
27 let align = shared.align.unwrap_or(shared.item.size() as u32);
28 let size_bytes = shared.size * shared.item.size() as u32;
29 shared.offset = shared_offset.next_multiple_of(align);
30 shared_offset = shared.offset + size_bytes;
31 D::compile_shared_memory_declaration(f, &shared)?;
32 }
33
34 for pipeline in self.pipelines.iter() {
35 writeln!(f, "{pipeline}")?;
36 }
37 for barrier in self.barriers.iter() {
38 writeln!(f, "{barrier}")?;
39 }
40
41 for const_array in self.const_arrays.iter() {
42 f.write_fmt(format_args!(
43 "const {} arrays_{}[{}] = {{",
44 const_array.item, const_array.index, const_array.size
45 ))?;
46 let elem = const_array.item.elem;
47 for value in const_array.values.iter().copied() {
48 let value = match value {
49 Variable::ConstantScalar(value, _) => Variable::ConstantScalar(value, elem),
50 _ => unreachable!("Value is always constant"),
51 };
52 f.write_fmt(format_args!("{value},"))?;
53 }
54 f.write_str("};\n")?;
55 }
56
57 for array in self.local_arrays.iter() {
59 write!(
60 f,
61 "{} l_arr_{}[{}];\n\n",
62 array.item, array.index, array.size
63 )?;
64 }
65
66 D::compile_wmma_local_variables(f)?;
67
68 for ops in self.instructions.iter() {
69 write!(f, "{ops}")?;
70 }
71
72 Ok(())
73 }
74}