use assert_matches::{assert_matches, debug_assert_matches};
use core::{fmt, hash, marker::PhantomData};
use super::{
state, task, timeout, utils,
wait::{WaitPayload, WaitQueue},
BadIdError, Id, Kernel, KernelCfg1, LockMutexError, LockMutexPrecheckError,
LockMutexTimeoutError, MarkConsistentMutexError, PortThreading, QueryMutexError,
TryLockMutexError, UnlockMutexError,
};
use crate::{time::Duration, utils::Init};
#[doc(include = "../common.md")]
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum MutexProtocol {
None,
Ceiling(usize),
}
#[doc(include = "../common.md")]
#[repr(transparent)]
pub struct Mutex<System>(Id, PhantomData<System>);
impl<System> Clone for Mutex<System> {
fn clone(&self) -> Self {
Self(self.0, self.1)
}
}
impl<System> Copy for Mutex<System> {}
impl<System> PartialEq for Mutex<System> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<System> Eq for Mutex<System> {}
impl<System> hash::Hash for Mutex<System> {
fn hash<H>(&self, state: &mut H)
where
H: hash::Hasher,
{
hash::Hash::hash(&self.0, state);
}
}
impl<System> fmt::Debug for Mutex<System> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("Mutex").field(&self.0).finish()
}
}
impl<System> Mutex<System> {
pub const unsafe fn from_id(id: Id) -> Self {
Self(id, PhantomData)
}
pub const fn id(self) -> Id {
self.0
}
}
impl<System: Kernel> Mutex<System> {
fn mutex_cb(self) -> Result<&'static MutexCb<System>, BadIdError> {
System::get_mutex_cb(self.0.get() - 1).ok_or(BadIdError::BadId)
}
#[cfg_attr(not(feature = "inline_syscall"), inline(never))]
pub fn is_locked(self) -> Result<bool, QueryMutexError> {
let lock = utils::lock_cpu::<System>()?;
let mutex_cb = self.mutex_cb()?;
Ok(mutex_cb.owning_task.get(&*lock).is_some())
}
#[cfg_attr(not(feature = "inline_syscall"), inline(never))]
pub fn unlock(self) -> Result<(), UnlockMutexError> {
let lock = utils::lock_cpu::<System>()?;
state::expect_waitable_context::<System>()?;
let mutex_cb = self.mutex_cb()?;
unlock_mutex(mutex_cb, lock)
}
#[cfg_attr(not(feature = "inline_syscall"), inline(never))]
pub fn lock(self) -> Result<(), LockMutexError> {
let lock = utils::lock_cpu::<System>()?;
state::expect_waitable_context::<System>()?;
let mutex_cb = self.mutex_cb()?;
lock_mutex(mutex_cb, lock)
}
#[cfg_attr(not(feature = "inline_syscall"), inline(never))]
pub fn lock_timeout(self, timeout: Duration) -> Result<(), LockMutexTimeoutError> {
let time32 = timeout::time32_from_duration(timeout)?;
let lock = utils::lock_cpu::<System>()?;
state::expect_waitable_context::<System>()?;
let mutex_cb = self.mutex_cb()?;
lock_mutex_timeout(mutex_cb, lock, time32)
}
#[cfg_attr(not(feature = "inline_syscall"), inline(never))]
pub fn try_lock(self) -> Result<(), TryLockMutexError> {
let lock = utils::lock_cpu::<System>()?;
state::expect_task_context::<System>()?;
let mutex_cb = self.mutex_cb()?;
try_lock_mutex(mutex_cb, lock)
}
#[cfg_attr(not(feature = "inline_syscall"), inline(never))]
pub fn mark_consistent(self) -> Result<(), MarkConsistentMutexError> {
let mut lock = utils::lock_cpu::<System>()?;
let mutex_cb = self.mutex_cb()?;
if mutex_cb.inconsistent.replace(&mut *lock, false) {
Ok(())
} else {
Err(MarkConsistentMutexError::BadObjectState)
}
}
}
#[doc(hidden)]
pub struct MutexCb<
System: PortThreading,
TaskPriority: 'static = <System as KernelCfg1>::TaskPriority,
> {
pub(super) ceiling: Option<TaskPriority>,
pub(super) inconsistent: utils::CpuLockCell<System, bool>,
pub(super) wait_queue: WaitQueue<System>,
pub(super) prev_mutex_held: utils::CpuLockCell<System, Option<&'static Self>>,
pub(super) owning_task: utils::CpuLockCell<System, Option<&'static task::TaskCb<System>>>,
}
impl<System: PortThreading> Init for MutexCb<System> {
#[allow(clippy::declare_interior_mutable_const)]
const INIT: Self = Self {
ceiling: Init::INIT,
inconsistent: Init::INIT,
wait_queue: Init::INIT,
prev_mutex_held: Init::INIT,
owning_task: Init::INIT,
};
}
impl<System: Kernel> fmt::Debug for MutexCb<System> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("MutexCb")
.field("self", &(self as *const _))
.field("ceiling", &self.ceiling)
.field("inconsistent", &self.inconsistent)
.field("wait_queue", &self.wait_queue)
.field(
"prev_mutex_held",
&self
.prev_mutex_held
.debug_fmt_with(|x, f| x.map(|x| x as *const _).fmt(f)),
)
.field(
"owning_task",
&self
.owning_task
.debug_fmt_with(|x, f| x.map(|x| x as *const _).fmt(f)),
)
.finish()
}
}
#[inline]
fn precheck_and_get_running_task<System: Kernel>(
mut lock: utils::CpuLockGuardBorrowMut<'_, System>,
mutex_cb: &'static MutexCb<System>,
) -> Result<&'static task::TaskCb<System>, LockMutexPrecheckError> {
let task = System::state().running_task(lock.borrow_mut()).unwrap();
if ptr_from_option_ref(mutex_cb.owning_task.get(&*lock)) == task {
return Err(LockMutexPrecheckError::WouldDeadlock);
}
if let Some(ceiling) = mutex_cb.ceiling {
if ceiling > task.base_priority.get(&*lock) {
return Err(LockMutexPrecheckError::BadParam);
}
}
Ok(task)
}
#[inline]
pub(super) fn does_held_mutex_allow_new_task_base_priority<System: Kernel>(
_lock: utils::CpuLockGuardBorrowMut<'_, System>,
mutex_cb: &'static MutexCb<System>,
new_base_priority: System::TaskPriority,
) -> bool {
if let Some(ceiling) = mutex_cb.ceiling {
if ceiling > new_base_priority {
return false;
}
}
true
}
#[inline]
pub(super) fn do_held_mutexes_allow_new_task_base_priority<System: Kernel>(
mut lock: utils::CpuLockGuardBorrowMut<'_, System>,
task: &'static task::TaskCb<System>,
new_base_priority: System::TaskPriority,
) -> bool {
let mut maybe_mutex_cb = task.last_mutex_held.get(&*lock);
while let Some(mutex_cb) = maybe_mutex_cb {
if !does_held_mutex_allow_new_task_base_priority(
lock.borrow_mut(),
mutex_cb,
new_base_priority,
) {
return false;
}
maybe_mutex_cb = mutex_cb.prev_mutex_held.get(&*lock);
}
true
}
pub(super) fn evaluate_task_effective_priority<System: Kernel>(
lock: utils::CpuLockGuardBorrowMut<'_, System>,
task: &'static task::TaskCb<System>,
base_priority: System::TaskPriority,
) -> System::TaskPriority {
let mut effective_priority = base_priority;
let mut maybe_mutex_cb = task.last_mutex_held.get(&*lock);
while let Some(mutex_cb) = maybe_mutex_cb {
if let Some(ceiling) = mutex_cb.ceiling {
effective_priority = effective_priority.min(ceiling);
}
maybe_mutex_cb = mutex_cb.prev_mutex_held.get(&*lock);
}
effective_priority
}
#[inline]
fn poll_core<System: Kernel>(
mutex_cb: &'static MutexCb<System>,
running_task: &'static task::TaskCb<System>,
lock: utils::CpuLockGuardBorrowMut<'_, System>,
) -> bool {
if mutex_cb.owning_task.get(&*lock).is_some() {
false
} else {
lock_core(mutex_cb, running_task, lock);
true
}
}
#[inline]
#[allow(unused_parens)]
fn lock_core<System: Kernel>(
mutex_cb: &'static MutexCb<System>,
task: &'static task::TaskCb<System>,
mut lock: utils::CpuLockGuardBorrowMut<'_, System>,
) {
debug_assert_matches!(
task.st.read(&*lock),
(task::TaskSt::Running | task::TaskSt::Waiting)
);
mutex_cb.owning_task.replace(&mut *lock, Some(task));
let prev_mutex_held = task.last_mutex_held.replace(&mut *lock, Some(mutex_cb));
mutex_cb
.prev_mutex_held
.replace(&mut *lock, prev_mutex_held);
if let Some(ceiling) = mutex_cb.ceiling {
let effective_priority = task.effective_priority.write(&mut *lock);
*effective_priority = (*effective_priority).min(ceiling);
}
}
fn lock_mutex<System: Kernel>(
mutex_cb: &'static MutexCb<System>,
mut lock: utils::CpuLockGuard<System>,
) -> Result<(), LockMutexError> {
let running_task = precheck_and_get_running_task(lock.borrow_mut(), mutex_cb)?;
if !poll_core(mutex_cb, running_task, lock.borrow_mut()) {
mutex_cb
.wait_queue
.wait(lock.borrow_mut(), WaitPayload::Mutex(mutex_cb))?;
}
if mutex_cb.inconsistent.get(&*lock) {
Err(LockMutexError::Abandoned)
} else {
Ok(())
}
}
fn try_lock_mutex<System: Kernel>(
mutex_cb: &'static MutexCb<System>,
mut lock: utils::CpuLockGuard<System>,
) -> Result<(), TryLockMutexError> {
let running_task = precheck_and_get_running_task(lock.borrow_mut(), mutex_cb)?;
if !poll_core(mutex_cb, running_task, lock.borrow_mut()) {
return Err(TryLockMutexError::Timeout);
}
if mutex_cb.inconsistent.get(&*lock) {
Err(TryLockMutexError::Abandoned)
} else {
Ok(())
}
}
fn lock_mutex_timeout<System: Kernel>(
mutex_cb: &'static MutexCb<System>,
mut lock: utils::CpuLockGuard<System>,
time32: timeout::Time32,
) -> Result<(), LockMutexTimeoutError> {
let running_task = precheck_and_get_running_task(lock.borrow_mut(), mutex_cb)?;
if !poll_core(mutex_cb, running_task, lock.borrow_mut()) {
mutex_cb.wait_queue.wait_timeout(
lock.borrow_mut(),
WaitPayload::Mutex(mutex_cb),
time32,
)?;
}
if mutex_cb.inconsistent.get(&*lock) {
Err(LockMutexTimeoutError::Abandoned)
} else {
Ok(())
}
}
fn unlock_mutex<System: Kernel>(
mutex_cb: &'static MutexCb<System>,
mut lock: utils::CpuLockGuard<System>,
) -> Result<(), UnlockMutexError> {
let task = System::state().running_task(lock.borrow_mut()).unwrap();
if ptr_from_option_ref(mutex_cb.owning_task.get(&*lock)) != task {
return Err(UnlockMutexError::NotOwner);
}
if ptr_from_option_ref(task.last_mutex_held.get(&*lock)) != mutex_cb {
return Err(UnlockMutexError::BadObjectState);
}
let prev_mutex_held = mutex_cb.prev_mutex_held.get(&*lock);
task.last_mutex_held.replace(&mut *lock, prev_mutex_held);
let base_priority = task.base_priority.get(&*lock);
let effective_priority =
evaluate_task_effective_priority(lock.borrow_mut(), task, base_priority);
task.effective_priority
.replace(&mut *lock, effective_priority);
unlock_mutex_unchecked(mutex_cb, lock.borrow_mut());
task::unlock_cpu_and_check_preemption(lock);
Ok(())
}
pub(super) fn abandon_held_mutexes<System: Kernel>(
mut lock: utils::CpuLockGuardBorrowMut<'_, System>,
task: &'static task::TaskCb<System>,
) {
let mut maybe_mutex_cb = task.last_mutex_held.replace(&mut *lock, None);
while let Some(mutex_cb) = maybe_mutex_cb {
maybe_mutex_cb = mutex_cb.prev_mutex_held.get(&*lock);
mutex_cb.inconsistent.replace(&mut *lock, true);
unlock_mutex_unchecked(mutex_cb, lock.borrow_mut());
}
}
fn unlock_mutex_unchecked<System: Kernel>(
mutex_cb: &'static MutexCb<System>,
mut lock: utils::CpuLockGuardBorrowMut<'_, System>,
) {
if let Some(next_task) = mutex_cb.wait_queue.first_waiting_task(lock.borrow_mut()) {
lock_core(mutex_cb, next_task, lock.borrow_mut());
assert!(mutex_cb.wait_queue.wake_up_one(lock.borrow_mut()));
} else {
mutex_cb.owning_task.replace(&mut *lock, None);
}
}
#[inline]
fn ptr_from_option_ref<T>(x: Option<&T>) -> *const T {
if let Some(x) = x {
x
} else {
core::ptr::null()
}
}