1use core::cmp::Ordering;
2
3#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
5pub struct DeviceId {
6 pub type_id: u16,
8 pub index_id: u32,
10}
11
12pub trait Device: Default + Clone + core::fmt::Debug + Send + Sync + 'static {
14 fn from_id(device_id: DeviceId) -> Self;
16 fn to_id(&self) -> DeviceId;
18 fn device_count(type_id: u16) -> usize;
20 fn device_count_total() -> usize {
22 Self::device_count(0)
23 }
24}
25
26impl core::fmt::Display for DeviceId {
27 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28 f.write_fmt(format_args!("{self:?}"))
29 }
30}
31
32impl Ord for DeviceId {
33 fn cmp(&self, other: &Self) -> Ordering {
34 match self.type_id.cmp(&other.type_id) {
35 Ordering::Equal => self.index_id.cmp(&other.index_id),
36 other => other,
37 }
38 }
39}
40
41impl PartialOrd for DeviceId {
42 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
43 Some(self.cmp(other))
44 }
45}
46
47pub use context::*;
48
49#[cfg(feature = "std")]
50mod reentrant {
51 pub use parking_lot::{ReentrantMutex, ReentrantMutexGuard};
52}
53
54#[cfg(feature = "std")]
57mod cell {
58 use core::cell::{RefCell, RefMut};
59 use core::ops::DerefMut;
60
61 pub type MutCell<T> = RefCell<T>;
62 pub type MutGuard<'a, T> = RefMut<'a, T>;
63
64 pub unsafe fn borrow_mut_split<'a, T>(cell: &MutCell<T>) -> (&'a mut T, MutGuard<'_, T>) {
65 let mut guard = cell.borrow_mut();
66 let item = guard.deref_mut();
67 let item: &'a mut T = unsafe { core::mem::transmute(item) };
68
69 (item, guard)
70 }
71}
72
73#[cfg(not(feature = "std"))]
74mod cell {
75 use core::ops::{Deref, DerefMut};
76
77 pub struct MutGuard<'a, T> {
78 guard: spin::MutexGuard<'a, T>,
79 }
80
81 pub struct MutCell<T> {
82 lock: spin::Mutex<T>,
83 }
84
85 impl<T> MutCell<T> {
86 pub fn new(item: T) -> Self {
87 Self {
88 lock: spin::Mutex::new(item),
89 }
90 }
91 }
92
93 impl<'a, T> Deref for MutGuard<'a, T> {
94 type Target = T;
95
96 fn deref(&self) -> &Self::Target {
97 self.guard.deref()
98 }
99 }
100
101 impl<'a, T> DerefMut for MutGuard<'a, T> {
102 fn deref_mut(&mut self) -> &mut Self::Target {
103 self.guard.deref_mut()
104 }
105 }
106
107 impl<T> MutCell<T> {
108 pub fn try_borrow_mut(&self) -> Result<MutGuard<'_, T>, ()> {
109 match self.lock.try_lock() {
110 Some(guard) => Ok(MutGuard { guard }),
111 None => Err(()),
112 }
113 }
114 }
115
116 pub unsafe fn borrow_mut_split<'a, T>(
117 cell: &MutCell<T>,
118 ) -> (&'a mut T, spin::MutexGuard<'_, T>) {
119 let mut guard = cell.lock.lock();
120 let item = guard.deref_mut();
121 let item: &'a mut T = unsafe { core::mem::transmute(item) };
122
123 (item, guard)
124 }
125}
126
127#[cfg(not(feature = "std"))]
128mod reentrant {
129 use core::ops::Deref;
130
131 pub struct ReentrantMutex<T> {
132 inner: spin::RwLock<T>,
133 }
134
135 impl<T> ReentrantMutex<T> {
136 pub fn new(item: T) -> Self {
137 Self {
138 inner: spin::RwLock::new(item),
139 }
140 }
141 }
142
143 pub struct ReentrantMutexGuard<'a, T> {
144 guard: spin::RwLockReadGuard<'a, T>,
145 }
146
147 impl<'a, T> Deref for ReentrantMutexGuard<'a, T> {
148 type Target = T;
149
150 fn deref(&self) -> &Self::Target {
151 self.guard.deref()
152 }
153 }
154
155 impl<T> ReentrantMutex<T> {
156 pub fn lock(&self) -> ReentrantMutexGuard<'_, T> {
157 let guard = self.inner.read();
158 ReentrantMutexGuard { guard }
159 }
160 }
161}
162
163mod context {
164 use super::cell::{MutCell, MutGuard};
165 use alloc::boxed::Box;
166 use core::{
167 any::{Any, TypeId},
168 marker::PhantomData,
169 };
170 use hashbrown::HashMap;
171
172 use super::reentrant::{ReentrantMutex, ReentrantMutexGuard};
173
174 use crate::{device::cell::borrow_mut_split, stub::Arc};
175
176 use super::{Device, DeviceId};
177
178 pub trait DeviceState: Send + 'static {
180 fn init(device_id: DeviceId) -> Self;
182 }
183
184 pub struct DeviceContext<S: DeviceState> {
186 lock: DeviceStateLock,
187 lock_kind: Arc<ReentrantMutex<()>>,
188 device_id: DeviceId,
189 _phantom: PhantomData<S>,
190 }
191
192 unsafe impl<S: DeviceState> Sync for DeviceContext<S> {}
194
195 impl<S: DeviceState> Clone for DeviceContext<S> {
196 fn clone(&self) -> Self {
197 Self {
198 lock: self.lock.clone(),
199 lock_kind: self.lock_kind.clone(),
200 _phantom: self._phantom,
201 device_id: self.device_id,
202 }
203 }
204 }
205
206 pub struct DeviceStateGuard<'a, S: DeviceState> {
210 guard_ref: Option<MutGuard<'a, Box<dyn Any + Send + 'static>>>,
211 guard_mutex: Option<ReentrantMutexGuard<'a, DeviceStateMap>>,
212 _phantom: PhantomData<S>,
213 }
214
215 pub struct DeviceGuard<'a> {
219 guard_mutex: Option<ReentrantMutexGuard<'a, DeviceStateMap>>,
220 }
221
222 impl<'a, S: DeviceState> Drop for DeviceStateGuard<'a, S> {
223 fn drop(&mut self) {
224 self.guard_ref = None;
226 self.guard_mutex = None;
227 }
228 }
229
230 impl<'a> Drop for DeviceGuard<'a> {
231 fn drop(&mut self) {
232 self.guard_mutex = None;
233 }
234 }
235
236 impl<'a, S: DeviceState> core::ops::Deref for DeviceStateGuard<'a, S> {
237 type Target = S;
238
239 fn deref(&self) -> &Self::Target {
240 self.guard_ref
241 .as_ref()
242 .expect("The guard to not be dropped")
243 .downcast_ref()
244 .expect("The type to be correct")
245 }
246 }
247
248 impl<'a, S: DeviceState> core::ops::DerefMut for DeviceStateGuard<'a, S> {
249 fn deref_mut(&mut self) -> &mut Self::Target {
250 self.guard_ref
251 .as_mut()
252 .expect("The guard to not be dropped")
253 .downcast_mut()
254 .expect("The type to be correct")
255 }
256 }
257
258 impl<S: DeviceState> DeviceContext<S> {
259 pub fn locate<D: Device + 'static>(device: &D) -> Self {
263 DeviceStateLock::locate(device)
264 }
265
266 pub fn insert<D: Device + 'static>(
272 device: &D,
273 state_new: S,
274 ) -> Result<Self, alloc::string::String> {
275 let lock = Self::locate(device);
276 let id = TypeId::of::<S>();
277
278 let state = lock.lock.lock.lock();
279
280 let (map, map_guard) = unsafe { borrow_mut_split(&state.map) };
282
283 if map.contains_key(&id) {
284 return Err(alloc::format!(
285 "A server is still registered for device {device:?}"
286 ));
287 }
288
289 let any: Box<dyn Any + Send + 'static> = Box::new(state_new);
290 let cell = MutCell::new(any);
291
292 map.insert(id, cell);
293
294 core::mem::drop(map_guard);
295 core::mem::drop(state);
296
297 Ok(lock)
298 }
299
300 pub fn lock_device_kind(&self) -> ReentrantMutexGuard<'_, ()> {
305 self.lock_kind.lock()
306 }
307
308 pub fn lock_device(&self) -> DeviceGuard<'_> {
310 let state = self.lock.lock.lock();
311
312 DeviceGuard {
313 guard_mutex: Some(state),
314 }
315 }
316
317 pub fn lock(&self) -> DeviceStateGuard<'_, S> {
327 let key = TypeId::of::<S>();
328 let state = self.lock.lock.lock();
329
330 let (map, map_guard) = unsafe { borrow_mut_split(&state.map) };
342
343 if !map.contains_key(&key) {
344 let state_default = S::init(self.device_id);
345 let any: Box<dyn Any + Send + 'static> = Box::new(state_default);
346 let cell = MutCell::new(any);
347
348 map.insert(key, cell);
349 }
350
351 let value = map
352 .get(&key)
353 .expect("Just validated the map contains the key.");
354 let ref_guard = match value.try_borrow_mut() {
355 Ok(guard) => guard,
356 #[cfg(feature = "std")]
357 Err(_) => panic!(
358 "State {} is already borrowed by the current thread {:?}",
359 core::any::type_name::<S>(),
360 std::thread::current().id()
361 ),
362 #[cfg(not(feature = "std"))]
363 Err(_) => panic!("State {} is already borrowed", core::any::type_name::<S>(),),
364 };
365
366 core::mem::drop(map_guard);
367
368 DeviceStateGuard {
369 guard_ref: Some(ref_guard),
370 guard_mutex: Some(state),
371 _phantom: PhantomData,
372 }
373 }
374 }
375
376 type Key = (DeviceId, TypeId);
377
378 static GLOBAL: spin::Mutex<DeviceLocator> = spin::Mutex::new(DeviceLocator { state: None });
379
380 #[derive(Default)]
381 struct DeviceLocatorState {
382 device: HashMap<Key, DeviceStateLock>,
383 device_kind: HashMap<TypeId, Arc<ReentrantMutex<()>>>,
384 }
385
386 struct DeviceLocator {
387 state: Option<DeviceLocatorState>,
388 }
389
390 #[derive(Clone)]
391 struct DeviceStateLock {
392 lock: Arc<ReentrantMutex<DeviceStateMap>>,
393 }
394
395 struct DeviceStateMap {
396 map: MutCell<HashMap<TypeId, MutCell<Box<dyn Any + Send + 'static>>>>,
397 }
398
399 impl DeviceStateLock {
400 fn locate<D: Device + 'static, S: DeviceState>(device: &D) -> DeviceContext<S> {
401 let id = device.to_id();
402 let kind = TypeId::of::<D>();
403 let key = (id, TypeId::of::<D>());
404 let mut global = GLOBAL.lock();
405
406 let locator_state = match &mut global.state {
407 Some(state) => state,
408 None => {
409 global.state = Some(Default::default());
410 global.state.as_mut().expect("Just created Option::Some")
411 }
412 };
413
414 let lock = match locator_state.device.get(&key) {
415 Some(value) => value.clone(),
416 None => {
417 let state = DeviceStateMap::new();
418
419 let value = DeviceStateLock {
420 lock: Arc::new(ReentrantMutex::new(state)),
421 };
422
423 locator_state.device.insert(key, value);
424 locator_state
425 .device
426 .get(&key)
427 .expect("Just inserted the key/value")
428 .clone()
429 }
430 };
431 let lock_kind = match locator_state.device_kind.get(&kind) {
432 Some(value) => value.clone(),
433 None => {
434 locator_state
435 .device_kind
436 .insert(kind, Arc::new(ReentrantMutex::new(())));
437 locator_state
438 .device_kind
439 .get(&kind)
440 .expect("Just inserted the key/value")
441 .clone()
442 }
443 };
444
445 DeviceContext {
446 lock,
447 lock_kind,
448 device_id: id,
449 _phantom: PhantomData,
450 }
451 }
452 }
453
454 impl DeviceStateMap {
455 fn new() -> Self {
456 Self {
457 map: MutCell::new(HashMap::new()),
458 }
459 }
460 }
461
462 #[cfg(test)]
463 mod tests {
464 use core::{
465 ops::{Deref, DerefMut},
466 time::Duration,
467 };
468
469 use super::*;
470
471 #[test]
472 fn can_have_multiple_mutate_state() {
473 let device1 = TestDevice::<0>::new(0);
474 let device2 = TestDevice::<1>::new(0);
475
476 let state1_usize = DeviceContext::<usize>::locate(&device1);
477 let state1_u32 = DeviceContext::<u32>::locate(&device1);
478 let state2_usize = DeviceContext::<usize>::locate(&device2);
479
480 let mut guard_usize = state1_usize.lock();
481 let mut guard_u32 = state1_u32.lock();
482
483 let val_usize = guard_usize.deref_mut();
484 let val_u32 = guard_u32.deref_mut();
485
486 *val_usize += 1;
487 *val_u32 += 2;
488
489 assert_eq!(*val_usize, 1);
490 assert_eq!(*val_u32, 2);
491
492 core::mem::drop(guard_usize);
493 core::mem::drop(guard_u32);
494
495 let mut guard_usize = state2_usize.lock();
496
497 let val_usize = guard_usize.deref_mut();
498 *val_usize += 1;
499
500 assert_eq!(*val_usize, 1);
501
502 core::mem::drop(guard_usize);
503
504 let guard_usize = state1_usize.lock();
505 let guard_u32 = state1_u32.lock();
506
507 let val_usize = guard_usize.deref();
508 let val_u32 = guard_u32.deref();
509
510 assert_eq!(*val_usize, 1);
511 assert_eq!(*val_u32, 2);
512 }
513
514 #[test]
515 #[should_panic]
516 fn can_not_have_multiple_mut_ref_to_same_state() {
517 let device1 = TestDevice::<0>::new(0);
518
519 struct DummyState;
520
521 impl DeviceState for DummyState {
522 fn init(_device_id: DeviceId) -> Self {
523 DummyState
524 }
525 }
526
527 fn recursive(total: usize, state: &DeviceContext<DummyState>) {
528 let _guard = state.lock();
529
530 if total > 0 {
531 recursive(total - 1, state);
532 }
533 }
534
535 recursive(5, &DeviceContext::locate(&device1));
536 }
537
538 #[test]
539 fn work_with_many_threads() {
540 let num_threads = 32;
541 let handles: Vec<_> = (0..num_threads)
542 .map(|i| std::thread::spawn(move || thread_main((num_threads * 4) - i)))
543 .collect();
544
545 handles.into_iter().for_each(|h| h.join().unwrap());
546
547 let device1 = TestDevice::<0>::new(0);
548 let device2 = TestDevice::<1>::new(0);
549
550 let state1_i64 = DeviceContext::<i64>::locate(&device1);
551 let state1_i32 = DeviceContext::<i32>::locate(&device1);
552 let state2_i32 = DeviceContext::<i32>::locate(&device2);
553
554 let guard_i64 = state1_i64.lock();
555 let guard_i32 = state1_i32.lock();
556
557 assert_eq!(*guard_i64, num_threads as i64);
558 assert_eq!(*guard_i32, num_threads as i32 * 2);
559
560 core::mem::drop(guard_i64);
561 core::mem::drop(guard_i32);
562
563 let guard_i32 = state2_i32.lock();
564 assert_eq!(*guard_i32, num_threads as i32);
565 }
566
567 fn thread_main(sleep: u64) {
568 let device1 = TestDevice::<0>::new(0);
569 let device2 = TestDevice::<1>::new(0);
570
571 let state1_i64 = DeviceContext::<i64>::locate(&device1);
572 let state1_i32 = DeviceContext::<i32>::locate(&device1);
573 let state2_i32 = DeviceContext::<i32>::locate(&device2);
574
575 let mut guard_i64 = state1_i64.lock();
576 let mut guard_i32 = state1_i32.lock();
577
578 let val_i64 = guard_i64.deref_mut();
579 let val_i32 = guard_i32.deref_mut();
580
581 *val_i64 += 1;
582 *val_i32 += 2;
583
584 core::mem::drop(guard_i64);
585 core::mem::drop(guard_i32);
586
587 std::thread::sleep(Duration::from_millis(sleep));
588
589 let mut guard_i32 = state2_i32.lock();
590
591 let val_i32 = guard_i32.deref_mut();
592 *val_i32 += 1;
593
594 core::mem::drop(guard_i32);
595 }
596
597 #[derive(Debug, Clone, Default, new)]
598 pub struct TestDevice<const TYPE: u8> {
600 index: u32,
601 }
602
603 impl<const TYPE: u8> Device for TestDevice<TYPE> {
604 fn from_id(device_id: DeviceId) -> Self {
605 Self {
606 index: device_id.index_id,
607 }
608 }
609
610 fn to_id(&self) -> DeviceId {
611 DeviceId {
612 type_id: 0,
613 index_id: self.index,
614 }
615 }
616
617 fn device_count(_type_id: u16) -> usize {
618 TYPE as usize + 1
619 }
620 }
621
622 impl DeviceState for usize {
623 fn init(_device_id: DeviceId) -> Self {
624 0
625 }
626 }
627
628 impl DeviceState for u32 {
629 fn init(_device_id: DeviceId) -> Self {
630 0
631 }
632 }
633 impl DeviceState for i32 {
634 fn init(_device_id: DeviceId) -> Self {
635 0
636 }
637 }
638 impl DeviceState for i64 {
639 fn init(_device_id: DeviceId) -> Self {
640 0
641 }
642 }
643 }
644}