cubecl_cpp/metal/
address_space.rs1use cubecl_core::compute::{Location, Visibility};
2
3use crate::{
4 Dialect,
5 shared::{Binding, Component, Variable},
6};
7
8use super::BufferAttribute;
9use std::fmt::Display;
10
11#[derive(Debug, PartialEq, Eq, Clone, Copy)]
12pub enum AddressSpace {
13 Constant,
14 ConstDevice,
15 Device,
16 Thread,
17 ThreadGroup,
18 None,
19}
20
21impl AddressSpace {
22 pub fn attribute(&self) -> BufferAttribute {
23 match self {
24 AddressSpace::Constant | AddressSpace::ConstDevice | AddressSpace::Device => {
25 BufferAttribute::Buffer
26 }
27 AddressSpace::ThreadGroup => BufferAttribute::ThreadGroup,
28 _ => BufferAttribute::None,
29 }
30 }
31}
32
33impl Display for AddressSpace {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 AddressSpace::Constant => f.write_str("constant"),
37 AddressSpace::ConstDevice => f.write_str("const device"),
38 AddressSpace::Device => f.write_str("device"),
39 AddressSpace::ThreadGroup => f.write_str("threadgroup"),
40 AddressSpace::Thread => f.write_str("thread"),
41 AddressSpace::None => Ok(()),
42 }
43 }
44}
45
46impl From<AddressSpace> for Visibility {
47 fn from(val: AddressSpace) -> Self {
48 match val {
49 AddressSpace::Constant => Visibility::Read,
50 _ => Visibility::ReadWrite,
51 }
52 }
53}
54
55impl<D: Dialect> From<&Binding<D>> for AddressSpace {
56 fn from(value: &Binding<D>) -> Self {
57 match value.vis {
58 Visibility::Read => AddressSpace::ConstDevice,
59 Visibility::ReadWrite => match value.location {
60 Location::Storage => AddressSpace::Device,
61 Location::Cube => AddressSpace::ThreadGroup,
62 },
63 }
64 }
65}
66
67impl<D: Dialect> From<&Variable<D>> for AddressSpace {
68 fn from(value: &Variable<D>) -> Self {
69 match value {
70 Variable::AbsolutePosBaseName
71 | Variable::AbsolutePosX
72 | Variable::AbsolutePosY
73 | Variable::AbsolutePosZ
74 | Variable::UnitPosBaseName
75 | Variable::UnitPosX
76 | Variable::UnitPosY
77 | Variable::UnitPosZ
78 | Variable::CubePosBaseName
79 | Variable::CubePosX
80 | Variable::CubePosY
81 | Variable::CubePosZ
82 | Variable::CubeDimBaseName
83 | Variable::CubeDimX
84 | Variable::CubeDimY
85 | Variable::CubeDimZ
86 | Variable::CubeCountBaseName
87 | Variable::CubeCountX
88 | Variable::CubeCountY
89 | Variable::CubeCountZ
90 | Variable::PlaneDim
91 | Variable::UnitPosPlane => AddressSpace::None,
92 Variable::GlobalInputArray(..) => AddressSpace::ConstDevice,
93 Variable::GlobalOutputArray(..) => AddressSpace::Device,
94 Variable::GlobalScalar { .. } => {
95 if value.is_const() {
96 AddressSpace::ConstDevice
97 } else {
98 AddressSpace::Device
99 }
100 }
101 Variable::SharedMemory(..) => AddressSpace::ThreadGroup,
102 _ => AddressSpace::Thread,
103 }
104 }
105}