1use std::sync::Arc;
2use tokio::runtime::{Builder, Runtime};
3use tokio::sync::{mpsc, oneshot};
4use tokio::task::LocalSet;
5
6pub struct TaskHandler<T> {
14 tx: mpsc::UnboundedSender<T>,
15 highprio_tx: mpsc::UnboundedSender<T>,
16 cb_tx: mpsc::UnboundedSender<T>,
17}
18
19impl<T> Clone for TaskHandler<T> {
20 fn clone(&self) -> Self {
21 Self {
22 tx: self.tx.clone(),
23 highprio_tx: self.highprio_tx.clone(),
24 cb_tx: self.cb_tx.clone(),
25 }
26 }
27}
28
29impl<T> TaskHandler<T> {
30 pub fn send(&self, ctx: T) {
31 if let Err(e) = self.tx.send(ctx) {
32 panic!("failed to send context through mpsc channel, reason: {e}");
33 }
34 }
35
36 pub fn send_highprio(&self, ctx: T) {
37 if let Err(e) = self.highprio_tx.send(ctx) {
38 panic!("failed to send context through high prio mpsc channel, reason: {e}");
39 }
40 }
41
42 pub fn send_cb(&self, ctx: T) {
43 if let Err(e) = self.cb_tx.send(ctx) {
44 panic!("failed to send context through callback mpsc channel, reason: {e}");
45 }
46 }
47}
48
49pub trait Task<T: 'static> {
50
51 fn handler(&mut self, ctx: T) -> impl std::future::Future<Output = ()>;
56
57 fn start(mut self) -> TaskHandler<T> where Self: Sized + 'static {
59 use futures_lite::future;
60 let (tx, mut rx) = mpsc::unbounded_channel::<T>();
61 let (highprio_tx, mut highprio_rx) = mpsc::unbounded_channel::<T>();
62 let (cb_tx, mut cb_rx) = mpsc::unbounded_channel::<T>();
63 tokio::task::spawn_local(async move {
64 while let Some(ctx) = future::or(cb_rx.recv(), future::or(highprio_rx.recv(), rx.recv())).await {
65 self.handler(ctx).await;
66 }
67 });
68 TaskHandler { tx: tx, highprio_tx: highprio_tx, cb_tx: cb_tx }
69 }
70}
71
72pub struct LocalSpawner<C, T> {
73 send: mpsc::UnboundedSender<(T, oneshot::Sender<TaskHandler<C>>)>,
74}
75
76impl<C, T> Clone for LocalSpawner<C, T> {
77 fn clone(&self) -> Self {
78 Self {
79 send: self.send.clone(),
80 }
81 }
82}
83
84impl<C: 'static + Send, T: Task<C> + 'static + Send> LocalSpawner<C, T> {
85 pub fn new_current() -> Self {
87 Self::new(None)
88 }
89
90 pub fn new(runtime: Option<Arc<Runtime>>) -> Self {
92 let (send, mut recv) = mpsc::unbounded_channel::<(T, oneshot::Sender<TaskHandler<C>>)>();
93
94 let rt = if let Some(r) = runtime {
95 r.clone()
96 } else {
97 let r = Builder::new_current_thread()
98 .enable_all()
99 .build()
100 .unwrap();
101 Arc::new(r)
102 };
103
104 std::thread::spawn(move || {
105 let local = LocalSet::new();
106
107 local.spawn_local(async move {
108 while let Some((task, tx)) = recv.recv().await {
109 let task_handle = task.start();
110 let _ = tx.send(task_handle);
112 }
113 });
116
117 rt.block_on(local);
120 });
121
122 Self {
123 send,
124 }
125 }
126
127 pub fn spawn(&self, task: T, tx: oneshot::Sender<TaskHandler<C>>) {
144 self.send.send((task, tx)).expect("Thread with LocalSet has shut down.");
145 }
146}