use std::cell::UnsafeCell;
#[must_use = "You should call Sender::send with the result"]
pub struct Sender<T>(std::sync::mpsc::Sender<T>);
impl<T> Sender<T> {
pub fn send(self, value: T) {
self.0.send(value).ok(); }
}
#[must_use]
pub struct Promise<T: Send + 'static> {
data: PromiseImpl<T>,
#[cfg(feature = "tokio")]
join_handle: Option<tokio::task::JoinHandle<()>>,
}
#[cfg(all(feature = "tokio", feature = "web"))]
compile_error!("You cannot specify both the 'tokio' and 'web' feature");
static_assertions::assert_not_impl_all!(Promise<u32>: Sync);
static_assertions::assert_impl_all!(Promise<u32>: Send);
impl<T: Send + 'static> Promise<T> {
pub fn new() -> (Sender<T>, Self) {
let (tx, rx) = std::sync::mpsc::channel();
(
Sender(tx),
Self {
data: PromiseImpl(UnsafeCell::new(PromiseStatus::Pending(rx))),
#[cfg(feature = "tokio")]
join_handle: None,
},
)
}
pub fn from_ready(value: T) -> Self {
Self {
data: PromiseImpl(UnsafeCell::new(PromiseStatus::Ready(value))),
#[cfg(feature = "tokio")]
join_handle: None,
}
}
#[cfg(any(feature = "tokio", feature = "web"))]
pub fn spawn_async(
#[cfg(feature = "tokio")] future: impl std::future::Future<Output = T> + 'static + Send,
#[cfg(feature = "web")] future: impl std::future::Future<Output = T> + 'static,
) -> Self {
#[cfg(feature = "tokio")]
{
let (sender, mut promise) = Self::new();
promise.join_handle =
Some(tokio::task::spawn(async move { sender.send(future.await) }));
promise
}
#[cfg(feature = "web")]
{
let (sender, promise) = Self::new();
wasm_bindgen_futures::spawn_local(async move { sender.send(future.await) });
promise
}
}
#[cfg(feature = "tokio")]
pub fn spawn_blocking<F>(f: F) -> Self
where
F: FnOnce() -> T + Send + 'static,
{
let (sender, mut promise) = Self::new();
promise.join_handle = Some(tokio::task::spawn(async move {
sender.send(tokio::task::block_in_place(f));
}));
promise
}
#[cfg(not(target_arch = "wasm32"))] pub fn spawn_thread<F>(thread_name: impl Into<String>, f: F) -> Self
where
F: FnOnce() -> T + Send + 'static,
{
let (sender, promise) = Self::new();
std::thread::Builder::new()
.name(thread_name.into())
.spawn(move || sender.send(f()))
.expect("Failed to spawn thread");
promise
}
pub fn ready(&self) -> Option<&T> {
match self.poll() {
std::task::Poll::Pending => None,
std::task::Poll::Ready(value) => Some(value),
}
}
pub fn ready_mut(&mut self) -> Option<&mut T> {
match self.poll_mut() {
std::task::Poll::Pending => None,
std::task::Poll::Ready(value) => Some(value),
}
}
pub fn try_take(self) -> Result<T, Self> {
self.data.try_take().map_err(|data| Self {
data,
#[cfg(feature = "tokio")]
join_handle: self.join_handle,
})
}
pub fn block_until_ready(&self) -> &T {
self.data.block_until_ready()
}
pub fn block_until_ready_mut(&mut self) -> &mut T {
self.data.block_until_ready_mut()
}
pub fn block_and_take(self) -> T {
self.data.block_until_ready();
match self.data.0.into_inner() {
PromiseStatus::Pending(_) => unreachable!(),
PromiseStatus::Ready(value) => value,
}
}
pub fn poll(&self) -> std::task::Poll<&T> {
self.data.poll()
}
pub fn poll_mut(&mut self) -> std::task::Poll<&mut T> {
self.data.poll_mut()
}
#[cfg(feature = "tokio")]
pub fn abort(self) {
if let Some(join_handle) = self.join_handle {
join_handle.abort();
}
}
}
enum PromiseStatus<T: Send + 'static> {
Pending(std::sync::mpsc::Receiver<T>),
Ready(T),
}
struct PromiseImpl<T: Send + 'static>(UnsafeCell<PromiseStatus<T>>);
impl<T: Send + 'static> PromiseImpl<T> {
#[allow(unused_variables)]
fn poll_mut(&mut self) -> std::task::Poll<&mut T> {
let inner = self.0.get_mut();
match inner {
PromiseStatus::Pending(rx) => {
if let Ok(value) = rx.try_recv() {
*inner = PromiseStatus::Ready(value);
match inner {
PromiseStatus::Ready(ref mut value) => std::task::Poll::Ready(value),
PromiseStatus::Pending(_) => unreachable!(),
}
} else {
std::task::Poll::Pending
}
}
PromiseStatus::Ready(ref mut value) => std::task::Poll::Ready(value),
}
}
fn try_take(self) -> Result<T, Self> {
let inner = self.0.into_inner();
match inner {
PromiseStatus::Pending(ref rx) => match rx.try_recv() {
Ok(value) => Ok(value),
Err(std::sync::mpsc::TryRecvError::Empty) => {
Err(PromiseImpl(UnsafeCell::new(inner)))
}
Err(std::sync::mpsc::TryRecvError::Disconnected) => {
panic!("The Promise Sender was dropped")
}
},
PromiseStatus::Ready(value) => Ok(value),
}
}
#[allow(unsafe_code)]
#[allow(unused_variables)]
fn poll(&self) -> std::task::Poll<&T> {
let this = unsafe {
self.0.get().as_mut().expect("UnsafeCell should be valid")
};
match this {
PromiseStatus::Pending(rx) => match rx.try_recv() {
Ok(value) => {
*this = PromiseStatus::Ready(value);
match this {
PromiseStatus::Ready(ref value) => std::task::Poll::Ready(value),
PromiseStatus::Pending(_) => unreachable!(),
}
}
Err(std::sync::mpsc::TryRecvError::Empty) => std::task::Poll::Pending,
Err(std::sync::mpsc::TryRecvError::Disconnected) => {
panic!("The Promise Sender was dropped")
}
},
PromiseStatus::Ready(ref value) => std::task::Poll::Ready(value),
}
}
#[allow(unused_variables)]
fn block_until_ready_mut(&mut self) -> &mut T {
let inner = self.0.get_mut();
match inner {
PromiseStatus::Pending(rx) => {
let value = rx.recv().expect("The Promise Sender was dropped");
*inner = PromiseStatus::Ready(value);
match inner {
PromiseStatus::Ready(ref mut value) => value,
PromiseStatus::Pending(_) => unreachable!(),
}
}
PromiseStatus::Ready(ref mut value) => value,
}
}
#[allow(unsafe_code)]
#[allow(unused_variables)]
fn block_until_ready(&self) -> &T {
let this = unsafe {
self.0.get().as_mut().expect("UnsafeCell should be valid")
};
match this {
PromiseStatus::Pending(rx) => {
let value = rx.recv().expect("The Promise Sender was dropped");
*this = PromiseStatus::Ready(value);
match this {
PromiseStatus::Ready(ref value) => value,
PromiseStatus::Pending(_) => unreachable!(),
}
}
PromiseStatus::Ready(ref value) => value,
}
}
}