Skip to main content

cubecl_cpp/metal/
address_space.rs

1use cubecl_core::prelude::Visibility;
2
3use crate::{
4    Dialect,
5    shared::{Component, KernelArg, Variable},
6};
7
8use super::BufferAttribute;
9use std::fmt::Display;
10
11#[derive(Debug, PartialEq, Eq, Clone, Copy)]
12pub enum AddressSpace {
13    Constant,
14    ConstDevice,
15    Device,
16    Thread,
17    ThreadGroup,
18    None,
19}
20
21impl AddressSpace {
22    pub fn attribute(&self) -> BufferAttribute {
23        match self {
24            AddressSpace::Constant | AddressSpace::ConstDevice | AddressSpace::Device => {
25                BufferAttribute::Buffer
26            }
27            AddressSpace::ThreadGroup => BufferAttribute::ThreadGroup,
28            _ => BufferAttribute::None,
29        }
30    }
31}
32
33impl Display for AddressSpace {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            AddressSpace::Constant => f.write_str("constant"),
37            AddressSpace::ConstDevice => f.write_str("const device"),
38            AddressSpace::Device => f.write_str("device"),
39            AddressSpace::ThreadGroup => f.write_str("threadgroup"),
40            AddressSpace::Thread => f.write_str("thread"),
41            AddressSpace::None => Ok(()),
42        }
43    }
44}
45
46impl From<AddressSpace> for Visibility {
47    fn from(val: AddressSpace) -> Self {
48        match val {
49            AddressSpace::Constant => Visibility::Read,
50            _ => Visibility::ReadWrite,
51        }
52    }
53}
54
55impl<D: Dialect> From<&KernelArg<D>> for AddressSpace {
56    fn from(value: &KernelArg<D>) -> Self {
57        match value.vis {
58            Visibility::Read => AddressSpace::ConstDevice,
59            Visibility::ReadWrite => AddressSpace::Device,
60        }
61    }
62}
63
64impl<D: Dialect> From<&Variable<D>> for AddressSpace {
65    fn from(value: &Variable<D>) -> Self {
66        match value {
67            Variable::AbsolutePosBaseName
68            | Variable::AbsolutePosX
69            | Variable::AbsolutePosY
70            | Variable::AbsolutePosZ
71            | Variable::UnitPosBaseName
72            | Variable::UnitPosX
73            | Variable::UnitPosY
74            | Variable::UnitPosZ
75            | Variable::CubePosBaseName
76            | Variable::CubePosX
77            | Variable::CubePosY
78            | Variable::CubePosZ
79            | Variable::CubeDimBaseName
80            | Variable::CubeDimX
81            | Variable::CubeDimY
82            | Variable::CubeDimZ
83            | Variable::CubeCountBaseName
84            | Variable::CubeCountX
85            | Variable::CubeCountY
86            | Variable::CubeCountZ
87            | Variable::PlaneDim
88            | Variable::UnitPosPlane => AddressSpace::None,
89            Variable::GlobalInputArray(..) => AddressSpace::ConstDevice,
90            Variable::GlobalOutputArray(..) => AddressSpace::Device,
91            Variable::GlobalScalar { .. } => {
92                if value.is_const() {
93                    AddressSpace::ConstDevice
94                } else {
95                    AddressSpace::Device
96                }
97            }
98            Variable::SharedArray(..) => AddressSpace::ThreadGroup,
99            _ => AddressSpace::Thread,
100        }
101    }
102}