use core::{cell::Cell, fmt, marker::PhantomData, mem::MaybeUninit, ops::Deref};
use crate::{
hunk::Hunk,
kernel::{
mutex, prelude::*, traits, Cfg, LockMutexError, MarkConsistentMutexError, MutexProtocol,
TryLockMutexError,
},
sync::source::{DefaultSource, Source},
utils::Init,
};
#[doc = include_str!("../common.md")]
pub struct Definer<System, Source> {
mutex: mutex::MutexDefiner<System>,
source: Source,
}
pub struct GenericRecursiveMutex<Cell, Mutex> {
cell: Cell,
mutex: Mutex,
}
#[doc = crate::tests::doc_test!(
/// ```rust
/// use core::cell::Cell;
/// use r3::{kernel::StaticTask, sync::StaticRecursiveMutex};
///
/// struct Objects {
/// mutex: StaticRecursiveMutex<System, Cell<i32>>,
/// }
///
/// const fn configure_app<C>(cfg: &mut Cfg<C>) -> Objects
/// where
/// C: ~const traits::CfgTask<System = System> +
/// ~const traits::CfgMutex,
/// {
/// StaticTask::define()
/// .start(task1_body)
/// .priority(2)
/// .active(true)
/// .finish(cfg);
///
/// let mutex = StaticRecursiveMutex::define()
/// .init(|| Cell::new(1))
/// .finish(cfg);
///
/// Objects { mutex }
/// }
///
/// fn task1_body() {
/// let guard = COTTAGE.mutex.lock().unwrap();
/// assert_eq!(guard.get(), 1);
/// guard.set(2);
///
/// {
/// // Recursive lock is allowed
/// let guard2 = COTTAGE.mutex.lock().unwrap();
/// assert_eq!(guard2.get(), 2);
/// guard2.set(3);
/// }
///
/// assert_eq!(guard.get(), 3);
/// # exit(0);
/// }
/// ```
)]
pub type StaticRecursiveMutex<System, T> =
GenericRecursiveMutex<Hunk<System, MaybeUninit<MutexInner<T>>>, mutex::StaticMutex<System>>;
unsafe impl<Cell, Mutex, T: Send> Send for GenericRecursiveMutex<Cell, Mutex> where
Cell: Deref<Target = MaybeUninit<MutexInner<T>>>
{
}
unsafe impl<Cell, Mutex, T: Send> Sync for GenericRecursiveMutex<Cell, Mutex> where
Cell: Deref<Target = MaybeUninit<MutexInner<T>>>
{
}
pub struct MutexInner<T> {
level: Cell<usize>,
data: T,
}
impl<T> MutexInner<T> {
#[inline]
pub const fn new(data: T) -> Self {
Self {
level: Cell::new(0),
data,
}
}
}
impl<T: Init> Init for MutexInner<T> {
const INIT: Self = Self::new(T::INIT);
}
impl<T: ~const Default> const Default for MutexInner<T> {
#[inline]
fn default() -> Self {
Self::new(T::default())
}
}
impl<T> const From<T> for MutexInner<T> {
#[inline]
fn from(x: T) -> Self {
Self::new(x)
}
}
const LEVEL_ABANDONED: usize = 1;
const LEVEL_COUNT_SHIFT: u32 = 1;
#[must_use = "if unused the GenericRecursiveMutex will immediately unlock"]
pub struct GenericMutexGuard<'a, Cell, Mutex, T>
where
Cell: Deref<Target = MaybeUninit<MutexInner<T>>>,
Mutex: mutex::MutexHandle,
{
mutex: &'a GenericRecursiveMutex<Cell, Mutex>,
_no_send_sync: PhantomData<*mut ()>,
}
unsafe impl<Cell, Mutex, T: Sync> Sync for GenericMutexGuard<'_, Cell, Mutex, T>
where
Cell: Deref<Target = MaybeUninit<MutexInner<T>>>,
Mutex: mutex::MutexHandle,
{
}
pub type LockResult<Guard> = Result<Guard, LockError<Guard>>;
pub type TryLockResult<Guard> = Result<Guard, TryLockError<Guard>>;
#[repr(i8)]
pub enum LockError<Guard> {
BadContext = LockMutexError::BadContext as i8,
Interrupted = LockMutexError::Interrupted as i8,
BadParam = LockMutexError::BadParam as i8,
Abandoned(Guard) = LockMutexError::Abandoned as i8,
}
impl<Guard> fmt::Debug for LockError<Guard> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Self::BadContext => "BadContext",
Self::Interrupted => "Interrupted",
Self::BadParam => "BadParam",
Self::Abandoned(_) => "Abandoned",
})
}
}
#[repr(i8)]
pub enum TryLockError<Guard> {
BadContext = TryLockMutexError::BadContext as i8,
WouldBlock = TryLockMutexError::Timeout as i8,
BadParam = TryLockMutexError::BadParam as i8,
Abandoned(Guard) = TryLockMutexError::Abandoned as i8,
}
impl<Guard> fmt::Debug for TryLockError<Guard> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Self::BadContext => "BadContext",
Self::WouldBlock => "WouldBlock",
Self::BadParam => "BadParam",
Self::Abandoned(_) => "Abandoned",
})
}
}
#[derive(Debug)]
#[repr(i8)]
pub enum MarkConsistentError {
BadContext = MarkConsistentMutexError::BadContext as i8,
Consistent = MarkConsistentMutexError::BadObjectState as i8,
}
impl<System, T: 'static> StaticRecursiveMutex<System, T>
where
System: traits::KernelMutex + traits::KernelStatic,
{
pub const fn define() -> Definer<System, DefaultSource<MutexInner<T>>> {
Definer {
mutex: mutex::MutexRef::define(),
source: DefaultSource::INIT, }
}
}
impl<System, Source> Definer<System, Source>
where
System: traits::KernelMutex + traits::KernelStatic,
{
pub const fn protocol(self, protocol: MutexProtocol) -> Self {
Self {
mutex: self.mutex.protocol(protocol),
..self
}
}
}
impl_source_setter!(
#[autowrap(MutexInner::new, MutexInner)]
impl Definer<System, #Source>
);
impl<System, Source> Definer<System, Source>
where
System: traits::KernelMutex + traits::KernelStatic,
{
pub const fn finish<C: ~const traits::CfgMutex<System = System>, T>(
self,
cfg: &mut Cfg<C>,
) -> StaticRecursiveMutex<System, T>
where
Source: ~const self::Source<System, Target = MutexInner<T>>,
{
GenericRecursiveMutex {
cell: unsafe { self.source.into_unsafe_cell_hunk(cfg).transmute() },
mutex: self.mutex.finish(cfg),
}
}
}
impl<Cell, Mutex, T> GenericRecursiveMutex<Cell, Mutex>
where
Cell: Deref<Target = MaybeUninit<MutexInner<T>>>,
Mutex: mutex::MutexHandle,
{
pub fn lock(&self) -> LockResult<GenericMutexGuard<'_, Cell, Mutex, T>> {
let level;
match self.mutex.lock() {
Ok(()) => {
level = unsafe { &self.cell.assume_init_ref().level };
}
Err(LockMutexError::WouldDeadlock) => {
level = unsafe { &self.cell.assume_init_ref().level };
level.update(|x| {
x.checked_add(1 << LEVEL_COUNT_SHIFT)
.expect("nesting count overflow")
});
}
Err(LockMutexError::NoAccess) => unreachable!(),
Err(LockMutexError::BadContext) => return Err(LockError::BadContext),
Err(LockMutexError::Interrupted) => return Err(LockError::Interrupted),
Err(LockMutexError::BadParam) => return Err(LockError::BadParam),
Err(LockMutexError::Abandoned) => {
level = unsafe { &self.cell.assume_init_ref().level };
level.set(LEVEL_ABANDONED);
self.mutex.mark_consistent().unwrap();
}
}
if (level.get() & LEVEL_ABANDONED) != 0 {
Err(LockError::Abandoned(GenericMutexGuard {
mutex: self,
_no_send_sync: PhantomData,
}))
} else {
Ok(GenericMutexGuard {
mutex: self,
_no_send_sync: PhantomData,
})
}
}
pub fn try_lock(&self) -> TryLockResult<GenericMutexGuard<'_, Cell, Mutex, T>> {
let level;
match self.mutex.try_lock() {
Ok(()) => {
level = unsafe { &self.cell.assume_init_ref().level };
}
Err(TryLockMutexError::WouldDeadlock) => {
level = unsafe { &self.cell.assume_init_ref().level };
level.update(|x| {
x.checked_add(1 << LEVEL_COUNT_SHIFT)
.expect("nesting count overflow")
});
}
Err(TryLockMutexError::NoAccess) => unreachable!(),
Err(TryLockMutexError::BadContext) => return Err(TryLockError::BadContext),
Err(TryLockMutexError::Timeout) => return Err(TryLockError::WouldBlock),
Err(TryLockMutexError::BadParam) => return Err(TryLockError::BadParam),
Err(TryLockMutexError::Abandoned) => {
level = unsafe { &self.cell.assume_init_ref().level };
level.set(LEVEL_ABANDONED);
self.mutex.mark_consistent().unwrap();
}
}
if (level.get() & LEVEL_ABANDONED) != 0 {
Err(TryLockError::Abandoned(GenericMutexGuard {
mutex: self,
_no_send_sync: PhantomData,
}))
} else {
Ok(GenericMutexGuard {
mutex: self,
_no_send_sync: PhantomData,
})
}
}
pub fn mark_consistent(&self) -> Result<(), MarkConsistentError> {
match self.mutex.mark_consistent() {
Ok(()) => {
let level = unsafe { &self.cell.assume_init_ref().level };
level.set(0);
Ok(())
}
Err(MarkConsistentMutexError::NoAccess) => unreachable!(),
Err(MarkConsistentMutexError::BadContext) => Err(MarkConsistentError::BadContext),
Err(MarkConsistentMutexError::BadObjectState) => {
let level = unsafe { &self.cell.assume_init_ref().level };
if (level.get() & LEVEL_ABANDONED) != 0 {
level.update(|x| x & !LEVEL_ABANDONED);
Ok(())
} else {
Err(MarkConsistentError::Consistent)
}
}
}
}
#[inline]
pub fn get_ptr(&self) -> *mut T {
unsafe { core::ptr::addr_of!((*self.cell.as_ptr()).data) as *mut T }
}
}
impl<Cell, Mutex, T: fmt::Debug> fmt::Debug for GenericRecursiveMutex<Cell, Mutex>
where
Cell: Deref<Target = MaybeUninit<MutexInner<T>>>,
Mutex: mutex::MutexHandle,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.try_lock() {
Ok(guard) => f
.debug_struct("GenericRecursiveMutex")
.field("data", &&*guard)
.finish(),
Err(TryLockError::BadContext) => {
struct BadContextPlaceholder;
impl fmt::Debug for BadContextPlaceholder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<bad context>")
}
}
f.debug_struct("GenericRecursiveMutex")
.field("data", &BadContextPlaceholder)
.finish()
}
Err(TryLockError::WouldBlock) => {
struct LockedPlaceholder;
impl fmt::Debug for LockedPlaceholder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<locked>")
}
}
f.debug_struct("GenericRecursiveMutex")
.field("data", &LockedPlaceholder)
.finish()
}
Err(TryLockError::Abandoned(_)) => {
struct AbandonedPlaceholder;
impl fmt::Debug for AbandonedPlaceholder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<abandoned>")
}
}
f.debug_struct("GenericRecursiveMutex")
.field("data", &AbandonedPlaceholder)
.finish()
}
Err(TryLockError::BadParam) => {
struct BadParamPlaceholder;
impl fmt::Debug for BadParamPlaceholder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<current priority too high>")
}
}
f.debug_struct("GenericRecursiveMutex")
.field("data", &BadParamPlaceholder)
.finish()
}
}
}
}
impl<Cell, Mutex, T: fmt::Debug> fmt::Debug for GenericMutexGuard<'_, Cell, Mutex, T>
where
Cell: Deref<Target = MaybeUninit<MutexInner<T>>>,
Mutex: mutex::MutexHandle,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&**self, f)
}
}
impl<Cell, Mutex, T: fmt::Display> fmt::Display for GenericMutexGuard<'_, Cell, Mutex, T>
where
Cell: Deref<Target = MaybeUninit<MutexInner<T>>>,
Mutex: mutex::MutexHandle,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&**self, f)
}
}
impl<Cell, Mutex, T> Drop for GenericMutexGuard<'_, Cell, Mutex, T>
where
Cell: Deref<Target = MaybeUninit<MutexInner<T>>>,
Mutex: mutex::MutexHandle,
{
#[inline]
fn drop(&mut self) {
let level = unsafe { &self.mutex.cell.assume_init_ref().level };
if level.get() == 0 || level.get() == LEVEL_ABANDONED {
self.mutex.mutex.unlock().unwrap();
} else {
level.update(|x| x - (1 << LEVEL_COUNT_SHIFT));
}
}
}
impl<Cell, Mutex, T> Deref for GenericMutexGuard<'_, Cell, Mutex, T>
where
Cell: Deref<Target = MaybeUninit<MutexInner<T>>>,
Mutex: mutex::MutexHandle,
{
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe { &self.mutex.cell.assume_init_ref().data }
}
}
unsafe impl<Cell, Mutex, T> stable_deref_trait::StableDeref
for GenericMutexGuard<'_, Cell, Mutex, T>
where
Cell: Deref<Target = MaybeUninit<MutexInner<T>>>,
Mutex: mutex::MutexHandle,
{
}