cubecl_cpp/metal/
address_space.rs

1use cubecl_core::compute::{Location, Visibility};
2
3use crate::{
4    Dialect,
5    shared::{Binding, Component, Variable},
6};
7
8use super::BufferAttribute;
9use std::fmt::Display;
10
11#[derive(Debug, PartialEq, Eq, Clone, Copy)]
12pub enum AddressSpace {
13    Constant,
14    ConstDevice,
15    Device,
16    Thread,
17    ThreadGroup,
18    None,
19}
20
21impl AddressSpace {
22    pub fn attribute(&self) -> BufferAttribute {
23        match self {
24            AddressSpace::Constant | AddressSpace::ConstDevice | AddressSpace::Device => {
25                BufferAttribute::Buffer
26            }
27            AddressSpace::ThreadGroup => BufferAttribute::ThreadGroup,
28            _ => BufferAttribute::None,
29        }
30    }
31}
32
33impl Display for AddressSpace {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            AddressSpace::Constant => f.write_str("constant"),
37            AddressSpace::ConstDevice => f.write_str("const device"),
38            AddressSpace::Device => f.write_str("device"),
39            AddressSpace::ThreadGroup => f.write_str("threadgroup"),
40            AddressSpace::Thread => f.write_str("thread"),
41            AddressSpace::None => Ok(()),
42        }
43    }
44}
45
46impl From<AddressSpace> for Visibility {
47    fn from(val: AddressSpace) -> Self {
48        match val {
49            AddressSpace::Constant => Visibility::Read,
50            _ => Visibility::ReadWrite,
51        }
52    }
53}
54
55impl<D: Dialect> From<&Binding<D>> for AddressSpace {
56    fn from(value: &Binding<D>) -> Self {
57        match value.vis {
58            Visibility::Read => AddressSpace::ConstDevice,
59            Visibility::ReadWrite => match value.location {
60                Location::Storage => AddressSpace::Device,
61                Location::Cube => AddressSpace::ThreadGroup,
62            },
63        }
64    }
65}
66
67impl<D: Dialect> From<&Variable<D>> for AddressSpace {
68    fn from(value: &Variable<D>) -> Self {
69        match value {
70            Variable::AbsolutePosBaseName
71            | Variable::AbsolutePosX
72            | Variable::AbsolutePosY
73            | Variable::AbsolutePosZ
74            | Variable::UnitPosBaseName
75            | Variable::UnitPosX
76            | Variable::UnitPosY
77            | Variable::UnitPosZ
78            | Variable::CubePosBaseName
79            | Variable::CubePosX
80            | Variable::CubePosY
81            | Variable::CubePosZ
82            | Variable::CubeDimBaseName
83            | Variable::CubeDimX
84            | Variable::CubeDimY
85            | Variable::CubeDimZ
86            | Variable::CubeCountBaseName
87            | Variable::CubeCountX
88            | Variable::CubeCountY
89            | Variable::CubeCountZ
90            | Variable::PlaneDim
91            | Variable::UnitPosPlane => AddressSpace::None,
92            Variable::GlobalInputArray(..) => AddressSpace::ConstDevice,
93            Variable::GlobalOutputArray(..) => AddressSpace::Device,
94            Variable::GlobalScalar { .. } => {
95                if value.is_const() {
96                    AddressSpace::ConstDevice
97                } else {
98                    AddressSpace::Device
99                }
100            }
101            Variable::SharedMemory(..) => AddressSpace::ThreadGroup,
102            _ => AddressSpace::Thread,
103        }
104    }
105}