cubecl_common/
device.rs

1use core::cmp::Ordering;
2
3/// The device id.
4#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
5pub struct DeviceId {
6    /// The type id identifies the type of the device.
7    pub type_id: u16,
8    /// The index id identifies the device number.
9    pub index_id: u32,
10}
11
12/// Device trait for all cubecl devices.
13pub trait Device: Default + Clone + core::fmt::Debug + Send + Sync + 'static {
14    /// Create a device from its [id](DeviceId).
15    fn from_id(device_id: DeviceId) -> Self;
16    /// Retrieve the [device id](DeviceId) from the device.
17    fn to_id(&self) -> DeviceId;
18    /// Returns the number of devices available under the provided type id.
19    fn device_count(type_id: u16) -> usize;
20    /// Returns the total number of devices that can be handled by the runtime.
21    fn device_count_total() -> usize {
22        Self::device_count(0)
23    }
24}
25
26impl core::fmt::Display for DeviceId {
27    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28        f.write_fmt(format_args!("{self:?}"))
29    }
30}
31
32impl Ord for DeviceId {
33    fn cmp(&self, other: &Self) -> Ordering {
34        match self.type_id.cmp(&other.type_id) {
35            Ordering::Equal => self.index_id.cmp(&other.index_id),
36            other => other,
37        }
38    }
39}
40
41impl PartialOrd for DeviceId {
42    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
43        Some(self.cmp(other))
44    }
45}
46
47pub use context::*;
48
49#[cfg(feature = "std")]
50mod reentrant {
51    pub use parking_lot::{ReentrantMutex, ReentrantMutexGuard};
52}
53
54// MutCell and MutGuard differs in implementation whether `std` is activated.
55
56#[cfg(feature = "std")]
57mod cell {
58    use core::cell::{RefCell, RefMut};
59    use core::ops::DerefMut;
60
61    pub type MutCell<T> = RefCell<T>;
62    pub type MutGuard<'a, T> = RefMut<'a, T>;
63
64    pub unsafe fn borrow_mut_split<'a, T>(cell: &MutCell<T>) -> (&'a mut T, MutGuard<'_, T>) {
65        let mut guard = cell.borrow_mut();
66        let item = guard.deref_mut();
67        let item: &'a mut T = unsafe { core::mem::transmute(item) };
68
69        (item, guard)
70    }
71}
72
73#[cfg(not(feature = "std"))]
74mod cell {
75    use core::ops::{Deref, DerefMut};
76
77    pub struct MutGuard<'a, T> {
78        guard: spin::MutexGuard<'a, T>,
79    }
80
81    pub struct MutCell<T> {
82        lock: spin::Mutex<T>,
83    }
84
85    impl<T> MutCell<T> {
86        pub fn new(item: T) -> Self {
87            Self {
88                lock: spin::Mutex::new(item),
89            }
90        }
91    }
92
93    impl<'a, T> Deref for MutGuard<'a, T> {
94        type Target = T;
95
96        fn deref(&self) -> &Self::Target {
97            self.guard.deref()
98        }
99    }
100
101    impl<'a, T> DerefMut for MutGuard<'a, T> {
102        fn deref_mut(&mut self) -> &mut Self::Target {
103            self.guard.deref_mut()
104        }
105    }
106
107    impl<T> MutCell<T> {
108        pub fn try_borrow_mut(&self) -> Result<MutGuard<'_, T>, ()> {
109            match self.lock.try_lock() {
110                Some(guard) => Ok(MutGuard { guard }),
111                None => Err(()),
112            }
113        }
114    }
115
116    pub unsafe fn borrow_mut_split<'a, T>(
117        cell: &MutCell<T>,
118    ) -> (&'a mut T, spin::MutexGuard<'_, T>) {
119        let mut guard = cell.lock.lock();
120        let item = guard.deref_mut();
121        let item: &'a mut T = unsafe { core::mem::transmute(item) };
122
123        (item, guard)
124    }
125}
126
127#[cfg(not(feature = "std"))]
128mod reentrant {
129    use core::ops::Deref;
130
131    pub struct ReentrantMutex<T> {
132        inner: spin::RwLock<T>,
133    }
134
135    impl<T> ReentrantMutex<T> {
136        pub fn new(item: T) -> Self {
137            Self {
138                inner: spin::RwLock::new(item),
139            }
140        }
141    }
142
143    pub struct ReentrantMutexGuard<'a, T> {
144        guard: spin::RwLockReadGuard<'a, T>,
145    }
146
147    impl<'a, T> Deref for ReentrantMutexGuard<'a, T> {
148        type Target = T;
149
150        fn deref(&self) -> &Self::Target {
151            self.guard.deref()
152        }
153    }
154
155    impl<T> ReentrantMutex<T> {
156        pub fn lock(&self) -> ReentrantMutexGuard<'_, T> {
157            let guard = self.inner.read();
158            ReentrantMutexGuard { guard }
159        }
160    }
161}
162
163mod context {
164    use super::cell::{MutCell, MutGuard};
165    use alloc::boxed::Box;
166    use core::{
167        any::{Any, TypeId},
168        marker::PhantomData,
169    };
170    use hashbrown::HashMap;
171
172    use super::reentrant::{ReentrantMutex, ReentrantMutexGuard};
173
174    use crate::{device::cell::borrow_mut_split, stub::Arc};
175
176    use super::{Device, DeviceId};
177
178    /// A state that can be saved inside the [DeviceContext].
179    pub trait DeviceState: Send + 'static {
180        /// Initialize a new state on the given device.
181        fn init(device_id: DeviceId) -> Self;
182    }
183
184    /// Handle for accessing a [DeviceState] associated with a specific device.
185    pub struct DeviceContext<S: DeviceState> {
186        lock: DeviceStateLock,
187        device_id: DeviceId,
188        _phantom: PhantomData<S>,
189    }
190
191    /// There is nothing to read without a lock, and it's fine to allow locking a context reference.
192    unsafe impl<S: DeviceState> Sync for DeviceContext<S> {}
193
194    impl<S: DeviceState> Clone for DeviceContext<S> {
195        fn clone(&self) -> Self {
196            Self {
197                lock: self.lock.clone(),
198                _phantom: self._phantom,
199                device_id: self.device_id,
200            }
201        }
202    }
203
204    /// Guard providing mutable access to [DeviceState].
205    ///
206    /// Automatically releases the lock when dropped.
207    pub struct DeviceStateGuard<'a, S: DeviceState> {
208        guard_ref: Option<MutGuard<'a, Box<dyn Any + Send + 'static>>>,
209        guard_mutex: Option<ReentrantMutexGuard<'a, DeviceStateMap>>,
210        _phantom: PhantomData<S>,
211    }
212
213    /// Guard making sure only the locked device can be used.
214    ///
215    /// Automatically releases the lock when dropped.
216    pub struct DeviceGuard<'a> {
217        guard_mutex: Option<ReentrantMutexGuard<'a, DeviceStateMap>>,
218    }
219
220    impl<'a, S: DeviceState> Drop for DeviceStateGuard<'a, S> {
221        fn drop(&mut self) {
222            // Important to drop the ref before.
223            self.guard_ref = None;
224            self.guard_mutex = None;
225        }
226    }
227
228    impl<'a> Drop for DeviceGuard<'a> {
229        fn drop(&mut self) {
230            self.guard_mutex = None;
231        }
232    }
233
234    impl<'a, S: DeviceState> core::ops::Deref for DeviceStateGuard<'a, S> {
235        type Target = S;
236
237        fn deref(&self) -> &Self::Target {
238            self.guard_ref
239                .as_ref()
240                .expect("The guard to not be dropped")
241                .downcast_ref()
242                .expect("The type to be correct")
243        }
244    }
245
246    impl<'a, S: DeviceState> core::ops::DerefMut for DeviceStateGuard<'a, S> {
247        fn deref_mut(&mut self) -> &mut Self::Target {
248            self.guard_ref
249                .as_mut()
250                .expect("The guard to not be dropped")
251                .downcast_mut()
252                .expect("The type to be correct")
253        }
254    }
255
256    impl<S: DeviceState> DeviceContext<S> {
257        /// Creates a [DeviceState<S>] handle for the given device.
258        ///
259        /// Registers the device-type combination globally if needed.
260        pub fn locate<D: Device + 'static>(device: &D) -> Self {
261            DeviceStateLock::locate(device)
262        }
263
264        /// Inserts a new state associated with the device.
265        ///
266        /// # Returns
267        ///
268        /// An error if the device already has a registered state.
269        pub fn insert<D: Device + 'static>(
270            device: &D,
271            state_new: S,
272        ) -> Result<Self, alloc::string::String> {
273            let lock = Self::locate(device);
274            let id = TypeId::of::<S>();
275
276            let state = lock.lock.lock.lock();
277
278            // It is safe for the same reasons enumerated in the lock function.
279            let (map, map_guard) = unsafe { borrow_mut_split(&state.map) };
280
281            if map.contains_key(&id) {
282                return Err(alloc::format!(
283                    "A server is still registered for device {:?}",
284                    device
285                ));
286            }
287
288            let any: Box<dyn Any + Send + 'static> = Box::new(state_new);
289            let cell = MutCell::new(any);
290
291            map.insert(id, cell);
292
293            core::mem::drop(map_guard);
294            core::mem::drop(state);
295
296            Ok(lock)
297        }
298
299        /// Locks the current device making sure this device can be used.
300        pub fn lock_device(&self) -> DeviceGuard<'_> {
301            let state = self.lock.lock.lock();
302
303            DeviceGuard {
304                guard_mutex: Some(state),
305            }
306        }
307
308        /// Acquires exclusive mutable access to the [DeviceState].
309        ///
310        /// The same device can lock multiple types at the same time.
311        ///
312        /// # Panics
313        ///
314        /// If the same state type is locked multiple times on the same thread.
315        /// This can only happen with recursive locking of the same state, which isn't allowed
316        /// since having multiple mutable references to the same state isn't valid.
317        pub fn lock(&self) -> DeviceStateGuard<'_, S> {
318            let key = TypeId::of::<S>();
319            let state = self.lock.lock.lock();
320
321            // It is safe for multiple reasons.
322            //
323            // 1. The mutability of the map is handled by each map entry with a RefCell.
324            //    Therefore, multiple mutable references to a map entry are checked.
325            // 2. Map items are never cleaned up, therefore it's impossible to remove the validity of
326            //    an entry.
327            // 3. Because of the lock, no race condition is possible.
328            //
329            // The reason why unsafe is necessary is that the [DeviceStateGuard] doesn't keep track
330            // of the borrowed map entry lifetime. But since it keeps track of both the [RefCell]
331            // and the [ReentrantMutex] guards, it is fine to erase the lifetime here.
332            let (map, map_guard) = unsafe { borrow_mut_split(&state.map) };
333
334            if !map.contains_key(&key) {
335                let state_default = S::init(self.device_id);
336                let any: Box<dyn Any + Send + 'static> = Box::new(state_default);
337                let cell = MutCell::new(any);
338
339                map.insert(key, cell);
340            }
341
342            let value = map
343                .get(&key)
344                .expect("Just validated the map contains the key.");
345            let ref_guard = match value.try_borrow_mut() {
346                Ok(guard) => guard,
347                #[cfg(feature = "std")]
348                Err(_) => panic!(
349                    "State {} is already borrowed by the current thread {:?}",
350                    core::any::type_name::<S>(),
351                    std::thread::current().id()
352                ),
353                #[cfg(not(feature = "std"))]
354                Err(_) => panic!("State {} is already borrowed", core::any::type_name::<S>(),),
355            };
356
357            core::mem::drop(map_guard);
358
359            DeviceStateGuard {
360                guard_ref: Some(ref_guard),
361                guard_mutex: Some(state),
362                _phantom: PhantomData,
363            }
364        }
365    }
366
367    type Key = (DeviceId, TypeId);
368
369    static GLOBAL: spin::Mutex<DeviceLocator> = spin::Mutex::new(DeviceLocator { state: None });
370
371    struct DeviceLocator {
372        state: Option<HashMap<Key, DeviceStateLock>>,
373    }
374
375    #[derive(Clone)]
376    struct DeviceStateLock {
377        lock: Arc<ReentrantMutex<DeviceStateMap>>,
378    }
379
380    struct DeviceStateMap {
381        map: MutCell<HashMap<TypeId, MutCell<Box<dyn Any + Send + 'static>>>>,
382    }
383
384    impl DeviceStateLock {
385        fn locate<D: Device + 'static, S: DeviceState>(device: &D) -> DeviceContext<S> {
386            let id = device.to_id();
387            let key = (id, TypeId::of::<D>());
388            let mut global = GLOBAL.lock();
389
390            let map = match &mut global.state {
391                Some(state) => state,
392                None => {
393                    global.state = Some(HashMap::default());
394                    global.state.as_mut().expect("Just created Option::Some")
395                }
396            };
397
398            let lock = match map.get(&key) {
399                Some(value) => value.clone(),
400                None => {
401                    let state = DeviceStateMap::new();
402
403                    let value = DeviceStateLock {
404                        lock: Arc::new(ReentrantMutex::new(state)),
405                    };
406
407                    map.insert(key, value);
408                    map.get(&key).expect("Just inserted the key/value").clone()
409                }
410            };
411
412            DeviceContext {
413                lock,
414                device_id: id,
415                _phantom: PhantomData,
416            }
417        }
418    }
419
420    impl DeviceStateMap {
421        fn new() -> Self {
422            Self {
423                map: MutCell::new(HashMap::new()),
424            }
425        }
426    }
427
428    #[cfg(test)]
429    mod tests {
430        use core::{
431            ops::{Deref, DerefMut},
432            time::Duration,
433        };
434
435        use super::*;
436
437        #[test]
438        fn can_have_multiple_mutate_state() {
439            let device1 = TestDevice::<0>::new(0);
440            let device2 = TestDevice::<1>::new(0);
441
442            let state1_usize = DeviceContext::<usize>::locate(&device1);
443            let state1_u32 = DeviceContext::<u32>::locate(&device1);
444            let state2_usize = DeviceContext::<usize>::locate(&device2);
445
446            let mut guard_usize = state1_usize.lock();
447            let mut guard_u32 = state1_u32.lock();
448
449            let val_usize = guard_usize.deref_mut();
450            let val_u32 = guard_u32.deref_mut();
451
452            *val_usize += 1;
453            *val_u32 += 2;
454
455            assert_eq!(*val_usize, 1);
456            assert_eq!(*val_u32, 2);
457
458            core::mem::drop(guard_usize);
459            core::mem::drop(guard_u32);
460
461            let mut guard_usize = state2_usize.lock();
462
463            let val_usize = guard_usize.deref_mut();
464            *val_usize += 1;
465
466            assert_eq!(*val_usize, 1);
467
468            core::mem::drop(guard_usize);
469
470            let guard_usize = state1_usize.lock();
471            let guard_u32 = state1_u32.lock();
472
473            let val_usize = guard_usize.deref();
474            let val_u32 = guard_u32.deref();
475
476            assert_eq!(*val_usize, 1);
477            assert_eq!(*val_u32, 2);
478        }
479
480        #[test]
481        #[should_panic]
482        fn can_not_have_multiple_mut_ref_to_same_state() {
483            let device1 = TestDevice::<0>::new(0);
484
485            struct DummyState;
486
487            impl DeviceState for DummyState {
488                fn init(_device_id: DeviceId) -> Self {
489                    DummyState
490                }
491            }
492
493            fn recursive(total: usize, state: &DeviceContext<DummyState>) {
494                let _guard = state.lock();
495
496                if total > 0 {
497                    recursive(total - 1, state);
498                }
499            }
500
501            recursive(5, &DeviceContext::locate(&device1));
502        }
503
504        #[test]
505        fn work_with_many_threads() {
506            let num_threads = 32;
507            let handles: Vec<_> = (0..num_threads)
508                .map(|i| std::thread::spawn(move || thread_main((num_threads * 4) - i)))
509                .collect();
510
511            handles.into_iter().for_each(|h| h.join().unwrap());
512
513            let device1 = TestDevice::<0>::new(0);
514            let device2 = TestDevice::<1>::new(0);
515
516            let state1_i64 = DeviceContext::<i64>::locate(&device1);
517            let state1_i32 = DeviceContext::<i32>::locate(&device1);
518            let state2_i32 = DeviceContext::<i32>::locate(&device2);
519
520            let guard_i64 = state1_i64.lock();
521            let guard_i32 = state1_i32.lock();
522
523            assert_eq!(*guard_i64, num_threads as i64);
524            assert_eq!(*guard_i32, num_threads as i32 * 2);
525
526            core::mem::drop(guard_i64);
527            core::mem::drop(guard_i32);
528
529            let guard_i32 = state2_i32.lock();
530            assert_eq!(*guard_i32, num_threads as i32);
531        }
532
533        fn thread_main(sleep: u64) {
534            let device1 = TestDevice::<0>::new(0);
535            let device2 = TestDevice::<1>::new(0);
536
537            let state1_i64 = DeviceContext::<i64>::locate(&device1);
538            let state1_i32 = DeviceContext::<i32>::locate(&device1);
539            let state2_i32 = DeviceContext::<i32>::locate(&device2);
540
541            let mut guard_i64 = state1_i64.lock();
542            let mut guard_i32 = state1_i32.lock();
543
544            let val_i64 = guard_i64.deref_mut();
545            let val_i32 = guard_i32.deref_mut();
546
547            *val_i64 += 1;
548            *val_i32 += 2;
549
550            core::mem::drop(guard_i64);
551            core::mem::drop(guard_i32);
552
553            std::thread::sleep(Duration::from_millis(sleep));
554
555            let mut guard_i32 = state2_i32.lock();
556
557            let val_i32 = guard_i32.deref_mut();
558            *val_i32 += 1;
559
560            core::mem::drop(guard_i32);
561        }
562
563        #[derive(Debug, Clone, Default, new)]
564        /// Type is only to create different type ids.
565        pub struct TestDevice<const TYPE: u8> {
566            index: u32,
567        }
568
569        impl<const TYPE: u8> Device for TestDevice<TYPE> {
570            fn from_id(device_id: DeviceId) -> Self {
571                Self {
572                    index: device_id.index_id,
573                }
574            }
575
576            fn to_id(&self) -> DeviceId {
577                DeviceId {
578                    type_id: 0,
579                    index_id: self.index,
580                }
581            }
582
583            fn device_count(_type_id: u16) -> usize {
584                TYPE as usize + 1
585            }
586        }
587
588        impl DeviceState for usize {
589            fn init(_device_id: DeviceId) -> Self {
590                0
591            }
592        }
593
594        impl DeviceState for u32 {
595            fn init(_device_id: DeviceId) -> Self {
596                0
597            }
598        }
599        impl DeviceState for i32 {
600            fn init(_device_id: DeviceId) -> Self {
601                0
602            }
603        }
604        impl DeviceState for i64 {
605            fn init(_device_id: DeviceId) -> Self {
606                0
607            }
608        }
609    }
610}