use std::{
fmt::{self, Debug, Formatter},
hash::Hash,
ops::{Deref, DerefMut},
pin::Pin,
task::{self, Poll},
};
use consume_on_drop::{Consume, ConsumeOnDrop};
use derivative::Derivative;
use futures::FutureExt as _;
use pin_project::pin_project;
#[derive(Eq, PartialEq, Ord, PartialOrd, Hash)]
#[repr(transparent)]
pub struct Catch<'a, T>(ConsumeOnDrop<CatchInner<'a, T>>);
impl<T: Debug> Debug for Catch<'_, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self.0.this)
}
}
impl<T: Future + Unpin> Future for Catch<'_, T> {
type Output = T::Output;
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0.this).poll(cx)
}
}
impl<'a, T: Send + 'a> Catch<'a, T> {
pub fn new(this: T, catch: Box<dyn FnOnce(T) + Send + 'a>) -> Self {
Catch(ConsumeOnDrop::new(CatchInner { this, catch }))
}
pub fn disarm(Catch(this): Self) -> (T, Box<dyn FnOnce(T) + 'a>) {
let CatchInner { this, catch } = ConsumeOnDrop::into_inner(this);
(this, catch)
}
pub fn sync(this: T) -> (Self, impl Fn() -> Option<T>) {
use parking_lot::Mutex;
use std::sync::Arc;
let v = Arc::new(Mutex::new(None));
let send = {
let v = v.clone();
Box::new(move |this| *v.lock() = Some(this))
};
let recv = move || v.lock().take();
(Catch::new(this, send), recv)
}
pub fn future(this: T) -> (Self, impl Future<Output = Option<T>> + Send + 'a) {
let (send, recv) = tokio::sync::oneshot::channel();
let send = Box::new(move |this| send.send(this).unwrap_or_else(drop));
(Catch::new(this, send), recv.map(Result::ok))
}
}
impl<T> Deref for Catch<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0.this
}
}
impl<T> DerefMut for Catch<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0.this
}
}
#[derive(Derivative)]
#[derivative(Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
#[derivative(
Debug = "transparent",
Debug(bound = "T: Debug"),
Eq(bound = "T: Eq"),
PartialEq(bound = "T: PartialEq"),
Ord(bound = "T: Ord"),
PartialOrd(bound = "T: PartialOrd"),
Hash(bound = "T: Hash")
)]
#[pin_project]
struct CatchInner<'a, T> {
this: T,
#[derivative(
Debug = "ignore",
PartialEq = "ignore",
PartialOrd = "ignore",
Ord = "ignore",
Hash = "ignore"
)]
catch: Box<dyn FnOnce(T) + Send + 'a>,
}
impl<T> Consume for CatchInner<'_, T> {
fn consume(self) {
(self.catch)(self.this)
}
}