#![allow(clippy::missing_panics_doc)]
use std::{
pin::Pin,
task::{Context, Poll},
};
use futures::{
channel::{mpsc, oneshot},
future::FutureExt as _,
task::LocalFutureObj,
};
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum Error {
#[error("thread killed before task completed")]
Killed(#[from] oneshot::Canceled),
}
pub type Result<T, E = Error> = std::result::Result<T, E>;
pub struct Thread {
sender: mpsc::UnboundedSender<Request>,
}
type Request = Box<dyn FnOnce() -> LocalFutureObj<'static, ()> + Send>;
pub struct Task<T> {
receiver: oneshot::Receiver<T>,
}
pub struct SendTask<T>(Task<T>);
impl<T: Send> Future for SendTask<T> {
type Output = Result<T>;
fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_unpin(context)
}
}
impl<T> Future for Task<T> {
type Output = Result<T>;
fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
self.receiver
.poll_unpin(context)
.map(|ready| ready.map_err(Into::into))
}
}
impl Thread {
#[must_use]
pub fn new() -> Self {
let (sender, mut receiver) = mpsc::unbounded::<Request>();
std::thread::spawn(|| {
use futures::{StreamExt as _, executor::LocalPool, task::LocalSpawn as _};
let mut executor = LocalPool::new();
let spawner = executor.spawner();
executor.run_until(async move {
while let Some(task) = receiver.next().await {
spawner
.spawn_local_obj(task())
.expect("executor should exist until destroyed");
}
});
});
Self { sender }
}
pub fn run<Context: Post, F: Future<Output: Post> + 'static>(
&self,
context: Context,
code: impl FnOnce(Context) -> F + Send + 'static,
) -> Task<F::Output> {
let (sender, receiver) = oneshot::channel::<F::Output>();
self.sender
.unbounded_send(Box::new(move || {
Box::new(async move {
let _ = sender.send(code(context).await);
})
.into()
}))
.unwrap_or_else(|_| panic!("worker shouldn't die unless dropped"));
Task { receiver }
}
pub fn run_send<Context: Post, F: Future<Output: Send> + 'static>(
&self,
context: Context,
code: impl FnOnce(Context) -> F + Send + 'static,
) -> SendTask<F::Output> {
SendTask(self.run(context, code))
}
}
impl Default for Thread {
fn default() -> Self {
Self::new()
}
}
pub trait Post: Send + 'static {}
impl<T: Send + 'static> Post for T {}
#[test]
fn basic_functionality() {
assert_eq!(
8u8,
futures::executor::LocalPool::new()
.run_until(
Thread::new()
.unwrap()
.run(3u8, |three| async move { three + 5 })
)
.unwrap(),
);
}