cubecl_runtime/
id.rs

1use alloc::sync::Arc;
2
3#[macro_export(local_inner_macros)]
4/// Create a new storage ID type.
5macro_rules! storage_id_type {
6    ($name:ident) => {
7        /// Storage ID.
8        #[derive(Copy, Clone, Hash, PartialEq, Eq, Debug, PartialOrd, Ord)]
9        pub struct $name {
10            value: usize,
11        }
12
13        impl $name {
14            /// Create a new ID.
15            pub fn new() -> Self {
16                use core::sync::atomic::{AtomicUsize, Ordering};
17
18                static COUNTER: AtomicUsize = AtomicUsize::new(0);
19
20                let value = COUNTER.fetch_add(1, Ordering::Relaxed);
21                if value == usize::MAX {
22                    core::panic!("Memory ID overflowed");
23                }
24                Self { value }
25            }
26        }
27
28        impl Default for $name {
29            fn default() -> Self {
30                Self::new()
31            }
32        }
33    };
34}
35
36/// Reference to a buffer handle.
37#[derive(Clone, Debug)]
38pub struct HandleRef<Id> {
39    id: Arc<Id>,
40    all: Arc<()>,
41}
42
43/// Reference to buffer binding.
44#[derive(Clone, Debug)]
45pub struct BindingRef<Id> {
46    id: Id,
47    _all: Arc<()>,
48}
49
50impl<Id> BindingRef<Id>
51where
52    Id: Clone + core::fmt::Debug,
53{
54    /// The id associated to the buffer.
55    pub(crate) fn id(&self) -> &Id {
56        &self.id
57    }
58}
59
60impl<Id> HandleRef<Id>
61where
62    Id: Clone + core::fmt::Debug,
63{
64    /// Create a new handle.
65    pub(crate) fn new(id: Id) -> Self {
66        Self {
67            id: Arc::new(id),
68            all: Arc::new(()),
69        }
70    }
71
72    /// The id associated to the handle.
73    pub(crate) fn id(&self) -> &Id {
74        &self.id
75    }
76
77    /// Get the binding.
78    pub(crate) fn binding(self) -> BindingRef<Id> {
79        BindingRef {
80            id: self.id.as_ref().clone(),
81            _all: self.all,
82        }
83    }
84
85    /// If the handle can be mut.
86    pub(crate) fn can_mut(&self) -> bool {
87        // 1 memory management reference with 1 tensor reference.
88        Arc::strong_count(&self.id) <= 2
89    }
90
91    /// If the resource is free.
92    pub(crate) fn is_free(&self) -> bool {
93        Arc::strong_count(&self.all) <= 1
94    }
95}
96
97#[macro_export(local_inner_macros)]
98/// Create new memory ID types.
99macro_rules! memory_id_type {
100    ($id:ident, $handle:ident) => {
101        /// Memory Handle.
102        #[derive(Clone, Debug)]
103        pub struct $handle {
104            value: $crate::id::HandleRef<$id>,
105        }
106
107        /// Memory ID.
108        #[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
109        pub struct $id {
110            pub(crate) value: usize,
111        }
112
113        impl $handle {
114            /// Create a new ID.
115            pub(crate) fn new() -> Self {
116                let value = Self::gen_id();
117                Self {
118                    value: $crate::id::HandleRef::new($id { value }),
119                }
120            }
121
122            fn gen_id() -> usize {
123                static COUNTER: core::sync::atomic::AtomicUsize =
124                    core::sync::atomic::AtomicUsize::new(0);
125
126                let value = COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
127                if value == usize::MAX {
128                    core::panic!("Memory ID overflowed");
129                }
130
131                value
132            }
133        }
134
135        impl core::ops::Deref for $handle {
136            type Target = $crate::id::HandleRef<$id>;
137
138            fn deref(&self) -> &Self::Target {
139                &self.value
140            }
141        }
142
143        impl Default for $handle {
144            fn default() -> Self {
145                Self::new()
146            }
147        }
148    };
149
150    ($id:ident, $handle:ident, $binding:ident) => {
151        memory_id_type!($id, $handle);
152
153        /// Binding of a memory handle.
154        #[derive(Clone, Debug)]
155        pub struct $binding {
156            value: $crate::id::BindingRef<$id>,
157        }
158
159        impl $handle {
160            pub(crate) fn binding(self) -> $binding {
161                $binding {
162                    value: self.value.binding(),
163                }
164            }
165        }
166
167        impl core::ops::Deref for $binding {
168            type Target = $crate::id::BindingRef<$id>;
169
170            fn deref(&self) -> &Self::Target {
171                &self.value
172            }
173        }
174    };
175}