1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use futures::channel::oneshot::{channel, Receiver};
6use futures::task::{Context, Poll};
7use hyper::rt;
8
9pub type FutureObj = Pin<Box<dyn 'static + Send + Future<Output = ()>>>;
11
12pub type BlockingObj = Box<dyn 'static + Send + FnOnce()>;
14
15pub trait Spawn {
17 fn spawn(&self, fut: FutureObj);
19
20 fn spawn_blocking(&self, task: BlockingObj);
22}
23
24#[derive(Clone)]
26pub struct Executor(pub(crate) Arc<dyn 'static + Send + Sync + Spawn>);
27
28pub struct JoinHandle<T>(Receiver<T>);
30
31impl Executor {
32 #[inline]
34 pub fn spawn<Fut>(&self, fut: Fut) -> JoinHandle<Fut::Output>
35 where
36 Fut: 'static + Send + Future,
37 Fut::Output: 'static + Send,
38 {
39 let (sender, recv) = channel();
40 self.0.spawn(Box::pin(async move {
41 if sender.send(fut.await).is_err() {
42 };
44 }));
45 JoinHandle(recv)
46 }
47
48 #[inline]
50 pub fn spawn_blocking<T, R>(&self, task: T) -> JoinHandle<R>
51 where
52 T: 'static + Send + FnOnce() -> R,
53 R: 'static + Send,
54 {
55 let (sender, recv) = channel();
56 self.0.spawn_blocking(Box::new(|| {
57 if sender.send(task()).is_err() {
58 };
60 }));
61 JoinHandle(recv)
62 }
63}
64
65impl<T> Future for JoinHandle<T> {
66 type Output = T;
67 #[inline]
68 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
69 let ready = futures::ready!(Pin::new(&mut self.0).poll(cx));
70 Poll::Ready(ready.expect("receiver in JoinHandle shouldn't be canceled"))
71 }
72}
73
74impl<F> rt::Executor<F> for Executor
75where
76 F: 'static + Send + Future,
77 F::Output: 'static + Send,
78{
79 #[inline]
80 fn execute(&self, fut: F) {
81 self.0.spawn(Box::pin(async move {
82 let _ = fut.await;
83 }));
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use std::sync::Arc;
90
91 use super::{BlockingObj, Executor, FutureObj, Spawn};
92
93 pub struct Exec;
94
95 impl Spawn for Exec {
96 fn spawn(&self, fut: FutureObj) {
97 tokio::task::spawn(fut);
98 }
99
100 fn spawn_blocking(&self, task: BlockingObj) {
101 tokio::task::spawn_blocking(task);
102 }
103 }
104
105 #[tokio::test]
106 async fn spawn() {
107 let exec = Executor(Arc::new(Exec));
108 assert_eq!(1, exec.spawn(async { 1 }).await);
109 }
110
111 #[tokio::test]
112 async fn spawn_blocking() {
113 let exec = Executor(Arc::new(Exec));
114 assert_eq!(1, exec.spawn_blocking(|| 1).await);
115 }
116}