use super::access::AccessQueue;
use crate::{
BlockCount, BlockNumber, Config, MessageId, async_runtime,
errors::{Error, Result, UsageError},
exec, format, msg,
};
use core::{
cell::UnsafeCell,
future::Future,
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
};
static mut NEXT_MUTEX_ID: MutexId = MutexId::new();
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
pub(crate) struct MutexId(u32);
impl MutexId {
pub const fn new() -> Self {
MutexId(0)
}
pub fn next(self) -> Self {
Self(self.0.wrapping_add(1))
}
}
pub struct Mutex<T> {
id: UnsafeCell<Option<MutexId>>,
locked: UnsafeCell<Option<(MessageId, BlockNumber)>>,
value: UnsafeCell<T>,
queue: AccessQueue,
}
impl<T> From<T> for Mutex<T> {
fn from(t: T) -> Self {
Mutex::new(t)
}
}
impl<T: Default> Default for Mutex<T> {
fn default() -> Self {
<T as Default>::default().into()
}
}
impl<T> Mutex<T> {
pub const fn new(t: T) -> Mutex<T> {
Mutex {
id: UnsafeCell::new(None),
value: UnsafeCell::new(t),
locked: UnsafeCell::new(None),
queue: AccessQueue::new(),
}
}
pub fn lock(&self) -> MutexLockFuture<'_, T> {
MutexLockFuture {
mutex_id: self.get_or_assign_id(),
mutex: self,
own_up_for: None,
}
}
#[allow(clippy::mut_from_ref)]
fn locked_by_mut(&self) -> &mut Option<(MessageId, BlockNumber)> {
unsafe { &mut *self.locked.get() }
}
fn get_or_assign_id(&self) -> MutexId {
let id = unsafe { &mut *self.id.get() };
*id.get_or_insert_with(|| unsafe {
let id = NEXT_MUTEX_ID;
NEXT_MUTEX_ID = NEXT_MUTEX_ID.next();
id
})
}
}
pub struct MutexGuard<'a, T> {
mutex: &'a Mutex<T>,
holder_msg_id: MessageId,
}
impl<T> MutexGuard<'_, T> {
#[track_caller]
fn ensure_access_by_holder(&self) {
let current_msg_id = msg::id();
if self.holder_msg_id != current_msg_id {
panic!(
"Mutex guard held by message 0x{} is being accessed by message 0x{}",
hex::encode(self.holder_msg_id),
hex::encode(current_msg_id)
);
}
}
}
impl<T> Drop for MutexGuard<'_, T> {
fn drop(&mut self) {
let is_holder_msg_signal_handler = match () {
#[cfg(not(feature = "ethexe"))]
() => msg::signal_from() == Ok(self.holder_msg_id),
#[cfg(feature = "ethexe")]
() => false,
};
if !is_holder_msg_signal_handler {
self.ensure_access_by_holder();
}
let locked_by = self.mutex.locked_by_mut();
let owner_msg_id = locked_by.map(|v| v.0);
if owner_msg_id != Some(self.holder_msg_id) && !is_holder_msg_signal_handler {
panic!(
"Mutex guard held by message 0x{} does not match lock owner message {}",
hex::encode(self.holder_msg_id),
owner_msg_id.map_or("None".into(), |v| format!("0x{}", hex::encode(v)))
);
}
if owner_msg_id == Some(self.holder_msg_id) {
if let Some(message_id) = self.mutex.queue.dequeue() {
exec::wake(message_id).expect("Failed to wake the message");
}
*locked_by = None;
}
}
}
impl<'a, T> AsRef<T> for MutexGuard<'a, T> {
fn as_ref(&self) -> &'a T {
self.ensure_access_by_holder();
unsafe { &*self.mutex.value.get() }
}
}
impl<'a, T> AsMut<T> for MutexGuard<'a, T> {
fn as_mut(&mut self) -> &'a mut T {
self.ensure_access_by_holder();
unsafe { &mut *self.mutex.value.get() }
}
}
impl<T> Deref for MutexGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
self.ensure_access_by_holder();
unsafe { &*self.mutex.value.get() }
}
}
impl<T> DerefMut for MutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
self.ensure_access_by_holder();
unsafe { &mut *self.mutex.value.get() }
}
}
unsafe impl<T> Sync for Mutex<T> {}
pub struct MutexLockFuture<'a, T> {
mutex_id: MutexId,
mutex: &'a Mutex<T>,
own_up_for: Option<BlockCount>,
}
impl<'a, T> MutexLockFuture<'a, T> {
pub fn own_up_for(self, block_count: BlockCount) -> Result<Self> {
if block_count == 0 {
Err(Error::Gstd(UsageError::ZeroMxLockDuration))
} else {
Ok(MutexLockFuture {
mutex_id: self.mutex_id,
mutex: self.mutex,
own_up_for: Some(block_count),
})
}
}
fn acquire_lock_ownership(
&mut self,
owner_msg_id: MessageId,
current_block: BlockNumber,
) -> Poll<MutexGuard<'a, T>> {
let owner_deadline_block =
current_block.saturating_add(self.own_up_for.unwrap_or_else(Config::mx_lock_duration));
async_runtime::locks().remove_mx_lock_monitor(owner_msg_id, self.mutex_id);
if let Some(next_rival_msg_id) = self.mutex.queue.first() {
async_runtime::locks().insert_mx_lock_monitor(
*next_rival_msg_id,
self.mutex_id,
owner_deadline_block,
);
}
let locked_by = self.mutex.locked_by_mut();
*locked_by = Some((owner_msg_id, owner_deadline_block));
Poll::Ready(MutexGuard {
mutex: self.mutex,
holder_msg_id: owner_msg_id,
})
}
fn queue_for_lock_ownership(
&mut self,
rival_msg_id: MessageId,
owner_deadline_block: Option<BlockNumber>,
) -> Poll<MutexGuard<'a, T>> {
if !self.mutex.queue.contains(&rival_msg_id) {
self.mutex.queue.enqueue(rival_msg_id);
if let Some(owner_deadline_block) = owner_deadline_block {
if self.mutex.queue.len() == 1 {
async_runtime::locks().insert_mx_lock_monitor(
rival_msg_id,
self.mutex_id,
owner_deadline_block,
);
}
}
}
Poll::Pending
}
}
impl<'a, T> Future for MutexLockFuture<'a, T> {
type Output = MutexGuard<'a, T>;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let current_msg_id = msg::id();
let current_block = exec::block_height();
let locked_by = self.mutex.locked_by_mut();
if locked_by.is_none() {
return self
.get_mut()
.acquire_lock_ownership(current_msg_id, current_block);
}
let (lock_owner_msg_id, deadline_block) =
(*locked_by).unwrap_or_else(|| unreachable!("Checked above"));
if current_block < deadline_block {
return self
.get_mut()
.queue_for_lock_ownership(current_msg_id, Some(deadline_block));
}
if let Some(msg_future_task) = async_runtime::futures().get_mut(&lock_owner_msg_id) {
msg_future_task.set_lock_exceeded();
exec::wake(lock_owner_msg_id).expect("Failed to wake the message");
}
while let Some(next_msg_id) = self.mutex.queue.dequeue() {
if next_msg_id == lock_owner_msg_id {
continue;
}
if next_msg_id == current_msg_id {
break;
}
exec::wake(next_msg_id).expect("Failed to wake the message");
*locked_by = None;
return self
.get_mut()
.queue_for_lock_ownership(current_msg_id, None);
}
self.get_mut()
.acquire_lock_ownership(current_msg_id, current_block)
}
}