use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use futures::stream::FuturesUnordered;
use futures::{
channel::{
mpsc::{unbounded, UnboundedReceiver, UnboundedSender},
oneshot,
},
FutureExt, Stream,
};
pub struct Panicked<Msg = ()>(pub Msg)
where
Msg: Send + 'static;
pub struct DetectorHook<Msg = ()>(UnboundedSender<RxHandle<Msg>>)
where
Msg: Send + 'static;
impl<Msg> Clone for DetectorHook<Msg>
where
Msg: Send + 'static,
{
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
pub trait LocalAlert<'a, T> {
fn local_alert<A>(self, hook: &'_ A) -> LocalPanicAwareFuture<'a, T>
where
A: AsRef<DetectorHook>;
fn local_alert_msg<Msg, A>(self, hook: &'_ A, msg: Msg) -> LocalPanicAwareFuture<'a, T>
where
Msg: Send + 'static,
A: AsRef<DetectorHook<Msg>>;
}
impl<'a, F> LocalAlert<'a, F::Output> for F
where
F: Future + 'a,
{
fn local_alert<A>(self, hook: &'_ A) -> LocalPanicAwareFuture<'a, F::Output>
where
A: AsRef<DetectorHook>,
{
self.local_alert_msg(hook, ())
}
fn local_alert_msg<Msg, A>(self, hook: &'_ A, msg: Msg) -> LocalPanicAwareFuture<'a, F::Output>
where
Msg: Send + 'static,
A: AsRef<DetectorHook<Msg>>,
{
let (tx, rx) = oneshot::channel();
hook.as_ref()
.0
.unbounded_send(RxHandle {
signaler: rx,
msg: Some(msg),
})
.expect("detector dropped early");
LocalPanicAwareFuture::new(async move {
let ret = self.await;
let _ = tx.send(());
ret
})
}
}
#[doc(hidden)]
pub struct LocalPanicAwareFuture<'a, T> {
inner: Pin<Box<dyn Future<Output = T> + 'a>>,
}
impl<'a, T> LocalPanicAwareFuture<'a, T> {
fn new<Fut>(fut: Fut) -> Self
where
Fut: Future<Output = T> + 'a,
{
LocalPanicAwareFuture {
inner: Box::pin(fut),
}
}
}
impl<T> Future for LocalPanicAwareFuture<'_, T> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.inner.poll_unpin(cx)
}
}
pub trait Alert<'a, T> {
fn alert<A>(self, hook: &'_ A) -> PanicAwareFuture<'a, T>
where
A: AsRef<DetectorHook>;
fn alert_msg<Msg, A>(self, hook: &'_ A, msg: Msg) -> PanicAwareFuture<'a, T>
where
Msg: Send + 'static,
A: AsRef<DetectorHook<Msg>>;
}
impl<'a, F> Alert<'a, F::Output> for F
where
F: Future + Send + 'a,
{
fn alert<A>(self, hook: &'_ A) -> PanicAwareFuture<'a, F::Output>
where
A: AsRef<DetectorHook>,
{
self.alert_msg(hook, ())
}
fn alert_msg<Msg, A>(self, hook: &'_ A, msg: Msg) -> PanicAwareFuture<'a, F::Output>
where
Msg: Send + 'static,
A: AsRef<DetectorHook<Msg>>,
{
let (tx, rx) = oneshot::channel();
hook.as_ref()
.0
.unbounded_send(RxHandle {
signaler: rx,
msg: Some(msg),
})
.expect("detector dropped early");
PanicAwareFuture::new(async move {
let ret = self.await;
let _ = tx.send(());
ret
})
}
}
#[doc(hidden)]
pub struct PanicAwareFuture<'a, T> {
inner: Pin<Box<dyn Future<Output = T> + Send + 'a>>,
}
impl<'a, T> PanicAwareFuture<'a, T> {
fn new<Fut>(fut: Fut) -> Self
where
Fut: Future<Output = T> + Send + 'a,
{
PanicAwareFuture {
inner: Box::pin(fut),
}
}
}
impl<T> Future for PanicAwareFuture<'_, T> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.inner.poll_unpin(cx)
}
}
struct RxHandle<Msg>
where
Msg: Send + 'static,
{
signaler: oneshot::Receiver<()>,
msg: Option<Msg>,
}
impl<Msg> Future for RxHandle<Msg>
where
Msg: Unpin + Send + 'static,
{
type Output = Option<Panicked<Msg>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let res = self.signaler.poll_unpin(cx);
res.map(|r| {
if r.is_err() {
Some(Panicked(self.msg.take().expect("message already read")))
} else {
None
}
})
}
}
pub struct PanicDetector<Msg = ()>
where
Msg: Send + 'static,
{
detector: Option<DetectorHook<Msg>>,
rx: UnboundedReceiver<RxHandle<Msg>>,
hooks: FuturesUnordered<RxHandle<Msg>>,
rx_closed: bool,
}
impl<Msg> PanicDetector<Msg>
where
Msg: Send + 'static,
{
pub fn new() -> Self {
let (tx, rx) = unbounded();
PanicDetector {
detector: Some(DetectorHook(tx)),
rx,
hooks: FuturesUnordered::new(),
rx_closed: false,
}
}
}
impl<Msg> AsRef<DetectorHook<Msg>> for PanicDetector<Msg>
where
Msg: Unpin + Send + 'static,
{
fn as_ref(&self) -> &DetectorHook<Msg> {
match self.detector {
Some(ref det) => det,
None => panic!(
"This detector has been polled. Create a new detector to receive new panic alerts."
),
}
}
}
impl<Msg> Future for PanicDetector<Msg>
where
Msg: Unpin + Send + 'static,
{
type Output = Option<Panicked<Msg>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.detector = None;
while !self.rx_closed {
match Pin::new(&mut self.rx).poll_next(cx) {
Poll::Pending => break,
Poll::Ready(Some(r)) => self.hooks.push(r),
Poll::Ready(None) => {
self.rx_closed = true;
break;
}
}
}
loop {
let res = Pin::new(&mut self.hooks).poll_next(cx);
match res {
Poll::Ready(Some(r)) => {
if r.is_some() {
break Poll::Ready(r);
}
}
Poll::Ready(None) => {
if self.rx_closed {
break Poll::Ready(None);
} else {
break Poll::Pending;
}
}
Poll::Pending => {
break Poll::Pending;
}
}
}
}
}
pub struct PanicMonitor<Msg = ()>
where
Msg: Send + 'static,
{
detector: Option<DetectorHook<Msg>>,
rx: UnboundedReceiver<RxHandle<Msg>>,
hooks: FuturesUnordered<RxHandle<Msg>>,
rx_closed: bool,
}
impl<Msg> PanicMonitor<Msg>
where
Msg: Send + 'static,
{
pub fn new() -> Self {
let (tx, rx) = unbounded();
PanicMonitor {
detector: Some(DetectorHook(tx)),
rx,
hooks: FuturesUnordered::new(),
rx_closed: false,
}
}
}
impl<Msg> AsRef<DetectorHook<Msg>> for PanicMonitor<Msg>
where
Msg: Unpin + Send + 'static,
{
fn as_ref(&self) -> &DetectorHook<Msg> {
match self.detector {
Some(ref det) => det,
None => panic!(
"This monitor has been polled. Create a new monitor to receive new panic alerts."
),
}
}
}
impl<Msg> Stream for PanicMonitor<Msg>
where
Msg: Unpin + Send + 'static,
{
type Item = Panicked<Msg>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.detector = None;
while !self.rx_closed {
match Pin::new(&mut self.rx).poll_next(cx) {
Poll::Pending => break,
Poll::Ready(Some(r)) => self.hooks.push(r),
Poll::Ready(None) => {
self.rx_closed = true;
break;
}
}
}
loop {
let res = Pin::new(&mut self.hooks).poll_next(cx);
match res {
Poll::Ready(Some(r)) => {
if r.is_some() {
break Poll::Ready(r);
}
}
Poll::Ready(None) => {
if self.rx_closed {
break Poll::Ready(None);
} else {
break Poll::Pending;
}
}
Poll::Pending => {
break Poll::Pending;
}
}
}
}
}
impl<Msg: Unpin + Send + 'static> Unpin for PanicMonitor<Msg> {}
#[cfg(test)]
mod tests {
use futures::StreamExt;
use super::*;
#[tokio::test]
async fn alert_works() {
let detector = PanicDetector::new();
for i in 0..=10 {
tokio::spawn(
async move {
if i == 1 {
panic!("What could go wrong");
}
}
.alert(&detector),
);
}
assert!(detector.await.is_some());
let detector = PanicDetector::new();
(0..=10).for_each(|_| {
tokio::spawn((|| async move {}.alert(&detector))());
});
assert!(detector.await.is_none());
}
#[tokio::test]
async fn unsend_works() {
let detector = PanicDetector::new();
let local = tokio::task::LocalSet::new();
local
.run_until(async move {
{
let _ = tokio::task::spawn_local(
async move {
}
.local_alert(&detector),
);
}
assert!(detector.await.is_none());
})
.await;
}
#[tokio::test]
async fn monitor_works() {
let mut monitor = PanicMonitor::new();
for i in 0..=10 {
tokio::spawn(
async move {
if i % 3 == 0 {
panic!();
}
}
.alert_msg(&monitor, i),
);
}
let mut count = 0;
while let Some(res) = monitor.next().await {
let id = res.0;
assert_eq!(id % 3, 0);
count += 1;
}
assert_eq!(count, 4);
}
}