use alloc::vec::Vec;
use crate::priority::Priority;
#[derive(Debug, Clone)]
pub struct PriorityInheritance {
base: Priority,
waiters: Vec<Priority>,
}
impl PriorityInheritance {
#[must_use]
pub fn new(base: Priority) -> Self {
Self {
base,
waiters: Vec::new(),
}
}
#[must_use]
pub fn base(&self) -> Priority {
self.base
}
#[must_use]
pub fn effective(&self) -> Priority {
self.waiters
.iter()
.copied()
.chain(core::iter::once(self.base))
.max()
.unwrap_or(self.base)
}
pub fn on_block(&mut self, waiter: Priority) -> Priority {
self.waiters.push(waiter);
self.effective()
}
pub fn on_unblock(&mut self, waiter: Priority) -> Priority {
if let Some(pos) = self.waiters.iter().position(|&w| w == waiter) {
self.waiters.remove(pos);
}
self.effective()
}
#[must_use]
pub fn waiter_count(&self) -> usize {
self.waiters.len()
}
}
#[cfg(feature = "std")]
pub use std_impl::{RtMutex, RtMutexGuard};
#[cfg(feature = "std")]
#[allow(clippy::expect_used)]
mod std_impl {
use super::Priority;
use alloc::vec::Vec;
use std::sync::{Condvar, Mutex, MutexGuard};
#[derive(Debug)]
struct Coord {
locked: bool,
holder: Option<Priority>,
waiters: Vec<Priority>,
}
impl Coord {
fn top_waiter(&self) -> Option<Priority> {
self.waiters.iter().copied().max()
}
}
#[derive(Debug)]
pub struct RtMutex<T> {
data: Mutex<T>,
coord: Mutex<Coord>,
cv: Condvar,
}
#[derive(Debug)]
pub struct RtMutexGuard<'a, T> {
mutex: &'a RtMutex<T>,
data: Option<MutexGuard<'a, T>>,
}
impl<T> RtMutex<T> {
#[must_use]
pub fn new(value: T) -> Self {
Self {
data: Mutex::new(value),
coord: Mutex::new(Coord {
locked: false,
holder: None,
waiters: Vec::new(),
}),
cv: Condvar::new(),
}
}
#[allow(clippy::missing_panics_doc)]
pub fn lock(&self, my_priority: Priority) -> RtMutexGuard<'_, T> {
{
let mut coord = self.coord.lock().expect("rt-mutex coord poisoned");
coord.waiters.push(my_priority);
loop {
let granted = !coord.locked && coord.top_waiter() == Some(my_priority);
if granted {
break;
}
coord = self.cv.wait(coord).expect("rt-mutex cv poisoned");
}
if let Some(pos) = coord.waiters.iter().position(|&w| w == my_priority) {
coord.waiters.remove(pos);
}
coord.locked = true;
coord.holder = Some(my_priority);
}
let data = self.data.lock().expect("rt-mutex data poisoned");
RtMutexGuard {
mutex: self,
data: Some(data),
}
}
#[allow(clippy::missing_panics_doc)]
pub fn try_lock(&self, my_priority: Priority) -> Option<RtMutexGuard<'_, T>> {
{
let mut coord = self.coord.lock().expect("rt-mutex coord poisoned");
if coord.locked {
return None;
}
coord.locked = true;
coord.holder = Some(my_priority);
}
let data = self.data.lock().expect("rt-mutex data poisoned");
Some(RtMutexGuard {
mutex: self,
data: Some(data),
})
}
#[allow(clippy::missing_panics_doc)]
#[must_use]
pub fn effective_holder_priority(&self) -> Option<Priority> {
let coord = self.coord.lock().expect("rt-mutex coord poisoned");
coord.holder.map(|h| match coord.top_waiter() {
Some(w) if w > h => w,
_ => h,
})
}
fn unlock(&self) {
let mut coord = self.coord.lock().expect("rt-mutex coord poisoned");
coord.locked = false;
coord.holder = None;
drop(coord);
self.cv.notify_all();
}
}
impl<T> RtMutexGuard<'_, T> {
#[must_use]
pub fn get(&self) -> &T {
self.data.as_ref().expect("guard active")
}
pub fn get_mut(&mut self) -> &mut T {
self.data.as_mut().expect("guard active")
}
}
impl<T> core::ops::Deref for RtMutexGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
self.get()
}
}
impl<T> core::ops::DerefMut for RtMutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
self.get_mut()
}
}
impl<T> Drop for RtMutexGuard<'_, T> {
fn drop(&mut self) {
self.data.take();
self.mutex.unlock();
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
fn p(v: i16) -> Priority {
Priority::new(v).unwrap()
}
#[test]
fn effective_is_base_without_waiters() {
let pi = PriorityInheritance::new(p(10));
assert_eq!(pi.effective(), p(10));
assert_eq!(pi.waiter_count(), 0);
}
#[test]
fn owner_inherits_highest_waiter() {
let mut pi = PriorityInheritance::new(p(10));
assert_eq!(pi.on_block(p(50)), p(50)); assert_eq!(pi.on_block(p(30)), p(50)); assert_eq!(pi.on_block(p(90)), p(90)); assert_eq!(pi.waiter_count(), 3);
}
#[test]
fn priority_reverts_on_unblock() {
let mut pi = PriorityInheritance::new(p(10));
pi.on_block(p(50));
pi.on_block(p(90));
assert_eq!(pi.effective(), p(90));
assert_eq!(pi.on_unblock(p(90)), p(50)); assert_eq!(pi.on_unblock(p(50)), p(10)); }
#[cfg(feature = "std")]
#[test]
fn rt_mutex_basic_lock_unlock() {
let m = RtMutex::new(0u32);
{
let mut g = m.lock(p(10));
*g += 5;
assert_eq!(m.effective_holder_priority(), Some(p(10)));
}
assert_eq!(m.effective_holder_priority(), None);
assert_eq!(*m.lock(p(1)), 5);
}
#[cfg(feature = "std")]
#[test]
fn rt_mutex_try_lock_contended() {
let m = RtMutex::new(());
let g = m.lock(p(10));
assert!(m.try_lock(p(20)).is_none());
drop(g);
assert!(m.try_lock(p(20)).is_some());
}
#[cfg(feature = "std")]
#[test]
fn rt_mutex_grants_highest_priority_waiter_first() {
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
let m = Arc::new(RtMutex::new(()));
let order = Arc::new(std::sync::Mutex::new(alloc::vec::Vec::<i16>::new()));
let next = Arc::new(AtomicU32::new(0));
let g = m.lock(p(1));
let mut handles = alloc::vec::Vec::new();
for prio in [p(40), p(80)] {
let m = Arc::clone(&m);
let order = Arc::clone(&order);
let next = Arc::clone(&next);
handles.push(std::thread::spawn(move || {
next.fetch_add(1, Ordering::SeqCst);
let _lg = m.lock(prio);
order.lock().unwrap().push(prio.value());
}));
}
while m.effective_holder_priority() != Some(p(80)) {
std::thread::yield_now();
}
drop(g);
for h in handles {
h.join().unwrap();
}
let seq = order.lock().unwrap().clone();
assert_eq!(seq, alloc::vec![80, 40]);
}
}