Skip to main content

ntex_server/
wrk.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2use std::task::{Context, Poll, ready};
3use std::{cmp, future::Future, future::poll_fn, hash, pin::Pin, sync::Arc};
4
5use async_channel::{Receiver, Sender, unbounded};
6use atomic_waker::AtomicWaker;
7use core_affinity::CoreId;
8
9use ntex_rt::{Arbiter, spawn};
10use ntex_service::{Pipeline, PipelineBinding, Service, ServiceFactory};
11use ntex_util::future::{Either, Stream, select, stream_recv};
12use ntex_util::time::{Millis, sleep, timeout_checked};
13
14use crate::ServerConfiguration;
15
16const STOP_TIMEOUT: Millis = Millis(3000);
17
18#[derive(Debug)]
19/// Shutdown worker
20struct Shutdown {
21    timeout: Millis,
22    result: oneshot::Sender<bool>,
23}
24
25#[derive(Copy, Clone, Default, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
26/// Worker status
27pub enum WorkerStatus {
28    Available,
29    #[default]
30    Unavailable,
31    Failed,
32}
33
34#[derive(Debug)]
35/// Server worker
36///
37/// Worker accepts message via unbounded channel and starts processing.
38pub struct Worker<T> {
39    name: String,
40    tx1: Sender<T>,
41    tx2: Sender<Shutdown>,
42    avail: WorkerAvailability,
43    failed: Arc<AtomicBool>,
44}
45
46impl<T> cmp::Ord for Worker<T> {
47    fn cmp(&self, other: &Self) -> cmp::Ordering {
48        self.name.cmp(&other.name)
49    }
50}
51
52impl<T> cmp::PartialOrd for Worker<T> {
53    fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
54        Some(self.cmp(other))
55    }
56}
57
58impl<T> hash::Hash for Worker<T> {
59    fn hash<H: hash::Hasher>(&self, state: &mut H) {
60        self.name.hash(state);
61    }
62}
63
64impl<T> Eq for Worker<T> {}
65
66impl<T> PartialEq for Worker<T> {
67    fn eq(&self, other: &Worker<T>) -> bool {
68        self.name == other.name
69    }
70}
71
72#[derive(Debug)]
73/// Stop worker process
74///
75/// Stop future resolves when worker completes processing
76/// incoming items and stop arbiter
77pub struct WorkerStop(oneshot::Receiver<bool>);
78
79impl<T> Worker<T> {
80    /// Start worker.
81    pub fn start<F>(name: String, cfg: F, cid: Option<CoreId>) -> Worker<T>
82    where
83        T: Send + 'static,
84        F: ServerConfiguration<Item = T>,
85    {
86        let (tx1, rx1) = unbounded();
87        let (tx2, rx2) = unbounded();
88        let (avail, avail_tx) = WorkerAvailability::create();
89        let name2 = name.clone();
90
91        Arbiter::with_name(name.clone()).exec_fn(move || {
92            if let Some(cid) = cid
93                && core_affinity::set_for_current(cid)
94            {
95                log::info!("Set affinity to {cid:?} for worker {name2:?}");
96            }
97
98            let _ = spawn(async move {
99                log::info!("Starting worker {name2:?}");
100
101                log::debug!("Creating server instance in {name2:?}");
102                let factory = cfg.create().await;
103
104                match create(name2.clone(), rx1, rx2, factory, avail_tx).await {
105                    Ok((svc, wrk)) => {
106                        log::debug!("Server instance has been created in {name2:?}");
107                        run_worker(svc, wrk).await;
108                    }
109                    Err(e) => {
110                        log::error!("Cannot start worker {name2:?}: {e:?}");
111                    }
112                }
113                Arbiter::current().stop();
114            });
115        });
116
117        Worker {
118            tx1,
119            tx2,
120            name,
121            avail,
122            failed: Arc::new(AtomicBool::new(false)),
123        }
124    }
125
126    /// Worker name
127    pub fn name(&self) -> &str {
128        &self.name
129    }
130
131    /// Send message to the worker.
132    ///
133    /// Returns `Ok` if message got accepted by the worker.
134    /// Otherwise return message back as `Err`
135    pub fn send(&self, msg: T) -> Result<(), T> {
136        self.tx1.try_send(msg).map_err(|msg| msg.into_inner())
137    }
138
139    /// Check worker status.
140    pub fn status(&self) -> WorkerStatus {
141        if self.failed.load(Ordering::Acquire) {
142            WorkerStatus::Failed
143        } else if self.avail.available() {
144            WorkerStatus::Available
145        } else {
146            WorkerStatus::Unavailable
147        }
148    }
149
150    /// Wait for worker status updates
151    pub async fn wait_for_status(&mut self) -> WorkerStatus {
152        if self.failed.load(Ordering::Acquire) {
153            WorkerStatus::Failed
154        } else {
155            self.avail.wait_for_update().await;
156            if self.avail.failed() {
157                self.failed.store(true, Ordering::Release);
158            }
159            self.status()
160        }
161    }
162
163    /// Stop worker.
164    ///
165    /// If timeout value is zero, force shutdown worker
166    pub fn stop(&self, timeout: Millis) -> WorkerStop {
167        let (result, rx) = oneshot::channel();
168        let _ = self.tx2.try_send(Shutdown { timeout, result });
169        WorkerStop(rx)
170    }
171}
172
173impl<T> Clone for Worker<T> {
174    fn clone(&self) -> Self {
175        Worker {
176            tx1: self.tx1.clone(),
177            tx2: self.tx2.clone(),
178            name: self.name.clone(),
179            avail: self.avail.clone(),
180            failed: self.failed.clone(),
181        }
182    }
183}
184
185impl Future for WorkerStop {
186    type Output = bool;
187
188    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
189        match ready!(Pin::new(&mut self.0).poll(cx)) {
190            Ok(res) => Poll::Ready(res),
191            Err(_) => Poll::Ready(true),
192        }
193    }
194}
195
196#[derive(Debug, Clone)]
197struct WorkerAvailability {
198    inner: Arc<Inner>,
199}
200
201#[derive(Debug, Clone)]
202struct WorkerAvailabilityTx {
203    inner: Arc<Inner>,
204}
205
206#[derive(Debug)]
207struct Inner {
208    waker: AtomicWaker,
209    updated: AtomicBool,
210    available: AtomicBool,
211    failed: AtomicBool,
212}
213
214impl WorkerAvailability {
215    fn create() -> (Self, WorkerAvailabilityTx) {
216        let inner = Arc::new(Inner {
217            waker: AtomicWaker::new(),
218            updated: AtomicBool::new(false),
219            available: AtomicBool::new(false),
220            failed: AtomicBool::new(false),
221        });
222
223        let avail = WorkerAvailability {
224            inner: inner.clone(),
225        };
226        let avail_tx = WorkerAvailabilityTx { inner };
227        (avail, avail_tx)
228    }
229
230    fn failed(&self) -> bool {
231        self.inner.failed.load(Ordering::Acquire)
232    }
233
234    fn available(&self) -> bool {
235        self.inner.available.load(Ordering::Acquire)
236    }
237
238    async fn wait_for_update(&self) {
239        poll_fn(|cx| {
240            if self.inner.updated.load(Ordering::Acquire) {
241                self.inner.updated.store(false, Ordering::Release);
242                Poll::Ready(())
243            } else {
244                self.inner.waker.register(cx.waker());
245                Poll::Pending
246            }
247        })
248        .await;
249    }
250}
251
252impl WorkerAvailabilityTx {
253    fn set(&self, val: bool) {
254        let old = self.inner.available.swap(val, Ordering::Release);
255        if old != val {
256            self.inner.updated.store(true, Ordering::Release);
257            self.inner.waker.wake();
258        }
259    }
260}
261
262impl Drop for WorkerAvailabilityTx {
263    fn drop(&mut self) {
264        self.inner.failed.store(true, Ordering::Release);
265        self.inner.updated.store(true, Ordering::Release);
266        self.inner.available.store(false, Ordering::Release);
267        self.inner.waker.wake();
268    }
269}
270
271/// Service worker
272///
273/// Worker accepts message via unbounded channel and starts processing.
274struct WorkerSt<T, F: ServiceFactory<T>> {
275    name: String,
276    rx: Receiver<T>,
277    stop: Pin<Box<dyn Stream<Item = Shutdown>>>,
278    factory: F,
279    availability: WorkerAvailabilityTx,
280}
281
282async fn run_worker<T, F>(mut svc: PipelineBinding<F::Service, T>, mut wrk: WorkerSt<T, F>)
283where
284    T: Send + 'static,
285    F: ServiceFactory<T> + 'static,
286{
287    loop {
288        let mut recv = std::pin::pin!(wrk.rx.recv());
289        let fut = poll_fn(|cx| {
290            match svc.poll_ready(cx) {
291                Poll::Ready(Ok(())) => {
292                    wrk.availability.set(true);
293                }
294                Poll::Ready(Err(err)) => {
295                    wrk.availability.set(false);
296                    return Poll::Ready(Err(err));
297                }
298                Poll::Pending => {
299                    wrk.availability.set(false);
300                    return Poll::Pending;
301                }
302            }
303
304            match ready!(recv.as_mut().poll(cx)) {
305                Ok(item) => {
306                    let fut = svc.call(item);
307                    let _ = spawn(async move {
308                        let _ = fut.await;
309                    });
310                    Poll::Ready(Ok::<_, F::Error>(true))
311                }
312                Err(_) => {
313                    log::error!("Server is gone");
314                    Poll::Ready(Ok(false))
315                }
316            }
317        });
318
319        match select(fut, stream_recv(&mut wrk.stop)).await {
320            Either::Left(Ok(true)) => continue,
321            Either::Left(Err(_)) => {
322                let _ = ntex_rt::spawn(async move {
323                    svc.shutdown().await;
324                });
325            }
326            Either::Right(Some(Shutdown { timeout, result })) => {
327                wrk.availability.set(false);
328
329                let timeout = if timeout.is_zero() { STOP_TIMEOUT } else { timeout };
330
331                stop_svc(&wrk.name, svc, timeout, Some(result)).await;
332                return;
333            }
334            Either::Left(Ok(false)) | Either::Right(None) => {
335                wrk.availability.set(false);
336                stop_svc(&wrk.name, svc, STOP_TIMEOUT, None).await;
337                return;
338            }
339        }
340
341        // re-create service
342        loop {
343            match select(wrk.factory.create(()), stream_recv(&mut wrk.stop)).await {
344                Either::Left(Ok(service)) => {
345                    svc = Pipeline::new(service).bind();
346                    break;
347                }
348                Either::Left(Err(_)) => sleep(Millis::ONE_SEC).await,
349                Either::Right(_) => return,
350            }
351        }
352    }
353}
354
355async fn stop_svc<T, F>(
356    name: &str,
357    svc: PipelineBinding<F, T>,
358    timeout: Millis,
359    result: Option<oneshot::Sender<bool>>,
360) where
361    T: Send + 'static,
362    F: Service<T> + 'static,
363{
364    let res = timeout_checked(timeout, svc.shutdown()).await;
365    if let Some(result) = result {
366        let _ = result.send(res.is_ok());
367    }
368
369    log::info!("Worker {name:?} has been stopped");
370}
371
372async fn create<T, F>(
373    name: String,
374    rx: Receiver<T>,
375    stop: Receiver<Shutdown>,
376    factory: Result<F, ()>,
377    availability: WorkerAvailabilityTx,
378) -> Result<(PipelineBinding<F::Service, T>, WorkerSt<T, F>), ()>
379where
380    T: Send + 'static,
381    F: ServiceFactory<T> + 'static,
382{
383    availability.set(false);
384    let factory = factory?;
385    let mut stop = Box::pin(stop);
386
387    let svc = match select(factory.create(()), stream_recv(&mut stop)).await {
388        Either::Left(Ok(svc)) => Pipeline::new(svc).bind(),
389        Either::Left(Err(_)) => return Err(()),
390        Either::Right(Some(Shutdown { result, .. })) => {
391            log::trace!("Shutdown uninitialized worker");
392            let _ = result.send(false);
393            return Err(());
394        }
395        Either::Right(None) => return Err(()),
396    };
397    availability.set(true);
398
399    Ok((
400        svc,
401        WorkerSt {
402            name,
403            rx,
404            factory,
405            availability,
406            stop: Box::pin(stop),
407        },
408    ))
409}