cubecl_common/
device.rs

1use core::cmp::Ordering;
2
3/// The device id.
4#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
5pub struct DeviceId {
6    /// The type id identifies the type of the device.
7    pub type_id: u16,
8    /// The index id identifies the device number.
9    pub index_id: u32,
10}
11
12/// Device trait for all cubecl devices.
13pub trait Device: Default + Clone + core::fmt::Debug + Send + Sync + 'static {
14    /// Create a device from its [id](DeviceId).
15    fn from_id(device_id: DeviceId) -> Self;
16    /// Retrieve the [device id](DeviceId) from the device.
17    fn to_id(&self) -> DeviceId;
18    /// Returns the number of devices available under the provided type id.
19    fn device_count(type_id: u16) -> usize;
20    /// Returns the total number of devices that can be handled by the runtime.
21    fn device_count_total() -> usize {
22        Self::device_count(0)
23    }
24}
25
26impl core::fmt::Display for DeviceId {
27    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28        f.write_fmt(format_args!("{self:?}"))
29    }
30}
31
32impl Ord for DeviceId {
33    fn cmp(&self, other: &Self) -> Ordering {
34        match self.type_id.cmp(&other.type_id) {
35            Ordering::Equal => self.index_id.cmp(&other.index_id),
36            other => other,
37        }
38    }
39}
40
41impl PartialOrd for DeviceId {
42    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
43        Some(self.cmp(other))
44    }
45}
46
47pub use context::*;
48
49#[cfg(feature = "std")]
50mod reentrant {
51    pub use parking_lot::{ReentrantMutex, ReentrantMutexGuard};
52}
53
54// MutCell and MutGuard differs in implementation whether `std` is activated.
55
56#[cfg(feature = "std")]
57mod cell {
58    use core::cell::{RefCell, RefMut};
59    use core::ops::DerefMut;
60
61    pub type MutCell<T> = RefCell<T>;
62    pub type MutGuard<'a, T> = RefMut<'a, T>;
63
64    pub unsafe fn borrow_mut_split<'a, T>(cell: &MutCell<T>) -> (&'a mut T, MutGuard<'_, T>) {
65        let mut guard = cell.borrow_mut();
66        let item = guard.deref_mut();
67        let item: &'a mut T = unsafe { core::mem::transmute(item) };
68
69        (item, guard)
70    }
71}
72
73#[cfg(not(feature = "std"))]
74mod cell {
75    use core::ops::{Deref, DerefMut};
76
77    pub struct MutGuard<'a, T> {
78        guard: spin::MutexGuard<'a, T>,
79    }
80
81    pub struct MutCell<T> {
82        lock: spin::Mutex<T>,
83    }
84
85    impl<T> MutCell<T> {
86        pub fn new(item: T) -> Self {
87            Self {
88                lock: spin::Mutex::new(item),
89            }
90        }
91    }
92
93    impl<'a, T> Deref for MutGuard<'a, T> {
94        type Target = T;
95
96        fn deref(&self) -> &Self::Target {
97            self.guard.deref()
98        }
99    }
100
101    impl<'a, T> DerefMut for MutGuard<'a, T> {
102        fn deref_mut(&mut self) -> &mut Self::Target {
103            self.guard.deref_mut()
104        }
105    }
106
107    impl<T> MutCell<T> {
108        pub fn try_borrow_mut(&self) -> Result<MutGuard<'_, T>, ()> {
109            match self.lock.try_lock() {
110                Some(guard) => Ok(MutGuard { guard }),
111                None => Err(()),
112            }
113        }
114    }
115
116    pub unsafe fn borrow_mut_split<'a, T>(
117        cell: &MutCell<T>,
118    ) -> (&'a mut T, spin::MutexGuard<'_, T>) {
119        let mut guard = cell.lock.lock();
120        let item = guard.deref_mut();
121        let item: &'a mut T = unsafe { core::mem::transmute(item) };
122
123        (item, guard)
124    }
125}
126
127#[cfg(not(feature = "std"))]
128mod reentrant {
129    use core::ops::Deref;
130
131    pub struct ReentrantMutex<T> {
132        inner: spin::RwLock<T>,
133    }
134
135    impl<T> ReentrantMutex<T> {
136        pub fn new(item: T) -> Self {
137            Self {
138                inner: spin::RwLock::new(item),
139            }
140        }
141    }
142
143    pub struct ReentrantMutexGuard<'a, T> {
144        guard: spin::RwLockReadGuard<'a, T>,
145    }
146
147    impl<'a, T> Deref for ReentrantMutexGuard<'a, T> {
148        type Target = T;
149
150        fn deref(&self) -> &Self::Target {
151            self.guard.deref()
152        }
153    }
154
155    impl<T> ReentrantMutex<T> {
156        pub fn lock(&self) -> ReentrantMutexGuard<'_, T> {
157            let guard = self.inner.read();
158            ReentrantMutexGuard { guard }
159        }
160    }
161}
162
163mod context {
164    use super::cell::{MutCell, MutGuard};
165    use alloc::boxed::Box;
166    use core::{
167        any::{Any, TypeId},
168        marker::PhantomData,
169    };
170    use hashbrown::HashMap;
171
172    use super::reentrant::{ReentrantMutex, ReentrantMutexGuard};
173
174    use crate::{device::cell::borrow_mut_split, stub::Arc};
175
176    use super::{Device, DeviceId};
177
178    /// A state that can be saved inside the [DeviceContext].
179    pub trait DeviceState: Send + 'static {
180        /// Initialize a new state on the given device.
181        fn init(device_id: DeviceId) -> Self;
182    }
183
184    /// Handle for accessing a [DeviceState] associated with a specific device.
185    pub struct DeviceContext<S: DeviceState> {
186        lock: DeviceStateLock,
187        lock_kind: Arc<ReentrantMutex<()>>,
188        device_id: DeviceId,
189        _phantom: PhantomData<S>,
190    }
191
192    /// There is nothing to read without a lock, and it's fine to allow locking a context reference.
193    unsafe impl<S: DeviceState> Sync for DeviceContext<S> {}
194
195    impl<S: DeviceState> Clone for DeviceContext<S> {
196        fn clone(&self) -> Self {
197            Self {
198                lock: self.lock.clone(),
199                lock_kind: self.lock_kind.clone(),
200                _phantom: self._phantom,
201                device_id: self.device_id,
202            }
203        }
204    }
205
206    /// Guard providing mutable access to [DeviceState].
207    ///
208    /// Automatically releases the lock when dropped.
209    pub struct DeviceStateGuard<'a, S: DeviceState> {
210        guard_ref: Option<MutGuard<'a, Box<dyn Any + Send + 'static>>>,
211        guard_mutex: Option<ReentrantMutexGuard<'a, DeviceStateMap>>,
212        _phantom: PhantomData<S>,
213    }
214
215    /// Guard making sure only the locked device can be used.
216    ///
217    /// Automatically releases the lock when dropped.
218    pub struct DeviceGuard<'a> {
219        guard_mutex: Option<ReentrantMutexGuard<'a, DeviceStateMap>>,
220    }
221
222    impl<'a, S: DeviceState> Drop for DeviceStateGuard<'a, S> {
223        fn drop(&mut self) {
224            // Important to drop the ref before.
225            self.guard_ref = None;
226            self.guard_mutex = None;
227        }
228    }
229
230    impl<'a> Drop for DeviceGuard<'a> {
231        fn drop(&mut self) {
232            self.guard_mutex = None;
233        }
234    }
235
236    impl<'a, S: DeviceState> core::ops::Deref for DeviceStateGuard<'a, S> {
237        type Target = S;
238
239        fn deref(&self) -> &Self::Target {
240            self.guard_ref
241                .as_ref()
242                .expect("The guard to not be dropped")
243                .downcast_ref()
244                .expect("The type to be correct")
245        }
246    }
247
248    impl<'a, S: DeviceState> core::ops::DerefMut for DeviceStateGuard<'a, S> {
249        fn deref_mut(&mut self) -> &mut Self::Target {
250            self.guard_ref
251                .as_mut()
252                .expect("The guard to not be dropped")
253                .downcast_mut()
254                .expect("The type to be correct")
255        }
256    }
257
258    impl<S: DeviceState> DeviceContext<S> {
259        /// Creates a [DeviceState<S>] handle for the given device.
260        ///
261        /// Registers the device-type combination globally if needed.
262        pub fn locate<D: Device + 'static>(device: &D) -> Self {
263            DeviceStateLock::locate(device)
264        }
265
266        /// Inserts a new state associated with the device.
267        ///
268        /// # Returns
269        ///
270        /// An error if the device already has a registered state.
271        pub fn insert<D: Device + 'static>(
272            device: &D,
273            state_new: S,
274        ) -> Result<Self, alloc::string::String> {
275            let lock = Self::locate(device);
276            let id = TypeId::of::<S>();
277
278            let state = lock.lock.lock.lock();
279
280            // It is safe for the same reasons enumerated in the lock function.
281            let (map, map_guard) = unsafe { borrow_mut_split(&state.map) };
282
283            if map.contains_key(&id) {
284                return Err(alloc::format!(
285                    "A server is still registered for device {device:?}"
286                ));
287            }
288
289            let any: Box<dyn Any + Send + 'static> = Box::new(state_new);
290            let cell = MutCell::new(any);
291
292            map.insert(id, cell);
293
294            core::mem::drop(map_guard);
295            core::mem::drop(state);
296
297            Ok(lock)
298        }
299
300        /// Locks all devices under the same kind.
301        ///
302        /// This is useful when you need mutable access to multiple devices at once, which can lead
303        /// to deadlocks.
304        pub fn lock_device_kind(&self) -> ReentrantMutexGuard<'_, ()> {
305            self.lock_kind.lock()
306        }
307
308        /// Locks the current device making sure this device can be used.
309        pub fn lock_device(&self) -> DeviceGuard<'_> {
310            let state = self.lock.lock.lock();
311
312            DeviceGuard {
313                guard_mutex: Some(state),
314            }
315        }
316
317        /// Acquires exclusive mutable access to the [DeviceState].
318        ///
319        /// The same device can lock multiple types at the same time.
320        ///
321        /// # Panics
322        ///
323        /// If the same state type is locked multiple times on the same thread.
324        /// This can only happen with recursive locking of the same state, which isn't allowed
325        /// since having multiple mutable references to the same state isn't valid.
326        pub fn lock(&self) -> DeviceStateGuard<'_, S> {
327            let key = TypeId::of::<S>();
328            let state = self.lock.lock.lock();
329
330            // It is safe for multiple reasons.
331            //
332            // 1. The mutability of the map is handled by each map entry with a RefCell.
333            //    Therefore, multiple mutable references to a map entry are checked.
334            // 2. Map items are never cleaned up, therefore it's impossible to remove the validity of
335            //    an entry.
336            // 3. Because of the lock, no race condition is possible.
337            //
338            // The reason why unsafe is necessary is that the [DeviceStateGuard] doesn't keep track
339            // of the borrowed map entry lifetime. But since it keeps track of both the [RefCell]
340            // and the [ReentrantMutex] guards, it is fine to erase the lifetime here.
341            let (map, map_guard) = unsafe { borrow_mut_split(&state.map) };
342
343            if !map.contains_key(&key) {
344                let state_default = S::init(self.device_id);
345                let any: Box<dyn Any + Send + 'static> = Box::new(state_default);
346                let cell = MutCell::new(any);
347
348                map.insert(key, cell);
349            }
350
351            let value = map
352                .get(&key)
353                .expect("Just validated the map contains the key.");
354            let ref_guard = match value.try_borrow_mut() {
355                Ok(guard) => guard,
356                #[cfg(feature = "std")]
357                Err(_) => panic!(
358                    "State {} is already borrowed by the current thread {:?}",
359                    core::any::type_name::<S>(),
360                    std::thread::current().id()
361                ),
362                #[cfg(not(feature = "std"))]
363                Err(_) => panic!("State {} is already borrowed", core::any::type_name::<S>(),),
364            };
365
366            core::mem::drop(map_guard);
367
368            DeviceStateGuard {
369                guard_ref: Some(ref_guard),
370                guard_mutex: Some(state),
371                _phantom: PhantomData,
372            }
373        }
374    }
375
376    type Key = (DeviceId, TypeId);
377
378    static GLOBAL: spin::Mutex<DeviceLocator> = spin::Mutex::new(DeviceLocator { state: None });
379
380    #[derive(Default)]
381    struct DeviceLocatorState {
382        device: HashMap<Key, DeviceStateLock>,
383        device_kind: HashMap<TypeId, Arc<ReentrantMutex<()>>>,
384    }
385
386    struct DeviceLocator {
387        state: Option<DeviceLocatorState>,
388    }
389
390    #[derive(Clone)]
391    struct DeviceStateLock {
392        lock: Arc<ReentrantMutex<DeviceStateMap>>,
393    }
394
395    struct DeviceStateMap {
396        map: MutCell<HashMap<TypeId, MutCell<Box<dyn Any + Send + 'static>>>>,
397    }
398
399    impl DeviceStateLock {
400        fn locate<D: Device + 'static, S: DeviceState>(device: &D) -> DeviceContext<S> {
401            let id = device.to_id();
402            let kind = TypeId::of::<D>();
403            let key = (id, TypeId::of::<D>());
404            let mut global = GLOBAL.lock();
405
406            let locator_state = match &mut global.state {
407                Some(state) => state,
408                None => {
409                    global.state = Some(Default::default());
410                    global.state.as_mut().expect("Just created Option::Some")
411                }
412            };
413
414            let lock = match locator_state.device.get(&key) {
415                Some(value) => value.clone(),
416                None => {
417                    let state = DeviceStateMap::new();
418
419                    let value = DeviceStateLock {
420                        lock: Arc::new(ReentrantMutex::new(state)),
421                    };
422
423                    locator_state.device.insert(key, value);
424                    locator_state
425                        .device
426                        .get(&key)
427                        .expect("Just inserted the key/value")
428                        .clone()
429                }
430            };
431            let lock_kind = match locator_state.device_kind.get(&kind) {
432                Some(value) => value.clone(),
433                None => {
434                    locator_state
435                        .device_kind
436                        .insert(kind, Arc::new(ReentrantMutex::new(())));
437                    locator_state
438                        .device_kind
439                        .get(&kind)
440                        .expect("Just inserted the key/value")
441                        .clone()
442                }
443            };
444
445            DeviceContext {
446                lock,
447                lock_kind,
448                device_id: id,
449                _phantom: PhantomData,
450            }
451        }
452    }
453
454    impl DeviceStateMap {
455        fn new() -> Self {
456            Self {
457                map: MutCell::new(HashMap::new()),
458            }
459        }
460    }
461
462    #[cfg(test)]
463    mod tests {
464        use core::{
465            ops::{Deref, DerefMut},
466            time::Duration,
467        };
468
469        use super::*;
470
471        #[test]
472        fn can_have_multiple_mutate_state() {
473            let device1 = TestDevice::<0>::new(0);
474            let device2 = TestDevice::<1>::new(0);
475
476            let state1_usize = DeviceContext::<usize>::locate(&device1);
477            let state1_u32 = DeviceContext::<u32>::locate(&device1);
478            let state2_usize = DeviceContext::<usize>::locate(&device2);
479
480            let mut guard_usize = state1_usize.lock();
481            let mut guard_u32 = state1_u32.lock();
482
483            let val_usize = guard_usize.deref_mut();
484            let val_u32 = guard_u32.deref_mut();
485
486            *val_usize += 1;
487            *val_u32 += 2;
488
489            assert_eq!(*val_usize, 1);
490            assert_eq!(*val_u32, 2);
491
492            core::mem::drop(guard_usize);
493            core::mem::drop(guard_u32);
494
495            let mut guard_usize = state2_usize.lock();
496
497            let val_usize = guard_usize.deref_mut();
498            *val_usize += 1;
499
500            assert_eq!(*val_usize, 1);
501
502            core::mem::drop(guard_usize);
503
504            let guard_usize = state1_usize.lock();
505            let guard_u32 = state1_u32.lock();
506
507            let val_usize = guard_usize.deref();
508            let val_u32 = guard_u32.deref();
509
510            assert_eq!(*val_usize, 1);
511            assert_eq!(*val_u32, 2);
512        }
513
514        #[test]
515        #[should_panic]
516        fn can_not_have_multiple_mut_ref_to_same_state() {
517            let device1 = TestDevice::<0>::new(0);
518
519            struct DummyState;
520
521            impl DeviceState for DummyState {
522                fn init(_device_id: DeviceId) -> Self {
523                    DummyState
524                }
525            }
526
527            fn recursive(total: usize, state: &DeviceContext<DummyState>) {
528                let _guard = state.lock();
529
530                if total > 0 {
531                    recursive(total - 1, state);
532                }
533            }
534
535            recursive(5, &DeviceContext::locate(&device1));
536        }
537
538        #[test]
539        fn work_with_many_threads() {
540            let num_threads = 32;
541            let handles: Vec<_> = (0..num_threads)
542                .map(|i| std::thread::spawn(move || thread_main((num_threads * 4) - i)))
543                .collect();
544
545            handles.into_iter().for_each(|h| h.join().unwrap());
546
547            let device1 = TestDevice::<0>::new(0);
548            let device2 = TestDevice::<1>::new(0);
549
550            let state1_i64 = DeviceContext::<i64>::locate(&device1);
551            let state1_i32 = DeviceContext::<i32>::locate(&device1);
552            let state2_i32 = DeviceContext::<i32>::locate(&device2);
553
554            let guard_i64 = state1_i64.lock();
555            let guard_i32 = state1_i32.lock();
556
557            assert_eq!(*guard_i64, num_threads as i64);
558            assert_eq!(*guard_i32, num_threads as i32 * 2);
559
560            core::mem::drop(guard_i64);
561            core::mem::drop(guard_i32);
562
563            let guard_i32 = state2_i32.lock();
564            assert_eq!(*guard_i32, num_threads as i32);
565        }
566
567        fn thread_main(sleep: u64) {
568            let device1 = TestDevice::<0>::new(0);
569            let device2 = TestDevice::<1>::new(0);
570
571            let state1_i64 = DeviceContext::<i64>::locate(&device1);
572            let state1_i32 = DeviceContext::<i32>::locate(&device1);
573            let state2_i32 = DeviceContext::<i32>::locate(&device2);
574
575            let mut guard_i64 = state1_i64.lock();
576            let mut guard_i32 = state1_i32.lock();
577
578            let val_i64 = guard_i64.deref_mut();
579            let val_i32 = guard_i32.deref_mut();
580
581            *val_i64 += 1;
582            *val_i32 += 2;
583
584            core::mem::drop(guard_i64);
585            core::mem::drop(guard_i32);
586
587            std::thread::sleep(Duration::from_millis(sleep));
588
589            let mut guard_i32 = state2_i32.lock();
590
591            let val_i32 = guard_i32.deref_mut();
592            *val_i32 += 1;
593
594            core::mem::drop(guard_i32);
595        }
596
597        #[derive(Debug, Clone, Default, new)]
598        /// Type is only to create different type ids.
599        pub struct TestDevice<const TYPE: u8> {
600            index: u32,
601        }
602
603        impl<const TYPE: u8> Device for TestDevice<TYPE> {
604            fn from_id(device_id: DeviceId) -> Self {
605                Self {
606                    index: device_id.index_id,
607                }
608            }
609
610            fn to_id(&self) -> DeviceId {
611                DeviceId {
612                    type_id: 0,
613                    index_id: self.index,
614                }
615            }
616
617            fn device_count(_type_id: u16) -> usize {
618                TYPE as usize + 1
619            }
620        }
621
622        impl DeviceState for usize {
623            fn init(_device_id: DeviceId) -> Self {
624                0
625            }
626        }
627
628        impl DeviceState for u32 {
629            fn init(_device_id: DeviceId) -> Self {
630                0
631            }
632        }
633        impl DeviceState for i32 {
634            fn init(_device_id: DeviceId) -> Self {
635                0
636            }
637        }
638        impl DeviceState for i64 {
639            fn init(_device_id: DeviceId) -> Self {
640                0
641            }
642        }
643    }
644}