cubecl_cpp/metal/
address_space.rs1use cubecl_core::prelude::Visibility;
2
3use crate::{
4 Dialect,
5 shared::{Component, KernelArg, Variable},
6};
7
8use super::BufferAttribute;
9use std::fmt::Display;
10
11#[derive(Debug, PartialEq, Eq, Clone, Copy)]
12pub enum AddressSpace {
13 Constant,
14 ConstDevice,
15 Device,
16 Thread,
17 ThreadGroup,
18 None,
19}
20
21impl AddressSpace {
22 pub fn attribute(&self) -> BufferAttribute {
23 match self {
24 AddressSpace::Constant | AddressSpace::ConstDevice | AddressSpace::Device => {
25 BufferAttribute::Buffer
26 }
27 AddressSpace::ThreadGroup => BufferAttribute::ThreadGroup,
28 _ => BufferAttribute::None,
29 }
30 }
31}
32
33impl Display for AddressSpace {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 AddressSpace::Constant => f.write_str("constant"),
37 AddressSpace::ConstDevice => f.write_str("const device"),
38 AddressSpace::Device => f.write_str("device"),
39 AddressSpace::ThreadGroup => f.write_str("threadgroup"),
40 AddressSpace::Thread => f.write_str("thread"),
41 AddressSpace::None => Ok(()),
42 }
43 }
44}
45
46impl From<AddressSpace> for Visibility {
47 fn from(val: AddressSpace) -> Self {
48 match val {
49 AddressSpace::Constant => Visibility::Read,
50 _ => Visibility::ReadWrite,
51 }
52 }
53}
54
55impl<D: Dialect> From<&KernelArg<D>> for AddressSpace {
56 fn from(value: &KernelArg<D>) -> Self {
57 match value.vis {
58 Visibility::Read => AddressSpace::ConstDevice,
59 Visibility::ReadWrite => AddressSpace::Device,
60 }
61 }
62}
63
64impl<D: Dialect> From<&Variable<D>> for AddressSpace {
65 fn from(value: &Variable<D>) -> Self {
66 match value {
67 Variable::AbsolutePosBaseName
68 | Variable::AbsolutePosX
69 | Variable::AbsolutePosY
70 | Variable::AbsolutePosZ
71 | Variable::UnitPosBaseName
72 | Variable::UnitPosX
73 | Variable::UnitPosY
74 | Variable::UnitPosZ
75 | Variable::CubePosBaseName
76 | Variable::CubePosX
77 | Variable::CubePosY
78 | Variable::CubePosZ
79 | Variable::CubeDimBaseName
80 | Variable::CubeDimX
81 | Variable::CubeDimY
82 | Variable::CubeDimZ
83 | Variable::CubeCountBaseName
84 | Variable::CubeCountX
85 | Variable::CubeCountY
86 | Variable::CubeCountZ
87 | Variable::PlaneDim
88 | Variable::UnitPosPlane => AddressSpace::None,
89 Variable::GlobalInputArray(..) => AddressSpace::ConstDevice,
90 Variable::GlobalOutputArray(..) => AddressSpace::Device,
91 Variable::GlobalScalar { .. } => {
92 if value.is_const() {
93 AddressSpace::ConstDevice
94 } else {
95 AddressSpace::Device
96 }
97 }
98 Variable::SharedArray(..) => AddressSpace::ThreadGroup,
99 _ => AddressSpace::Thread,
100 }
101 }
102}