Skip to main content

cubecl_runtime/memory_management/memory_pool/
handle.rs

1use crate::memory_management::MemoryHandle;
2use alloc::sync::Arc;
3use core::cell::Cell;
4
5/// Managed Memory handle
6#[derive(Debug)]
7pub struct ManagedMemoryHandle {
8    descriptor: Arc<ManagedMemoryDescriptor>,
9    // Holds only the reference counts of the handle.
10    handle_count: Arc<()>,
11}
12
13/// Binding of a memory handle
14#[derive(Debug)]
15pub struct ManagedMemoryBinding {
16    descriptor: Arc<ManagedMemoryDescriptor>,
17}
18
19impl Clone for ManagedMemoryHandle {
20    fn clone(&self) -> Self {
21        Self {
22            descriptor: self.descriptor.clone(),
23            handle_count: self.handle_count.clone(),
24        }
25    }
26}
27
28/// Managed memory descriptor.
29///
30/// The location is wrapped in `Cell` for interior mutability: multiple
31/// handles share the same descriptor via `Arc`, yet the memory management
32/// system needs to update the location after creation (e.g. during
33/// `reserve` / `bind`). All mutation happens on a single device thread,
34/// so `Cell` is safe — we just need `unsafe impl Sync` because `Cell`
35/// is `!Sync`.
36///
37/// An alternative would be `spin::Mutex<MemoryLocation>` which avoids the
38/// `unsafe impl Sync` at the cost of a lock on every access.
39pub(crate) struct ManagedMemoryDescriptor {
40    pub(crate) id: ManagedMemoryId,
41    location: Cell<MemoryLocation>,
42}
43
44// SAFETY: The channel requires ManagedMemoryHandle to be Send + Sync.
45// Cell is _not_ Sync, but, we know that we only access this from the device thread,
46// so we lie to the compiler and claim it is Sync. Other code must NOT rely on
47// ManagedMemoryDescriptor being Send + Sync.
48unsafe impl Send for ManagedMemoryDescriptor {}
49unsafe impl Sync for ManagedMemoryDescriptor {}
50
51impl core::fmt::Debug for ManagedMemoryDescriptor {
52    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
53        f.debug_struct("ManagedMemoryDescriptor")
54            .field("id", &self.id)
55            .field("location", &self.location())
56            .finish()
57    }
58}
59
60#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
61/// Managed memory unique identifier.
62pub struct ManagedMemoryId {
63    pub(crate) value: usize,
64}
65
66impl PartialEq for ManagedMemoryDescriptor {
67    fn eq(&self, other: &Self) -> bool {
68        self.id == other.id
69    }
70}
71
72impl Eq for ManagedMemoryDescriptor {}
73
74#[derive(Clone, Copy, Debug)]
75/// Defines where the [`ManagedMemoryId`] is located.
76pub(crate) struct MemoryLocation {
77    /// The memory pool index in the global memory management.
78    pub pool: u8,
79    /// The memory page index in a memory pool.
80    pub page: u16,
81    /// The memory slice index in a memory page.
82    pub slice: u32,
83    /// Whether the memory location is known/initialized.
84    pub init: u8,
85}
86
87impl ManagedMemoryDescriptor {
88    /// Update the memory location for the given [`ManagedMemoryId`].
89    pub(crate) fn update_location(&self, location: MemoryLocation) {
90        self.location.set(location);
91    }
92
93    /// Update only the slice position for the given [`ManagedMemoryId`].
94    pub(crate) fn update_slice(&self, slice: u32) {
95        self.location.update(|mut loc| {
96            loc.slice = slice;
97            loc
98        });
99    }
100
101    /// Update only the memory page position for the given [`ManagedMemoryId`].
102    pub fn update_page(&self, page: u16) {
103        self.location.update(|mut loc| {
104            loc.page = page;
105            loc
106        });
107    }
108
109    /// Retrieves the current location.
110    pub(crate) fn location(&self) -> MemoryLocation {
111        self.location.get()
112    }
113
114    pub(crate) fn slice(&self) -> usize {
115        self.location.get().slice as usize
116    }
117
118    pub(crate) fn page(&self) -> usize {
119        self.location.get().page as usize
120    }
121}
122
123impl MemoryLocation {
124    /// Creates a new memory location.
125    pub(crate) fn new(pool: u8, page: u16, slice: u32) -> Self {
126        Self {
127            pool,
128            page,
129            slice,
130            init: 1,
131        }
132    }
133
134    /// Creates a new uninitialized memory location.
135    pub(crate) fn uninit() -> Self {
136        Self {
137            pool: 0,
138            page: 0,
139            slice: 0,
140            init: 0,
141        }
142    }
143}
144
145impl ManagedMemoryHandle {
146    /// Creates a new managed memory handle.
147    pub fn new() -> Self {
148        let value = Self::gen_id();
149
150        Self {
151            descriptor: Arc::new(ManagedMemoryDescriptor {
152                id: ManagedMemoryId { value },
153                location: Cell::new(MemoryLocation::uninit()),
154            }),
155            handle_count: Arc::new(()),
156        }
157    }
158
159    /// Retrieves the descriptor for the current handle.
160    pub(crate) fn descriptor(&self) -> &ManagedMemoryDescriptor {
161        &self.descriptor
162    }
163
164    /// Return whether the current handle can be modified in-place.
165    pub fn can_mut(&self) -> bool {
166        Arc::strong_count(&self.handle_count) <= 2
167    }
168
169    /// Return whether the current handle is free.
170    pub fn is_free(&self) -> bool {
171        Arc::strong_count(&self.descriptor) <= 1
172    }
173
174    /// Returns the binding for the current handle.
175    pub fn binding(self) -> ManagedMemoryBinding {
176        ManagedMemoryBinding {
177            descriptor: self.descriptor.clone(),
178        }
179    }
180
181    fn gen_id() -> usize {
182        static COUNTER: core::sync::atomic::AtomicUsize = core::sync::atomic::AtomicUsize::new(0);
183        let value = COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
184        if value == usize::MAX {
185            core::panic!("Memory ID overflowed");
186        }
187        value
188    }
189}
190
191impl ManagedMemoryBinding {
192    /// Retrieves the descriptor for the current binding.
193    pub(crate) fn descriptor(&self) -> &ManagedMemoryDescriptor {
194        &self.descriptor
195    }
196}
197
198impl Default for ManagedMemoryHandle {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204impl Clone for ManagedMemoryBinding {
205    fn clone(&self) -> Self {
206        Self {
207            descriptor: self.descriptor.clone(),
208        }
209    }
210}
211
212impl MemoryHandle<ManagedMemoryBinding> for ManagedMemoryHandle {
213    fn can_mut(&self) -> bool {
214        self.can_mut()
215    }
216
217    fn binding(self) -> ManagedMemoryBinding {
218        self.binding()
219    }
220}
221
222/// Calculates a best-effort heuristic for the alignment of row-aligned tensors.
223/// Prefers contiguous alignments for unit dimensions, 16-byte minimum alignment for non-unit,
224/// scaling with input size up to `buffer_align`.
225pub fn optimal_align(shape: usize, elem_size: usize, buffer_align: usize) -> usize {
226    if shape == 1 {
227        elem_size
228    } else {
229        (shape * elem_size)
230            .next_power_of_two()
231            .clamp(16, buffer_align)
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn test_memory_id_mutability() {
241        let handle1 = ManagedMemoryHandle::new();
242        handle1.descriptor().update_slice(4);
243        assert_eq!(handle1.descriptor().slice(), 4);
244
245        let handle2 = ManagedMemoryHandle::new();
246        handle2
247            .clone()
248            .descriptor()
249            .update_location(handle1.descriptor().location());
250        assert_eq!(handle2.descriptor().slice(), 4);
251    }
252
253    #[test]
254    fn test_location_visible_through_shared_arc() {
255        let handle = ManagedMemoryHandle::new();
256        let handle2 = handle.clone();
257
258        let location = MemoryLocation::new(1, 2, 3);
259        handle.descriptor().update_location(location);
260
261        assert_eq!(handle2.descriptor().location().pool, 1);
262        assert_eq!(handle2.descriptor().location().page, 2);
263        assert_eq!(handle2.descriptor().location().slice, 3);
264        assert_eq!(handle2.descriptor().location().init, 1);
265
266        handle.descriptor().update_slice(42);
267        assert_eq!(handle2.descriptor().slice(), 42);
268    }
269}