rework/
worker.rs

1use std::{
2    cell::RefCell,
3    future::Future,
4    marker::PhantomData,
5    pin::Pin,
6    task::{Context, Poll},
7    thread,
8};
9
10use futures::{
11    channel::mpsc::{channel, unbounded, Receiver, Sender, UnboundedReceiver, UnboundedSender},
12    FutureExt, StreamExt,
13};
14use pandet::{PanicMonitor, UnsendOnPanic};
15use tokio::{
16    runtime::Builder,
17    task,
18    time::{self, Duration, Instant, Sleep},
19};
20use tracing::{error, info, span, warn, Instrument, Level};
21
22use crate::PanicInfo;
23
24use super::{Command, Request, WorkFn, Workload};
25
26thread_local! {
27    static REQ_COUNTER: RefCell<usize> = RefCell::new(0);
28}
29
30/// Convenient function for sending a message of type `T` through an `UnboundedSender<T>`.
31#[inline]
32fn send<T>(tx: &UnboundedSender<T>, payload: T, err_msg: &str) -> Result<(), T> {
33    if let Err(e) = tx.unbounded_send(payload) {
34        warn!("{}: {:?}", err_msg, e);
35        Err(e.into_inner())
36    } else {
37        Ok(())
38    }
39}
40
41struct ConnTimeout {
42    thres: Duration,
43    last: Instant,
44    sleep: Pin<Box<Sleep>>,
45}
46
47impl ConnTimeout {
48    pub fn new(thres: Duration) -> Self {
49        ConnTimeout {
50            thres,
51            last: Instant::now(),
52            sleep: Box::pin(time::sleep(thres)),
53        }
54    }
55
56    pub fn reset(&mut self) {
57        self.last = Instant::now();
58        self.sleep
59            .as_mut()
60            .reset(Instant::now().checked_add(self.thres).unwrap());
61    }
62}
63
64impl Future for ConnTimeout {
65    type Output = ();
66
67    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
68        self.sleep.poll_unpin(cx)
69    }
70}
71
72struct InnerWorker<Req, Resp, Fut, WFn> {
73    resp_tx: UnboundedSender<(Workload, Resp)>,
74    task_fut: WFn,
75    _pd: PhantomData<(Req, Fut)>,
76}
77
78impl<Req, Resp, Fut, WFn> InnerWorker<Req, Resp, Fut, WFn>
79where
80    Req: Request,
81    Resp: Send + 'static,
82    WFn: WorkFn<Req, Resp, Fut> + 'static,
83    Fut: Future<Output = Resp>,
84{
85    async fn work(&self, req: Req) {
86        info!("New request {:?}", &req);
87
88        let wl = req.workload();
89        let resp = self.task_fut.work(req).await;
90        let _ = send(
91            &self.resp_tx,
92            (wl, resp),
93            "Error sending response to Dispatcher",
94        );
95    }
96}
97
98//                   ┌─────────────────┐
99//         cmd_rx -->│                 │
100//                   │  worker-thread  │--> resp_tx
101//         req_rx -->│                 │
102//                   └─────────────────┘
103pub(crate) struct Worker<Req, Resp> {
104    id: usize,
105    agent: WorkerAgent<Req>,
106    req_rx: UnboundedReceiver<Req>,
107    resp_tx: UnboundedSender<(Workload, Resp)>,
108    panic_tx: UnboundedSender<(Workload, PanicInfo)>,
109    cmd_rx: Receiver<Command>,
110    shutdown_grace_period: Duration,
111}
112
113impl<Req, Resp> Worker<Req, Resp> {
114    pub(crate) fn new(
115        id: usize,
116        resp_tx: UnboundedSender<(Workload, Resp)>,
117        panic_tx: UnboundedSender<(Workload, PanicInfo)>,
118        shutdown_grace_period: Duration,
119    ) -> Self {
120        let (req_tx, req_rx) = unbounded();
121        let (cmd_tx, cmd_rx) = channel(0);
122
123        Worker {
124            id,
125            agent: WorkerAgent { id, req_tx, cmd_tx },
126            req_rx,
127            resp_tx,
128            panic_tx,
129            cmd_rx,
130            shutdown_grace_period,
131        }
132    }
133
134    pub(crate) fn new_agent(&self) -> WorkerAgent<Req> {
135        self.agent.clone()
136    }
137
138    /// Deploys the `Worker` onto a spawned `Thread`. Returns `Worker.id` when the thread exits.
139    pub(crate) fn deploy<Fut, WFn, IFn>(
140        mut self,
141        make_work_fn: impl Fn() -> WFn + Clone + Send + 'static,
142        make_init_fn: impl Fn() -> IFn + Clone + Send + 'static,
143    ) -> thread::JoinHandle<usize>
144    where
145        Req: Request,
146        Resp: Send + 'static,
147        Fut: Future<Output = Resp> + 'static,
148        WFn: WorkFn<Req, Resp, Fut> + 'static,
149        Fut: Future<Output = Resp> + 'static,
150        IFn: WorkFn<(), ()> + 'static,
151    {
152        let work = move || {
153            let work_fn = make_work_fn();
154            let init_fn = make_init_fn();
155            let id = self.id;
156            let span = span!(Level::ERROR, "worker", id = id);
157            let _enter = span.enter();
158            info!("Worker thread deployed");
159
160            let rt = match Builder::new_current_thread().enable_all().build() {
161                Ok(rt) => rt,
162                Err(e) => {
163                    error!("Failed to build thread-local tokio runtime: {:?}", e);
164                    return self.id;
165                }
166            };
167            let local = task::LocalSet::new();
168            local.block_on(&rt, async move {
169                init_fn.work(()).await;
170
171                let inner_worker = Box::new(InnerWorker {
172                    resp_tx: self.resp_tx,
173                    task_fut: work_fn,
174                    _pd: PhantomData,
175                });
176
177                // FIXME: memory leak if worker thread exited prematurely
178                let inner_worker: &'static InnerWorker<_, _, _, _> = &*Box::leak(inner_worker);
179
180                let (ref mut monitor, det) = PanicMonitor::<(usize, String, Workload)>::new();
181
182                let fut = async move {
183                    let shutting_down = &mut false;
184                    let conn_timeout = ConnTimeout::new(self.shutdown_grace_period);
185                    tokio::pin!(conn_timeout);
186                    loop {
187                        tokio::select! {
188                            Some(cmd) = self.cmd_rx.next(), if !*shutting_down => {
189                                info!("Received command {cmd:?}");
190                                match cmd {
191                                    Command::Shutdown => {
192                                        *shutting_down = true;
193                                    },
194                                }
195                            },
196                            Some(req) = self.req_rx.next(), if !*shutting_down => {
197                                let req_id = REQ_COUNTER.with(|c| {
198                                    *c.borrow_mut() += 1;
199                                    *c.borrow()
200                                });
201
202                                let kind = format!("{req:?}");
203                                let wl = req.workload();
204                                let work_req_span = span!(Level::ERROR, "work-req", id=req_id);
205                                task::spawn_local(
206                                    async move {
207                                        inner_worker.work(req).await;
208                                    }
209                                    .instrument(work_req_span)
210                                    .unsend_on_panic_info(&det, (req_id, kind, wl))
211                                );
212                                conn_timeout.reset();
213                            },
214                            _ = &mut conn_timeout, if *shutting_down => {
215                                info!("Connections timed out");
216                                break;
217                            },
218                            else => {
219                                warn!("WorkerAgent dropped");
220                                break;
221                            }
222                        }
223                    }
224                };
225                task::spawn_local(fut);
226
227                while let Some(e) = monitor.next().await {
228                    let (req_id, kind, wl) = e.0;
229                    warn!("Panic detected when working on work-req{{id={req_id}}} {kind}");
230                    let _ = send(
231                        &self.panic_tx,
232                        (wl, PanicInfo { request: kind }),
233                        "Error reporting panick to Dispatcher",
234                    );
235                }
236            });
237
238            info!("Worker thread exited");
239            id
240        };
241
242        thread::Builder::new()
243            .name(format!("worker-thread-{}", self.id))
244            .spawn(work)
245            .expect("failed to deploy worker {self.id}")
246    }
247}
248
249pub(crate) struct WorkerAgent<Req> {
250    id: usize,
251    req_tx: UnboundedSender<Req>,
252    cmd_tx: Sender<Command>,
253}
254
255impl<Req> WorkerAgent<Req> {
256    pub(crate) fn request(&self, req: Req) {
257        let _ = send(
258            &self.req_tx,
259            req,
260            &format!("Error sending work request to worker {}", self.id),
261        );
262    }
263
264    pub(crate) fn command(&mut self, cmd: Command) {
265        if let Err(e) = self.cmd_tx.try_send(cmd) {
266            error!("Error sending command to worker {}: {:?}", self.id, e);
267        }
268    }
269}
270
271impl<Req> Clone for WorkerAgent<Req> {
272    fn clone(&self) -> Self {
273        Self {
274            id: self.id,
275            req_tx: self.req_tx.clone(),
276            cmd_tx: self.cmd_tx.clone(),
277        }
278    }
279}