#[cfg(test)]
mod tests;
mod spin_mutex;
mod unsafe_list;
use crate::num::NonZeroUsize;
use crate::ops::{Deref, DerefMut};
use crate::time::Duration;
use super::abi::thread;
use super::abi::usercalls;
use fortanix_sgx_abi::{Tcs, EV_UNPARK, WAIT_INDEFINITE};
pub use self::spin_mutex::{try_lock_or_false, SpinMutex, SpinMutexGuard};
use self::unsafe_list::{UnsafeList, UnsafeListEntry};
struct WaitEntry {
tcs: Tcs,
wake: bool,
}
#[derive(Default)]
pub struct WaitVariable<T> {
queue: WaitQueue,
lock: T,
}
impl<T> WaitVariable<T> {
pub const fn new(var: T) -> Self {
WaitVariable { queue: WaitQueue::new(), lock: var }
}
pub fn queue_empty(&self) -> bool {
self.queue.is_empty()
}
pub fn lock_var(&self) -> &T {
&self.lock
}
pub fn lock_var_mut(&mut self) -> &mut T {
&mut self.lock
}
}
#[derive(Copy, Clone)]
pub enum NotifiedTcs {
Single(Tcs),
All { count: NonZeroUsize },
}
pub struct WaitGuard<'a, T: 'a> {
mutex_guard: Option<SpinMutexGuard<'a, WaitVariable<T>>>,
notified_tcs: NotifiedTcs,
}
pub struct WaitQueue {
inner: UnsafeList<SpinMutex<WaitEntry>>,
}
unsafe impl Send for WaitQueue {}
impl Default for WaitQueue {
fn default() -> Self {
Self::new()
}
}
impl<'a, T> WaitGuard<'a, T> {
pub fn notified_tcs(&self) -> NotifiedTcs {
self.notified_tcs
}
pub fn drop_after<U>(self, guard: U) {
drop(guard);
drop(self);
}
}
impl<'a, T> Deref for WaitGuard<'a, T> {
type Target = SpinMutexGuard<'a, WaitVariable<T>>;
fn deref(&self) -> &Self::Target {
self.mutex_guard.as_ref().unwrap()
}
}
impl<'a, T> DerefMut for WaitGuard<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.mutex_guard.as_mut().unwrap()
}
}
impl<'a, T> Drop for WaitGuard<'a, T> {
fn drop(&mut self) {
drop(self.mutex_guard.take());
let target_tcs = match self.notified_tcs {
NotifiedTcs::Single(tcs) => Some(tcs),
NotifiedTcs::All { .. } => None,
};
rtunwrap!(Ok, usercalls::send(EV_UNPARK, target_tcs));
}
}
impl WaitQueue {
pub const fn new() -> Self {
WaitQueue { inner: UnsafeList::new() }
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn wait<T, F: FnOnce()>(mut guard: SpinMutexGuard<'_, WaitVariable<T>>, before_wait: F) {
unsafe {
let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry {
tcs: thread::current(),
wake: false,
}));
let entry = guard.queue.inner.push(&mut entry);
drop(guard);
before_wait();
while !entry.lock().wake {
let eventset = rtunwrap!(Ok, usercalls::wait(EV_UNPARK, WAIT_INDEFINITE));
rtassert!(eventset & EV_UNPARK == EV_UNPARK);
}
}
}
pub fn wait_timeout<T, F: FnOnce()>(
lock: &SpinMutex<WaitVariable<T>>,
timeout: Duration,
before_wait: F,
) -> bool {
unsafe {
let mut entry = UnsafeListEntry::new(SpinMutex::new(WaitEntry {
tcs: thread::current(),
wake: false,
}));
let entry_lock = lock.lock().queue.inner.push(&mut entry);
before_wait();
usercalls::wait_timeout(EV_UNPARK, timeout, || entry_lock.lock().wake);
let mut guard = lock.lock();
let success = entry_lock.lock().wake;
if !success {
guard.queue.inner.remove(&mut entry);
}
success
}
}
pub fn notify_one<T>(
mut guard: SpinMutexGuard<'_, WaitVariable<T>>,
) -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>> {
unsafe {
if let Some(entry) = guard.queue.inner.pop() {
let mut entry_guard = entry.lock();
let tcs = entry_guard.tcs;
entry_guard.wake = true;
drop(entry);
Ok(WaitGuard { mutex_guard: Some(guard), notified_tcs: NotifiedTcs::Single(tcs) })
} else {
Err(guard)
}
}
}
pub fn notify_all<T>(
mut guard: SpinMutexGuard<'_, WaitVariable<T>>,
) -> Result<WaitGuard<'_, T>, SpinMutexGuard<'_, WaitVariable<T>>> {
unsafe {
let mut count = 0;
while let Some(entry) = guard.queue.inner.pop() {
count += 1;
let mut entry_guard = entry.lock();
entry_guard.wake = true;
}
if let Some(count) = NonZeroUsize::new(count) {
Ok(WaitGuard { mutex_guard: Some(guard), notified_tcs: NotifiedTcs::All { count } })
} else {
Err(guard)
}
}
}
}