1use core::cmp::Ordering;
2
3#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
5pub struct DeviceId {
6 pub type_id: u16,
8 pub index_id: u32,
10}
11
12pub trait Device: Default + Clone + core::fmt::Debug + Send + Sync + 'static {
14 fn from_id(device_id: DeviceId) -> Self;
16 fn to_id(&self) -> DeviceId;
18 fn device_count(type_id: u16) -> usize;
20 fn device_count_total() -> usize {
22 Self::device_count(0)
23 }
24}
25
26impl core::fmt::Display for DeviceId {
27 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28 f.write_fmt(format_args!("{self:?}"))
29 }
30}
31
32impl Ord for DeviceId {
33 fn cmp(&self, other: &Self) -> Ordering {
34 match self.type_id.cmp(&other.type_id) {
35 Ordering::Equal => self.index_id.cmp(&other.index_id),
36 other => other,
37 }
38 }
39}
40
41impl PartialOrd for DeviceId {
42 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
43 Some(self.cmp(other))
44 }
45}
46
47pub use context::*;
48
49#[cfg(feature = "std")]
50mod reentrant {
51 pub use parking_lot::{ReentrantMutex, ReentrantMutexGuard};
52}
53
54#[cfg(feature = "std")]
57mod cell {
58 use core::cell::{RefCell, RefMut};
59 use core::ops::DerefMut;
60
61 pub type MutCell<T> = RefCell<T>;
62 pub type MutGuard<'a, T> = RefMut<'a, T>;
63
64 pub unsafe fn borrow_mut_split<'a, T>(cell: &MutCell<T>) -> (&'a mut T, MutGuard<'_, T>) {
65 let mut guard = cell.borrow_mut();
66 let item = guard.deref_mut();
67 let item: &'a mut T = unsafe { core::mem::transmute(item) };
68
69 (item, guard)
70 }
71}
72
73#[cfg(not(feature = "std"))]
74mod cell {
75 use core::ops::{Deref, DerefMut};
76
77 pub struct MutGuard<'a, T> {
78 guard: spin::MutexGuard<'a, T>,
79 }
80
81 pub struct MutCell<T> {
82 lock: spin::Mutex<T>,
83 }
84
85 impl<T> MutCell<T> {
86 pub fn new(item: T) -> Self {
87 Self {
88 lock: spin::Mutex::new(item),
89 }
90 }
91 }
92
93 impl<'a, T> Deref for MutGuard<'a, T> {
94 type Target = T;
95
96 fn deref(&self) -> &Self::Target {
97 self.guard.deref()
98 }
99 }
100
101 impl<'a, T> DerefMut for MutGuard<'a, T> {
102 fn deref_mut(&mut self) -> &mut Self::Target {
103 self.guard.deref_mut()
104 }
105 }
106
107 impl<T> MutCell<T> {
108 pub fn try_borrow_mut(&self) -> Result<MutGuard<'_, T>, ()> {
109 match self.lock.try_lock() {
110 Some(guard) => Ok(MutGuard { guard }),
111 None => Err(()),
112 }
113 }
114 }
115
116 pub unsafe fn borrow_mut_split<'a, T>(
117 cell: &MutCell<T>,
118 ) -> (&'a mut T, spin::MutexGuard<'_, T>) {
119 let mut guard = cell.lock.lock();
120 let item = guard.deref_mut();
121 let item: &'a mut T = unsafe { core::mem::transmute(item) };
122
123 (item, guard)
124 }
125}
126
127#[cfg(not(feature = "std"))]
128mod reentrant {
129 use core::ops::Deref;
130
131 pub struct ReentrantMutex<T> {
132 inner: spin::RwLock<T>,
133 }
134
135 impl<T> ReentrantMutex<T> {
136 pub fn new(item: T) -> Self {
137 Self {
138 inner: spin::RwLock::new(item),
139 }
140 }
141 }
142
143 pub struct ReentrantMutexGuard<'a, T> {
144 guard: spin::RwLockReadGuard<'a, T>,
145 }
146
147 impl<'a, T> Deref for ReentrantMutexGuard<'a, T> {
148 type Target = T;
149
150 fn deref(&self) -> &Self::Target {
151 self.guard.deref()
152 }
153 }
154
155 impl<T> ReentrantMutex<T> {
156 pub fn lock(&self) -> ReentrantMutexGuard<'_, T> {
157 let guard = self.inner.read();
158 ReentrantMutexGuard { guard }
159 }
160 }
161}
162
163mod context {
164 use super::cell::{MutCell, MutGuard};
165 use alloc::boxed::Box;
166 use core::{
167 any::{Any, TypeId},
168 marker::PhantomData,
169 };
170 use hashbrown::HashMap;
171
172 use super::reentrant::{ReentrantMutex, ReentrantMutexGuard};
173
174 use crate::{device::cell::borrow_mut_split, stub::Arc};
175
176 use super::{Device, DeviceId};
177
178 pub trait DeviceState: Send + 'static {
180 fn init(device_id: DeviceId) -> Self;
182 }
183
184 pub struct DeviceContext<S: DeviceState> {
186 lock: DeviceStateLock,
187 device_id: DeviceId,
188 _phantom: PhantomData<S>,
189 }
190
191 unsafe impl<S: DeviceState> Sync for DeviceContext<S> {}
193
194 impl<S: DeviceState> Clone for DeviceContext<S> {
195 fn clone(&self) -> Self {
196 Self {
197 lock: self.lock.clone(),
198 _phantom: self._phantom,
199 device_id: self.device_id,
200 }
201 }
202 }
203
204 pub struct DeviceStateGuard<'a, S: DeviceState> {
208 guard_ref: Option<MutGuard<'a, Box<dyn Any + Send + 'static>>>,
209 guard_mutex: Option<ReentrantMutexGuard<'a, DeviceStateMap>>,
210 _phantom: PhantomData<S>,
211 }
212
213 pub struct DeviceGuard<'a> {
217 guard_mutex: Option<ReentrantMutexGuard<'a, DeviceStateMap>>,
218 }
219
220 impl<'a, S: DeviceState> Drop for DeviceStateGuard<'a, S> {
221 fn drop(&mut self) {
222 self.guard_ref = None;
224 self.guard_mutex = None;
225 }
226 }
227
228 impl<'a> Drop for DeviceGuard<'a> {
229 fn drop(&mut self) {
230 self.guard_mutex = None;
231 }
232 }
233
234 impl<'a, S: DeviceState> core::ops::Deref for DeviceStateGuard<'a, S> {
235 type Target = S;
236
237 fn deref(&self) -> &Self::Target {
238 self.guard_ref
239 .as_ref()
240 .expect("The guard to not be dropped")
241 .downcast_ref()
242 .expect("The type to be correct")
243 }
244 }
245
246 impl<'a, S: DeviceState> core::ops::DerefMut for DeviceStateGuard<'a, S> {
247 fn deref_mut(&mut self) -> &mut Self::Target {
248 self.guard_ref
249 .as_mut()
250 .expect("The guard to not be dropped")
251 .downcast_mut()
252 .expect("The type to be correct")
253 }
254 }
255
256 impl<S: DeviceState> DeviceContext<S> {
257 pub fn locate<D: Device + 'static>(device: &D) -> Self {
261 DeviceStateLock::locate(device)
262 }
263
264 pub fn insert<D: Device + 'static>(
270 device: &D,
271 state_new: S,
272 ) -> Result<Self, alloc::string::String> {
273 let lock = Self::locate(device);
274 let id = TypeId::of::<S>();
275
276 let state = lock.lock.lock.lock();
277
278 let (map, map_guard) = unsafe { borrow_mut_split(&state.map) };
280
281 if map.contains_key(&id) {
282 return Err(alloc::format!(
283 "A server is still registered for device {:?}",
284 device
285 ));
286 }
287
288 let any: Box<dyn Any + Send + 'static> = Box::new(state_new);
289 let cell = MutCell::new(any);
290
291 map.insert(id, cell);
292
293 core::mem::drop(map_guard);
294 core::mem::drop(state);
295
296 Ok(lock)
297 }
298
299 pub fn lock_device(&self) -> DeviceGuard<'_> {
301 let state = self.lock.lock.lock();
302
303 DeviceGuard {
304 guard_mutex: Some(state),
305 }
306 }
307
308 pub fn lock(&self) -> DeviceStateGuard<'_, S> {
318 let key = TypeId::of::<S>();
319 let state = self.lock.lock.lock();
320
321 let (map, map_guard) = unsafe { borrow_mut_split(&state.map) };
333
334 if !map.contains_key(&key) {
335 let state_default = S::init(self.device_id);
336 let any: Box<dyn Any + Send + 'static> = Box::new(state_default);
337 let cell = MutCell::new(any);
338
339 map.insert(key, cell);
340 }
341
342 let value = map
343 .get(&key)
344 .expect("Just validated the map contains the key.");
345 let ref_guard = match value.try_borrow_mut() {
346 Ok(guard) => guard,
347 #[cfg(feature = "std")]
348 Err(_) => panic!(
349 "State {} is already borrowed by the current thread {:?}",
350 core::any::type_name::<S>(),
351 std::thread::current().id()
352 ),
353 #[cfg(not(feature = "std"))]
354 Err(_) => panic!("State {} is already borrowed", core::any::type_name::<S>(),),
355 };
356
357 core::mem::drop(map_guard);
358
359 DeviceStateGuard {
360 guard_ref: Some(ref_guard),
361 guard_mutex: Some(state),
362 _phantom: PhantomData,
363 }
364 }
365 }
366
367 type Key = (DeviceId, TypeId);
368
369 static GLOBAL: spin::Mutex<DeviceLocator> = spin::Mutex::new(DeviceLocator { state: None });
370
371 struct DeviceLocator {
372 state: Option<HashMap<Key, DeviceStateLock>>,
373 }
374
375 #[derive(Clone)]
376 struct DeviceStateLock {
377 lock: Arc<ReentrantMutex<DeviceStateMap>>,
378 }
379
380 struct DeviceStateMap {
381 map: MutCell<HashMap<TypeId, MutCell<Box<dyn Any + Send + 'static>>>>,
382 }
383
384 impl DeviceStateLock {
385 fn locate<D: Device + 'static, S: DeviceState>(device: &D) -> DeviceContext<S> {
386 let id = device.to_id();
387 let key = (id, TypeId::of::<D>());
388 let mut global = GLOBAL.lock();
389
390 let map = match &mut global.state {
391 Some(state) => state,
392 None => {
393 global.state = Some(HashMap::default());
394 global.state.as_mut().expect("Just created Option::Some")
395 }
396 };
397
398 let lock = match map.get(&key) {
399 Some(value) => value.clone(),
400 None => {
401 let state = DeviceStateMap::new();
402
403 let value = DeviceStateLock {
404 lock: Arc::new(ReentrantMutex::new(state)),
405 };
406
407 map.insert(key, value);
408 map.get(&key).expect("Just inserted the key/value").clone()
409 }
410 };
411
412 DeviceContext {
413 lock,
414 device_id: id,
415 _phantom: PhantomData,
416 }
417 }
418 }
419
420 impl DeviceStateMap {
421 fn new() -> Self {
422 Self {
423 map: MutCell::new(HashMap::new()),
424 }
425 }
426 }
427
428 #[cfg(test)]
429 mod tests {
430 use core::{
431 ops::{Deref, DerefMut},
432 time::Duration,
433 };
434
435 use super::*;
436
437 #[test]
438 fn can_have_multiple_mutate_state() {
439 let device1 = TestDevice::<0>::new(0);
440 let device2 = TestDevice::<1>::new(0);
441
442 let state1_usize = DeviceContext::<usize>::locate(&device1);
443 let state1_u32 = DeviceContext::<u32>::locate(&device1);
444 let state2_usize = DeviceContext::<usize>::locate(&device2);
445
446 let mut guard_usize = state1_usize.lock();
447 let mut guard_u32 = state1_u32.lock();
448
449 let val_usize = guard_usize.deref_mut();
450 let val_u32 = guard_u32.deref_mut();
451
452 *val_usize += 1;
453 *val_u32 += 2;
454
455 assert_eq!(*val_usize, 1);
456 assert_eq!(*val_u32, 2);
457
458 core::mem::drop(guard_usize);
459 core::mem::drop(guard_u32);
460
461 let mut guard_usize = state2_usize.lock();
462
463 let val_usize = guard_usize.deref_mut();
464 *val_usize += 1;
465
466 assert_eq!(*val_usize, 1);
467
468 core::mem::drop(guard_usize);
469
470 let guard_usize = state1_usize.lock();
471 let guard_u32 = state1_u32.lock();
472
473 let val_usize = guard_usize.deref();
474 let val_u32 = guard_u32.deref();
475
476 assert_eq!(*val_usize, 1);
477 assert_eq!(*val_u32, 2);
478 }
479
480 #[test]
481 #[should_panic]
482 fn can_not_have_multiple_mut_ref_to_same_state() {
483 let device1 = TestDevice::<0>::new(0);
484
485 struct DummyState;
486
487 impl DeviceState for DummyState {
488 fn init(_device_id: DeviceId) -> Self {
489 DummyState
490 }
491 }
492
493 fn recursive(total: usize, state: &DeviceContext<DummyState>) {
494 let _guard = state.lock();
495
496 if total > 0 {
497 recursive(total - 1, state);
498 }
499 }
500
501 recursive(5, &DeviceContext::locate(&device1));
502 }
503
504 #[test]
505 fn work_with_many_threads() {
506 let num_threads = 32;
507 let handles: Vec<_> = (0..num_threads)
508 .map(|i| std::thread::spawn(move || thread_main((num_threads * 4) - i)))
509 .collect();
510
511 handles.into_iter().for_each(|h| h.join().unwrap());
512
513 let device1 = TestDevice::<0>::new(0);
514 let device2 = TestDevice::<1>::new(0);
515
516 let state1_i64 = DeviceContext::<i64>::locate(&device1);
517 let state1_i32 = DeviceContext::<i32>::locate(&device1);
518 let state2_i32 = DeviceContext::<i32>::locate(&device2);
519
520 let guard_i64 = state1_i64.lock();
521 let guard_i32 = state1_i32.lock();
522
523 assert_eq!(*guard_i64, num_threads as i64);
524 assert_eq!(*guard_i32, num_threads as i32 * 2);
525
526 core::mem::drop(guard_i64);
527 core::mem::drop(guard_i32);
528
529 let guard_i32 = state2_i32.lock();
530 assert_eq!(*guard_i32, num_threads as i32);
531 }
532
533 fn thread_main(sleep: u64) {
534 let device1 = TestDevice::<0>::new(0);
535 let device2 = TestDevice::<1>::new(0);
536
537 let state1_i64 = DeviceContext::<i64>::locate(&device1);
538 let state1_i32 = DeviceContext::<i32>::locate(&device1);
539 let state2_i32 = DeviceContext::<i32>::locate(&device2);
540
541 let mut guard_i64 = state1_i64.lock();
542 let mut guard_i32 = state1_i32.lock();
543
544 let val_i64 = guard_i64.deref_mut();
545 let val_i32 = guard_i32.deref_mut();
546
547 *val_i64 += 1;
548 *val_i32 += 2;
549
550 core::mem::drop(guard_i64);
551 core::mem::drop(guard_i32);
552
553 std::thread::sleep(Duration::from_millis(sleep));
554
555 let mut guard_i32 = state2_i32.lock();
556
557 let val_i32 = guard_i32.deref_mut();
558 *val_i32 += 1;
559
560 core::mem::drop(guard_i32);
561 }
562
563 #[derive(Debug, Clone, Default, new)]
564 pub struct TestDevice<const TYPE: u8> {
566 index: u32,
567 }
568
569 impl<const TYPE: u8> Device for TestDevice<TYPE> {
570 fn from_id(device_id: DeviceId) -> Self {
571 Self {
572 index: device_id.index_id,
573 }
574 }
575
576 fn to_id(&self) -> DeviceId {
577 DeviceId {
578 type_id: 0,
579 index_id: self.index,
580 }
581 }
582
583 fn device_count(_type_id: u16) -> usize {
584 TYPE as usize + 1
585 }
586 }
587
588 impl DeviceState for usize {
589 fn init(_device_id: DeviceId) -> Self {
590 0
591 }
592 }
593
594 impl DeviceState for u32 {
595 fn init(_device_id: DeviceId) -> Self {
596 0
597 }
598 }
599 impl DeviceState for i32 {
600 fn init(_device_id: DeviceId) -> Self {
601 0
602 }
603 }
604 impl DeviceState for i64 {
605 fn init(_device_id: DeviceId) -> Self {
606 0
607 }
608 }
609 }
610}