cubecl_runtime/memory_management/memory_pool/
handle.rs1use crate::memory_management::MemoryHandle;
2use alloc::sync::Arc;
3use core::cell::Cell;
4
5#[derive(Debug)]
7pub struct ManagedMemoryHandle {
8 descriptor: Arc<ManagedMemoryDescriptor>,
9 handle_count: Arc<()>,
11}
12
13#[derive(Debug)]
15pub struct ManagedMemoryBinding {
16 descriptor: Arc<ManagedMemoryDescriptor>,
17}
18
19impl Clone for ManagedMemoryHandle {
20 fn clone(&self) -> Self {
21 Self {
22 descriptor: self.descriptor.clone(),
23 handle_count: self.handle_count.clone(),
24 }
25 }
26}
27
28pub(crate) struct ManagedMemoryDescriptor {
40 pub(crate) id: ManagedMemoryId,
41 location: Cell<MemoryLocation>,
42}
43
44unsafe impl Send for ManagedMemoryDescriptor {}
49unsafe impl Sync for ManagedMemoryDescriptor {}
50
51impl core::fmt::Debug for ManagedMemoryDescriptor {
52 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
53 f.debug_struct("ManagedMemoryDescriptor")
54 .field("id", &self.id)
55 .field("location", &self.location())
56 .finish()
57 }
58}
59
60#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
61pub struct ManagedMemoryId {
63 pub(crate) value: usize,
64}
65
66impl PartialEq for ManagedMemoryDescriptor {
67 fn eq(&self, other: &Self) -> bool {
68 self.id == other.id
69 }
70}
71
72impl Eq for ManagedMemoryDescriptor {}
73
74#[derive(Clone, Copy, Debug)]
75pub(crate) struct MemoryLocation {
77 pub pool: u8,
79 pub page: u16,
81 pub slice: u32,
83 pub init: u8,
85}
86
87impl ManagedMemoryDescriptor {
88 pub(crate) fn update_location(&self, location: MemoryLocation) {
90 self.location.set(location);
91 }
92
93 pub(crate) fn update_slice(&self, slice: u32) {
95 self.location.update(|mut loc| {
96 loc.slice = slice;
97 loc
98 });
99 }
100
101 pub fn update_page(&self, page: u16) {
103 self.location.update(|mut loc| {
104 loc.page = page;
105 loc
106 });
107 }
108
109 pub(crate) fn location(&self) -> MemoryLocation {
111 self.location.get()
112 }
113
114 pub(crate) fn slice(&self) -> usize {
115 self.location.get().slice as usize
116 }
117
118 pub(crate) fn page(&self) -> usize {
119 self.location.get().page as usize
120 }
121}
122
123impl MemoryLocation {
124 pub(crate) fn new(pool: u8, page: u16, slice: u32) -> Self {
126 Self {
127 pool,
128 page,
129 slice,
130 init: 1,
131 }
132 }
133
134 pub(crate) fn uninit() -> Self {
136 Self {
137 pool: 0,
138 page: 0,
139 slice: 0,
140 init: 0,
141 }
142 }
143}
144
145impl ManagedMemoryHandle {
146 pub fn new() -> Self {
148 let value = Self::gen_id();
149
150 Self {
151 descriptor: Arc::new(ManagedMemoryDescriptor {
152 id: ManagedMemoryId { value },
153 location: Cell::new(MemoryLocation::uninit()),
154 }),
155 handle_count: Arc::new(()),
156 }
157 }
158
159 pub(crate) fn descriptor(&self) -> &ManagedMemoryDescriptor {
161 &self.descriptor
162 }
163
164 pub fn can_mut(&self) -> bool {
166 Arc::strong_count(&self.handle_count) <= 2
167 }
168
169 pub fn is_free(&self) -> bool {
171 Arc::strong_count(&self.descriptor) <= 1
172 }
173
174 pub fn binding(self) -> ManagedMemoryBinding {
176 ManagedMemoryBinding {
177 descriptor: self.descriptor.clone(),
178 }
179 }
180
181 fn gen_id() -> usize {
182 static COUNTER: core::sync::atomic::AtomicUsize = core::sync::atomic::AtomicUsize::new(0);
183 let value = COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
184 if value == usize::MAX {
185 core::panic!("Memory ID overflowed");
186 }
187 value
188 }
189}
190
191impl ManagedMemoryBinding {
192 pub(crate) fn descriptor(&self) -> &ManagedMemoryDescriptor {
194 &self.descriptor
195 }
196}
197
198impl Default for ManagedMemoryHandle {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204impl Clone for ManagedMemoryBinding {
205 fn clone(&self) -> Self {
206 Self {
207 descriptor: self.descriptor.clone(),
208 }
209 }
210}
211
212impl MemoryHandle<ManagedMemoryBinding> for ManagedMemoryHandle {
213 fn can_mut(&self) -> bool {
214 self.can_mut()
215 }
216
217 fn binding(self) -> ManagedMemoryBinding {
218 self.binding()
219 }
220}
221
222pub fn optimal_align(shape: usize, elem_size: usize, buffer_align: usize) -> usize {
226 if shape == 1 {
227 elem_size
228 } else {
229 (shape * elem_size)
230 .next_power_of_two()
231 .clamp(16, buffer_align)
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn test_memory_id_mutability() {
241 let handle1 = ManagedMemoryHandle::new();
242 handle1.descriptor().update_slice(4);
243 assert_eq!(handle1.descriptor().slice(), 4);
244
245 let handle2 = ManagedMemoryHandle::new();
246 handle2
247 .clone()
248 .descriptor()
249 .update_location(handle1.descriptor().location());
250 assert_eq!(handle2.descriptor().slice(), 4);
251 }
252
253 #[test]
254 fn test_location_visible_through_shared_arc() {
255 let handle = ManagedMemoryHandle::new();
256 let handle2 = handle.clone();
257
258 let location = MemoryLocation::new(1, 2, 3);
259 handle.descriptor().update_location(location);
260
261 assert_eq!(handle2.descriptor().location().pool, 1);
262 assert_eq!(handle2.descriptor().location().page, 2);
263 assert_eq!(handle2.descriptor().location().slice, 3);
264 assert_eq!(handle2.descriptor().location().init, 1);
265
266 handle.descriptor().update_slice(42);
267 assert_eq!(handle2.descriptor().slice(), 42);
268 }
269}