use super::access::AccessQueue;
use crate::{MessageId, exec, msg};
use core::{
cell::{Cell, UnsafeCell},
future::Future,
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
};
type ReadersCount = u8;
const READERS_LIMIT: ReadersCount = 32;
pub struct RwLock<T> {
locked: UnsafeCell<Option<MessageId>>,
value: UnsafeCell<T>,
readers: Cell<ReadersCount>,
queue: AccessQueue,
}
impl<T> From<T> for RwLock<T> {
fn from(t: T) -> Self {
RwLock::new(t)
}
}
impl<T: Default> Default for RwLock<T> {
fn default() -> Self {
<T as Default>::default().into()
}
}
impl<T> RwLock<T> {
pub const READERS_LIMIT: ReadersCount = READERS_LIMIT;
pub const fn new(t: T) -> RwLock<T> {
RwLock {
value: UnsafeCell::new(t),
locked: UnsafeCell::new(None),
readers: Cell::new(0),
queue: AccessQueue::new(),
}
}
pub fn read(&self) -> RwLockReadFuture<'_, T> {
RwLockReadFuture { lock: self }
}
pub fn write(&self) -> RwLockWriteFuture<'_, T> {
RwLockWriteFuture { lock: self }
}
}
unsafe impl<T> Sync for RwLock<T> {}
pub struct RwLockReadGuard<'a, T> {
lock: &'a RwLock<T>,
holder_msg_id: MessageId,
}
impl<T> RwLockReadGuard<'_, T> {
fn ensure_access_by_holder(&self) {
let current_msg_id = msg::id();
if self.holder_msg_id != current_msg_id {
panic!(
"Read lock guard held by message 0x{} is being accessed by message 0x{}",
hex::encode(self.holder_msg_id),
hex::encode(current_msg_id)
);
}
}
}
impl<T> Drop for RwLockReadGuard<'_, T> {
fn drop(&mut self) {
self.ensure_access_by_holder();
unsafe {
let readers = &self.lock.readers;
let readers_count = readers.get().saturating_sub(1);
readers.replace(readers_count);
if readers_count == 0 {
*self.lock.locked.get() = None;
if let Some(message_id) = self.lock.queue.dequeue() {
exec::wake(message_id).expect("Failed to wake the message");
}
}
}
}
}
impl<'a, T> AsRef<T> for RwLockReadGuard<'a, T> {
fn as_ref(&self) -> &'a T {
self.ensure_access_by_holder();
unsafe { &*self.lock.value.get() }
}
}
impl<T> Deref for RwLockReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
self.ensure_access_by_holder();
unsafe { &*self.lock.value.get() }
}
}
pub struct RwLockWriteGuard<'a, T> {
lock: &'a RwLock<T>,
holder_msg_id: MessageId,
}
impl<T> RwLockWriteGuard<'_, T> {
fn ensure_access_by_holder(&self) {
let current_msg_id = msg::id();
if self.holder_msg_id != current_msg_id {
panic!(
"Write lock guard held by message 0x{} is being accessed by message 0x{}",
hex::encode(self.holder_msg_id),
hex::encode(current_msg_id)
);
}
}
}
impl<T> Drop for RwLockWriteGuard<'_, T> {
fn drop(&mut self) {
self.ensure_access_by_holder();
unsafe {
let locked_by = &mut *self.lock.locked.get();
let owner_msg_id = locked_by.unwrap_or_else(|| {
panic!(
"Write lock guard held by message 0x{} is being dropped for non-existing lock",
hex::encode(self.holder_msg_id),
);
});
if owner_msg_id != self.holder_msg_id {
panic!(
"Write lock guard held by message 0x{} does not match lock owner message 0x{}",
hex::encode(self.holder_msg_id),
hex::encode(owner_msg_id),
);
}
*locked_by = None;
if let Some(message_id) = self.lock.queue.dequeue() {
exec::wake(message_id).expect("Failed to wake the message");
}
}
}
}
impl<'a, T> AsRef<T> for RwLockWriteGuard<'a, T> {
fn as_ref(&self) -> &'a T {
self.ensure_access_by_holder();
unsafe { &*self.lock.value.get() }
}
}
impl<'a, T> AsMut<T> for RwLockWriteGuard<'a, T> {
fn as_mut(&mut self) -> &'a mut T {
self.ensure_access_by_holder();
unsafe { &mut *self.lock.value.get() }
}
}
impl<T> Deref for RwLockWriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
self.ensure_access_by_holder();
unsafe { &*self.lock.value.get() }
}
}
impl<T> DerefMut for RwLockWriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
self.ensure_access_by_holder();
unsafe { &mut *self.lock.value.get() }
}
}
pub struct RwLockReadFuture<'a, T> {
lock: &'a RwLock<T>,
}
pub struct RwLockWriteFuture<'a, T> {
lock: &'a RwLock<T>,
}
impl<'a, T> Future for RwLockReadFuture<'a, T> {
type Output = RwLockReadGuard<'a, T>;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let readers = &self.lock.readers;
let readers_count = readers.get().saturating_add(1);
let current_msg_id = msg::id();
let lock = unsafe { &mut *self.lock.locked.get() };
if lock.is_none() && readers_count <= READERS_LIMIT {
readers.replace(readers_count);
Poll::Ready(RwLockReadGuard {
lock: self.lock,
holder_msg_id: current_msg_id,
})
} else {
if !self.lock.queue.contains(¤t_msg_id) {
self.lock.queue.enqueue(current_msg_id);
}
Poll::Pending
}
}
}
impl<'a, T> Future for RwLockWriteFuture<'a, T> {
type Output = RwLockWriteGuard<'a, T>;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let current_msg_id = msg::id();
let lock = unsafe { &mut *self.lock.locked.get() };
if lock.is_none() && self.lock.readers.get() == 0 {
*lock = Some(current_msg_id);
Poll::Ready(RwLockWriteGuard {
lock: self.lock,
holder_msg_id: current_msg_id,
})
} else {
if !self.lock.queue.contains(¤t_msg_id) {
self.lock.queue.enqueue(current_msg_id);
}
Poll::Pending
}
}
}