use crate::{
device::{
DeviceId, DeviceService,
handle::{CallError, DeviceHandleSpec, ServerUtilitiesHandle, ServiceCreationError},
},
stream_id::StreamId,
stub::{Arc, Mutex, RwLock},
};
use alloc::boxed::Box;
use core::{
any::{Any, TypeId},
marker::PhantomData,
};
use hashbrown::HashMap;
pub struct MutexDeviceHandle<S: DeviceService> {
state: MutexDeviceState,
device_id: DeviceId,
_phantom: PhantomData<fn(S)>,
}
#[derive(Clone)]
struct MutexDeviceState {
service: Arc<Mutex<Box<dyn Any + Send>>>,
device_id: DeviceId,
utilities: ServerUtilitiesHandle,
}
static DEVICE_REGISTRY: spin::Mutex<Option<HashMap<DeviceId, DeviceRegistry>>> =
spin::Mutex::new(None);
type DeviceRegistry = HashMap<TypeId, MutexDeviceState>;
impl<S: DeviceService + 'static> DeviceHandleSpec<S> for MutexDeviceHandle<S> {
const BLOCKING: bool = true;
fn new(device_id: DeviceId) -> Self {
let mut guard = DEVICE_REGISTRY.lock();
if guard.is_none() {
*guard = Some(HashMap::new());
};
let device_map: &mut HashMap<_, _> = match guard.as_mut() {
Some(val) => val.entry(device_id).or_insert_with(HashMap::new),
None => unreachable!(),
};
let type_id = TypeId::of::<S>();
let state = device_map
.entry(type_id)
.or_insert_with(|| {
let state = S::init(device_id);
let utilities = state.utilities();
MutexDeviceState {
service: Arc::new(Mutex::new(Box::new(state))),
device_id,
utilities,
}
})
.clone();
Self {
state,
device_id,
_phantom: PhantomData,
}
}
fn device_id(&self) -> DeviceId {
self.device_id
}
fn utilities(&self) -> ServerUtilitiesHandle {
self.state.utilities.clone()
}
fn flush_queue(&self) {}
fn submit_blocking<'a, R: Send, T: FnOnce(&mut S) -> R + Send + 'a>(
&self,
task: T,
) -> Result<R, CallError> {
let mut guard = self.state.service.lock().unwrap();
let state = guard.downcast_mut::<S>().expect("State type mismatch");
Ok(task(state))
}
fn submit<T: FnOnce(&mut S) + Send + 'static>(&self, task: T) {
let mut guard = self.state.service.lock().unwrap();
let state = guard.downcast_mut::<S>().expect("State type mismatch");
task(state);
}
fn insert(device_id: DeviceId, service: S) -> Result<Self, ServiceCreationError> {
let mut guard = DEVICE_REGISTRY.lock();
if guard.is_none() {
*guard = Some(HashMap::new());
};
let device_map: &mut HashMap<_, _> = match guard.as_mut() {
Some(val) => val.entry(device_id).or_insert_with(HashMap::new),
None => unreachable!(),
};
let type_id = TypeId::of::<S>();
if device_map.contains_key(&type_id) {
return Err(ServiceCreationError::new("Service already created".into()));
}
let state = device_map
.entry(type_id)
.or_insert_with(|| {
let utilities = service.utilities();
MutexDeviceState {
service: Arc::new(Mutex::new(Box::new(service))),
device_id,
utilities,
}
})
.clone();
Ok(Self {
state,
device_id,
_phantom: PhantomData,
})
}
fn exclusive<R: Send, T: FnOnce() -> R + Send>(&self, task: T) -> Result<R, CallError> {
let lock = self.device_lock();
let guard = lock.lock();
let result = Ok(task());
core::mem::drop(guard);
result
}
}
static DEVICE_LOCK: spin::Mutex<Option<HashMap<DeviceId, Arc<DeviceLock>>>> =
spin::Mutex::new(None);
struct DeviceLock {
lock: RwLock<Option<StreamId>>,
main: spin::Mutex<()>,
}
enum Guard<'a> {
Reentrant,
Main(spin::MutexGuard<'a, ()>, &'a DeviceLock),
}
impl<'a> Drop for Guard<'a> {
fn drop(&mut self) {
match self {
Guard::Reentrant => {}
Guard::Main(_mutex_guard, thread_mutex) => {
let mut state = thread_mutex.lock.write().unwrap();
*state = None;
}
}
}
}
impl DeviceLock {
pub fn lock(&self) -> Guard<'_> {
let stream_id = StreamId::current();
loop {
let mut state = self.lock.write().unwrap();
let is_ok = match state.as_ref() {
Some(value) => *value == stream_id,
None => {
*state = Some(stream_id);
let guard = self.main.lock();
return Guard::Main(guard, self);
}
};
match is_ok {
true => {
core::mem::drop(state);
return Guard::Reentrant;
}
false => {
}
};
}
}
}
impl<S: DeviceService> MutexDeviceHandle<S> {
fn device_lock(&self) -> Arc<DeviceLock> {
let mut guard = DEVICE_LOCK.lock();
if guard.is_none() {
*guard = Some(HashMap::new());
};
let device_map = match guard.as_mut() {
Some(val) => val.entry(self.device_id),
None => unreachable!(),
};
device_map
.or_insert_with(|| {
Arc::new(DeviceLock {
lock: RwLock::new(None),
main: spin::Mutex::new(()),
})
})
.clone()
}
}
impl<S: DeviceService> Clone for MutexDeviceHandle<S> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
device_id: self.device_id,
_phantom: PhantomData,
}
}
}