use super::{ThreadBeamFlags, ThreadBeamRx, ThreadBeamState, ThreadBeamTx};
use core::{mem::MaybeUninit, ptr::NonNull};
#[cfg(feature = "parking_lot")]
use parking_lot::{Condvar, Mutex};
#[cfg(not(feature = "parking_lot"))]
use std::sync::{Condvar, Mutex};
#[cfg(not(feature = "parking_lot"))]
macro_rules! cvar_wait {
($lock:ident = $cvar:expr) => {
$lock = $cvar.wait($lock).unwrap();
};
}
#[cfg(feature = "parking_lot")]
macro_rules! cvar_wait {
($lock:ident = $cvar:expr) => {
$cvar.wait(&mut $lock);
};
}
#[cfg(not(feature = "parking_lot"))]
macro_rules! lock_mutex {
($mutex:expr) => {
$mutex.lock().unwrap()
};
}
#[cfg(feature = "parking_lot")]
macro_rules! lock_mutex {
($mutex:expr) => {
$mutex.lock()
};
}
pub(super) struct ThreadBeamInner<T> {
lock: Mutex<ThreadBeamState<T>>,
cvar: Condvar,
}
impl<T: Send> ThreadBeamTx<T> {
pub fn send(self, value: T) {
let inner = unsafe { self.0.as_ref() };
let mut lock = lock_mutex!(inner.lock);
lock.set_data(value);
inner.cvar.notify_all();
}
}
impl<T: Send> Drop for ThreadBeamTx<T> {
fn drop(&mut self) {
let deallocate = {
let inner = unsafe { self.0.as_ref() };
let mut lock = lock_mutex!(inner.lock);
let deallocate = lock.drop_tx();
inner.cvar.notify_all();
deallocate
};
if deallocate {
unsafe { Box::from_raw(self.0.as_ptr()) };
}
}
}
impl<T: Send> ThreadBeamRx<T> {
pub fn recv(self) -> Option<T> {
let inner = unsafe { self.0.as_ref() };
let mut lock = lock_mutex!(inner.lock);
if lock.has_data() {
lock.flags &= !ThreadBeamFlags::HAS_DATA;
return Some(unsafe { lock.data.assume_init_read() });
} else if lock.hung_up() {
return None;
}
cvar_wait!(lock = inner.cvar);
if lock.has_data() {
lock.flags &= !ThreadBeamFlags::HAS_DATA;
Some(unsafe { lock.data.assume_init_read() })
} else {
None
}
}
}
impl<T: Send> Drop for ThreadBeamRx<T> {
fn drop(&mut self) {
let deallocate = {
let inner = unsafe { self.0.as_ref() };
lock_mutex!(inner.lock).drop_rx()
};
if deallocate {
unsafe { Box::from_raw(self.0.as_ptr()) };
}
}
}
pub fn channel<T: Send>() -> (ThreadBeamTx<T>, ThreadBeamRx<T>) {
let inner = Box::into_raw(Box::new(ThreadBeamInner {
lock: Mutex::new(ThreadBeamState {
data: MaybeUninit::uninit(),
flags: ThreadBeamFlags::TX | ThreadBeamFlags::RX,
}),
cvar: Condvar::new(),
}));
let inner = unsafe { NonNull::new_unchecked(inner) };
(ThreadBeamTx(inner), ThreadBeamRx(inner))
}
#[inline]
pub fn spawn<T, R, F>(spawn: F) -> (Option<T>, std::thread::JoinHandle<R>)
where
F: FnOnce(ThreadBeamTx<T>) -> R,
F: Send + 'static,
T: Send + 'static,
R: Send + 'static,
{
let (tx, rx) = channel();
let join = std::thread::spawn(move || spawn(tx));
(rx.recv(), join)
}