cubecl_cpp/metal/
attribute.rs

1use std::fmt::Display;
2
3pub enum BufferAttribute {
4    Buffer,
5    ThreadGroup,
6    None,
7}
8
9impl BufferAttribute {
10    pub fn indexed_fmt(&self, index: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
11        write!(f, " [[{self}({index})]]")
12    }
13}
14
15impl Display for BufferAttribute {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        match self {
18            Self::Buffer => f.write_str("buffer"),
19            Self::ThreadGroup => f.write_str("threadgroup"),
20            Self::None => Ok(()),
21        }
22    }
23}
24
25pub enum BuiltInAttribute {
26    SIMDgroupIndexInThreadgroup,
27    ThreadIndexInSIMDgroup,
28    ThreadIndexInThreadgroup,
29    ThreadPositionInGrid,
30    ThreadPositionInThreadgroup,
31    ThreadgroupPositionInGrid,
32    ThreadgroupsPerGrid,
33    ThreadsPerSIMDgroup,
34    ThreadsPerThreadgroup,
35    None,
36}
37
38impl Display for BuiltInAttribute {
39    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
40        match self {
41            BuiltInAttribute::SIMDgroupIndexInThreadgroup => {
42                f.write_str("[[simdgroup_index_in_threadgroup]]")
43            }
44            BuiltInAttribute::ThreadIndexInSIMDgroup => {
45                f.write_str("[[thread_index_in_simdgroup]]")
46            }
47            BuiltInAttribute::ThreadIndexInThreadgroup => {
48                f.write_str("[[thread_index_in_threadgroup]]")
49            }
50            BuiltInAttribute::ThreadPositionInGrid => f.write_str("[[thread_position_in_grid]]"),
51            BuiltInAttribute::ThreadPositionInThreadgroup => {
52                f.write_str("[[thread_position_in_threadgroup]]")
53            }
54            BuiltInAttribute::ThreadgroupPositionInGrid => {
55                f.write_str("[[threadgroup_position_in_grid]]")
56            }
57            BuiltInAttribute::ThreadgroupsPerGrid => f.write_str("[[threadgroups_per_grid]]"),
58            BuiltInAttribute::ThreadsPerSIMDgroup => f.write_str("[[threads_per_simdgroup]]"),
59            BuiltInAttribute::ThreadsPerThreadgroup => f.write_str("[[threads_per_threadgroup]]"),
60            BuiltInAttribute::None => Ok(()),
61        }
62    }
63}