use core::{cell::Cell, fmt, marker::PhantomData};
use crate::{
hunk::{CfgHunkBuilder, DefaultInitTag, Hunk, HunkIniter},
kernel::{
self,
cfg::{CfgBuilder, CfgMutexBuilder},
LockMutexError, MarkConsistentMutexError, MutexProtocol, TryLockMutexError,
},
prelude::*,
utils::Init,
};
pub struct Builder<System, T, InitTag> {
mutex: CfgMutexBuilder<System>,
hunk: CfgHunkBuilder<System, MutexInner<T>, InitTag>,
}
pub struct RecursiveMutex<System, T> {
hunk: Hunk<System, MutexInner<T>>,
mutex: kernel::Mutex<System>,
}
unsafe impl<System: Kernel, T: 'static + Send> Send for RecursiveMutex<System, T> {}
unsafe impl<System: Kernel, T: 'static + Send> Sync for RecursiveMutex<System, T> {}
#[doc(hidden)]
pub struct MutexInner<T> {
level: Cell<usize>,
data: T,
}
impl<T: Init> Init for MutexInner<T> {
const INIT: Self = Self {
level: Cell::new(0),
data: Init::INIT,
};
}
const LEVEL_ABANDONED: usize = 1;
const LEVEL_COUNT_SHIFT: u32 = 1;
#[must_use = "if unused the RecursiveMutex will immediately unlock"]
pub struct MutexGuard<'a, System: Kernel, T: 'static> {
mutex: &'a RecursiveMutex<System, T>,
_no_send_sync: PhantomData<*mut ()>,
}
unsafe impl<System: Kernel, T: 'static + Sync> Sync for MutexGuard<'_, System, T> {}
pub type LockResult<Guard> = Result<Guard, LockError<Guard>>;
pub type TryLockResult<Guard> = Result<Guard, TryLockError<Guard>>;
#[repr(i8)]
pub enum LockError<Guard> {
BadContext = LockMutexError::BadContext as i8,
Interrupted = LockMutexError::Interrupted as i8,
BadParam = LockMutexError::BadParam as i8,
Abandoned(Guard) = LockMutexError::Abandoned as i8,
}
impl<Guard> fmt::Debug for LockError<Guard> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Self::BadContext => "BadContext",
Self::Interrupted => "Interrupted",
Self::BadParam => "BadParam",
Self::Abandoned(_) => "Abandoned",
})
}
}
#[repr(i8)]
pub enum TryLockError<Guard> {
BadContext = TryLockMutexError::BadContext as i8,
WouldBlock = TryLockMutexError::Timeout as i8,
BadParam = TryLockMutexError::BadParam as i8,
Abandoned(Guard) = TryLockMutexError::Abandoned as i8,
}
impl<Guard> fmt::Debug for TryLockError<Guard> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Self::BadContext => "BadContext",
Self::WouldBlock => "WouldBlock",
Self::BadParam => "BadParam",
Self::Abandoned(_) => "Abandoned",
})
}
}
#[derive(Debug)]
#[repr(i8)]
pub enum MarkConsistentError {
BadContext = MarkConsistentMutexError::BadContext as i8,
Consistent = MarkConsistentMutexError::BadObjectState as i8,
}
impl<System: Kernel, T: 'static> RecursiveMutex<System, T> {
pub const fn build() -> Builder<System, T, DefaultInitTag> {
Builder {
mutex: kernel::Mutex::build(),
hunk: Hunk::build(),
}
}
}
impl<System: Kernel, T: 'static, InitTag> Builder<System, T, InitTag> {
pub const fn protocol(self, protocol: MutexProtocol) -> Self {
Self {
mutex: self.mutex.protocol(protocol),
..self
}
}
}
impl<System: Kernel, T: 'static, InitTag: HunkIniter<MutexInner<T>>> Builder<System, T, InitTag> {
pub const fn finish(self, cfg: &mut CfgBuilder<System>) -> RecursiveMutex<System, T> {
RecursiveMutex {
hunk: self.hunk.finish(cfg),
mutex: self.mutex.finish(cfg),
}
}
}
impl<System: Kernel, T: 'static> RecursiveMutex<System, T> {
pub fn lock(&self) -> LockResult<MutexGuard<'_, System, T>> {
let level = &self.hunk.level;
match self.mutex.lock() {
Ok(()) => {}
Err(LockMutexError::WouldDeadlock) => {
level.update(|x| {
x.checked_add(1 << LEVEL_COUNT_SHIFT)
.expect("nesting count overflow")
});
}
Err(LockMutexError::BadId) => unreachable!(),
Err(LockMutexError::BadContext) => return Err(LockError::BadContext),
Err(LockMutexError::Interrupted) => return Err(LockError::Interrupted),
Err(LockMutexError::BadParam) => return Err(LockError::BadParam),
Err(LockMutexError::Abandoned) => {
level.set(LEVEL_ABANDONED);
self.mutex.mark_consistent().unwrap();
}
}
if (level.get() & LEVEL_ABANDONED) != 0 {
Err(LockError::Abandoned(MutexGuard {
mutex: self,
_no_send_sync: PhantomData,
}))
} else {
Ok(MutexGuard {
mutex: self,
_no_send_sync: PhantomData,
})
}
}
pub fn try_lock(&self) -> TryLockResult<MutexGuard<'_, System, T>> {
let level = &self.hunk.level;
match self.mutex.try_lock() {
Ok(()) => {}
Err(TryLockMutexError::WouldDeadlock) => {
level.update(|x| {
x.checked_add(1 << LEVEL_COUNT_SHIFT)
.expect("nesting count overflow")
});
}
Err(TryLockMutexError::BadId) => unreachable!(),
Err(TryLockMutexError::BadContext) => return Err(TryLockError::BadContext),
Err(TryLockMutexError::Timeout) => return Err(TryLockError::WouldBlock),
Err(TryLockMutexError::BadParam) => return Err(TryLockError::BadParam),
Err(TryLockMutexError::Abandoned) => {
level.set(LEVEL_ABANDONED);
self.mutex.mark_consistent().unwrap();
}
}
if (level.get() & LEVEL_ABANDONED) != 0 {
Err(TryLockError::Abandoned(MutexGuard {
mutex: self,
_no_send_sync: PhantomData,
}))
} else {
Ok(MutexGuard {
mutex: self,
_no_send_sync: PhantomData,
})
}
}
pub fn mark_consistent(&self) -> Result<(), MarkConsistentError> {
let level = &self.hunk.level;
match self.mutex.mark_consistent() {
Ok(()) => {
level.set(0);
Ok(())
}
Err(MarkConsistentMutexError::BadId) => unreachable!(),
Err(MarkConsistentMutexError::BadContext) => Err(MarkConsistentError::BadContext),
Err(MarkConsistentMutexError::BadObjectState) => {
if (level.get() & LEVEL_ABANDONED) != 0 {
level.update(|x| x & !LEVEL_ABANDONED);
Ok(())
} else {
Err(MarkConsistentError::Consistent)
}
}
}
}
#[inline]
pub fn get_ptr(&self) -> *mut T {
core::ptr::raw_const!(self.hunk.data) as _
}
}
impl<System: Kernel, T: fmt::Debug + 'static> fmt::Debug for RecursiveMutex<System, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.try_lock() {
Ok(guard) => f
.debug_struct("RecursiveMutex")
.field("data", &&*guard)
.finish(),
Err(TryLockError::BadContext) => {
struct BadContextPlaceholder;
impl fmt::Debug for BadContextPlaceholder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<CPU context active>")
}
}
f.debug_struct("RecursiveMutex")
.field("data", &BadContextPlaceholder)
.finish()
}
Err(TryLockError::WouldBlock) => {
struct LockedPlaceholder;
impl fmt::Debug for LockedPlaceholder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<locked>")
}
}
f.debug_struct("RecursiveMutex")
.field("data", &LockedPlaceholder)
.finish()
}
Err(TryLockError::Abandoned(_)) => {
struct AbandonedPlaceholder;
impl fmt::Debug for AbandonedPlaceholder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<abandoned>")
}
}
f.debug_struct("RecursiveMutex")
.field("data", &AbandonedPlaceholder)
.finish()
}
Err(TryLockError::BadParam) => {
struct BadParamPlaceholder;
impl fmt::Debug for BadParamPlaceholder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<current priority too high>")
}
}
f.debug_struct("RecursiveMutex")
.field("data", &BadParamPlaceholder)
.finish()
}
}
}
}
impl<System: Kernel, T: fmt::Debug + 'static> fmt::Debug for MutexGuard<'_, System, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&**self, f)
}
}
impl<System: Kernel, T: fmt::Display + 'static> fmt::Display for MutexGuard<'_, System, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&**self, f)
}
}
impl<System: Kernel, T: 'static> Drop for MutexGuard<'_, System, T> {
#[inline]
fn drop(&mut self) {
let level = &self.mutex.hunk.level;
if level.get() == 0 || level.get() == LEVEL_ABANDONED {
self.mutex.mutex.unlock().unwrap();
} else {
level.update(|x| x - (1 << LEVEL_COUNT_SHIFT));
}
}
}
impl<System: Kernel, T: 'static> core::ops::Deref for MutexGuard<'_, System, T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
&self.mutex.hunk.data
}
}