#![cfg(feature = "async")]
use core::{
mem,
task::{Context, Poll, Waker},
};
use futures_util::Stream;
use std::{
sync::mpsc::{self, TryRecvError},
thread,
};
pub(crate) enum Unblock<Func, Item, Err> {
Unspawned {
spawn: Func,
},
Spawned {
items: mpsc::Receiver<Result<Item, Err>>,
wakers: mpsc::Sender<Waker>,
},
Hole,
}
impl<
Item: Send + 'static,
Iter: Iterator<Item = Item>,
Err: Send + 'static,
Func: FnOnce() -> Result<Iter, Err> + Send + 'static,
> Unblock<Func, Item, Err>
{
pub(crate) fn new(spawn: Func) -> Self {
Self::Unspawned { spawn }
}
fn spawn(&mut self) {
let spawn = match mem::replace(self, Self::Hole) {
Self::Unspawned { spawn } => spawn,
_ => unreachable!("Unblock::spawn called twice"),
};
let (items_tx, items_rx) = mpsc::channel();
let (wakers_tx, wakers_rx) = mpsc::channel::<Waker>();
thread::Builder::new()
.name("breadx-unblock".into())
.spawn(move || {
let wake_all = move |wait_for_drop: bool| {
if wait_for_drop {
while let Ok(waker) = wakers_rx.recv() {
waker.wake();
}
} else {
while let Ok(waker) = wakers_rx.try_recv() {
waker.wake();
}
}
};
let iter = match spawn() {
Ok(iter) => iter,
Err(err) => {
items_tx.send(Err(err)).expect(CHANNEL_SEND_PANIC);
wake_all(true);
return;
}
};
for item in iter {
items_tx.send(Ok(item)).expect(CHANNEL_SEND_PANIC);
wake_all(false);
}
mem::drop(items_tx);
wake_all(true);
})
.expect("failed to spawn unblock thread");
*self = Self::Spawned {
items: items_rx,
wakers: wakers_tx,
};
}
fn poll_for_item(&mut self, ctx: &mut Context<'_>) -> Poll<Result<Option<Item>, Err>> {
loop {
match mem::replace(self, Self::Hole) {
Self::Hole => {
panic!("cannot poll an empty hole")
}
mut this @ Self::Unspawned { .. } => {
this.spawn();
*self = this;
}
Self::Spawned { items, wakers } => {
match items.try_recv() {
Ok(item) => {
*self = Self::Spawned { items, wakers };
return Poll::Ready(item.map(Some));
}
Err(TryRecvError::Disconnected) => {
return Poll::Ready(Ok(None));
}
Err(TryRecvError::Empty) => {
wakers.send(ctx.waker().clone()).ok();
*self = Self::Spawned { items, wakers };
return Poll::Pending;
}
}
}
}
}
}
}
impl<
Item: Send + 'static,
Iter: Iterator<Item = Item> + Unpin,
Err: Send + 'static,
Func: FnOnce() -> Result<Iter, Err> + Send + Unpin + 'static,
> Stream for Unblock<Func, Item, Err>
{
type Item = Result<Item, Err>;
fn poll_next(
self: core::pin::Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.get_mut().poll_for_item(ctx) {
Poll::Ready(item) => Poll::Ready(item.transpose()),
Poll::Pending => Poll::Pending,
}
}
}
const CHANNEL_SEND_PANIC: &str = "failed to send channel item";
#[cfg(test)]
mod test {
use super::*;
use core::convert::Infallible;
use futures_util::{stream::iter, StreamExt};
use std::{thread::sleep, time::Duration};
#[test]
fn unblock_works() {
spin_on::spin_on(async {
let unblock = Unblock::new(|| {
let iter = (0..10).map(|i| {
sleep(Duration::from_millis(1));
i
});
Result::<_, Infallible>::Ok(iter)
});
unblock
.zip(iter(0..10))
.for_each(|(i, j)| async move {
assert_eq!(i.unwrap(), j);
})
.await;
});
}
}