cubecl_cpp/metal/
attribute.rs1use std::fmt::Display;
2
3pub enum BufferAttribute {
4 Buffer,
5 ThreadGroup,
6 None,
7}
8
9impl BufferAttribute {
10 pub fn indexed_fmt(&self, index: usize, f: &mut std::fmt::Formatter) -> std::fmt::Result {
11 write!(f, " [[{self}({index})]]")
12 }
13}
14
15impl Display for BufferAttribute {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 match self {
18 Self::Buffer => f.write_str("buffer"),
19 Self::ThreadGroup => f.write_str("threadgroup"),
20 Self::None => Ok(()),
21 }
22 }
23}
24
25pub enum BuiltInAttribute {
26 SIMDgroupIndexInThreadgroup,
27 ThreadIndexInSIMDgroup,
28 ThreadIndexInThreadgroup,
29 ThreadPositionInGrid,
30 ThreadPositionInThreadgroup,
31 ThreadgroupPositionInGrid,
32 ThreadgroupsPerGrid,
33 ThreadsPerSIMDgroup,
34 ThreadsPerThreadgroup,
35 None,
36}
37
38impl Display for BuiltInAttribute {
39 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
40 match self {
41 BuiltInAttribute::SIMDgroupIndexInThreadgroup => {
42 f.write_str("[[simdgroup_index_in_threadgroup]]")
43 }
44 BuiltInAttribute::ThreadIndexInSIMDgroup => {
45 f.write_str("[[thread_index_in_simdgroup]]")
46 }
47 BuiltInAttribute::ThreadIndexInThreadgroup => {
48 f.write_str("[[thread_index_in_threadgroup]]")
49 }
50 BuiltInAttribute::ThreadPositionInGrid => f.write_str("[[thread_position_in_grid]]"),
51 BuiltInAttribute::ThreadPositionInThreadgroup => {
52 f.write_str("[[thread_position_in_threadgroup]]")
53 }
54 BuiltInAttribute::ThreadgroupPositionInGrid => {
55 f.write_str("[[threadgroup_position_in_grid]]")
56 }
57 BuiltInAttribute::ThreadgroupsPerGrid => f.write_str("[[threadgroups_per_grid]]"),
58 BuiltInAttribute::ThreadsPerSIMDgroup => f.write_str("[[threads_per_simdgroup]]"),
59 BuiltInAttribute::ThreadsPerThreadgroup => f.write_str("[[threads_per_threadgroup]]"),
60 BuiltInAttribute::None => Ok(()),
61 }
62 }
63}