use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll, ready};
use std::{cmp, future::Future, future::poll_fn, hash, pin::Pin, sync::Arc};
use async_channel::{Receiver, Sender, TrySendError, unbounded};
use atomic_waker::AtomicWaker;
use core_affinity::CoreId;
use ntex_rt::{Arbiter, spawn};
use ntex_service::{Pipeline, PipelineBinding, Service, ServiceFactory};
use ntex_util::future::{Either, Stream, select, stream_recv};
use ntex_util::time::{Millis, sleep, timeout_checked};
use crate::ServerConfiguration;
const STOP_TIMEOUT: Millis = Millis(3000);
#[derive(Debug)]
struct Shutdown {
timeout: Millis,
result: oneshot::Sender<bool>,
}
#[derive(Copy, Clone, Default, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum WorkerStatus {
Available,
#[default]
Unavailable,
Failed,
}
#[derive(Debug)]
pub struct Worker<T> {
name: String,
tx1: Sender<T>,
tx2: Sender<Shutdown>,
avail: WorkerAvailability,
failed: Arc<AtomicBool>,
}
impl<T> cmp::Ord for Worker<T> {
fn cmp(&self, other: &Self) -> cmp::Ordering {
self.name.cmp(&other.name)
}
}
impl<T> cmp::PartialOrd for Worker<T> {
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<T> hash::Hash for Worker<T> {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
}
}
impl<T> Eq for Worker<T> {}
impl<T> PartialEq for Worker<T> {
fn eq(&self, other: &Worker<T>) -> bool {
self.name == other.name
}
}
#[derive(Debug)]
pub struct WorkerStop(oneshot::AsyncReceiver<bool>);
impl<T> Worker<T> {
pub fn start<F>(name: String, cfg: F, cid: Option<CoreId>) -> Worker<T>
where
T: Send + 'static,
F: ServerConfiguration<Item = T>,
{
let (tx1, rx1) = unbounded();
let (tx2, rx2) = unbounded();
let (avail, avail_tx) = WorkerAvailability::create();
let name2 = name.clone();
Arbiter::with_name(name.clone()).handle().spawn(async move {
if let Some(cid) = cid
&& core_affinity::set_for_current(cid)
{
log::info!("Set affinity to {cid:?} for worker {name2:?}");
}
spawn(async move {
log::info!("Starting worker {name2:?}");
log::debug!("Creating server instance in {name2:?}");
let factory = cfg.create().await;
match create(name2.clone(), rx1, rx2, factory, avail_tx).await {
Ok((svc, wrk)) => {
log::debug!("Server instance has been created in {name2:?}");
run_worker(svc, wrk).await;
}
Err(e) => {
log::error!("Cannot start worker {name2:?}: {e:?}");
}
}
Arbiter::current().stop();
});
});
Worker {
tx1,
tx2,
name,
avail,
failed: Arc::new(AtomicBool::new(false)),
}
}
pub fn name(&self) -> &str {
&self.name
}
pub fn send(&self, msg: T) -> Result<(), T> {
self.tx1.try_send(msg).map_err(TrySendError::into_inner)
}
pub fn status(&self) -> WorkerStatus {
if self.failed.load(Ordering::Acquire) {
WorkerStatus::Failed
} else if self.avail.available() {
WorkerStatus::Available
} else {
WorkerStatus::Unavailable
}
}
pub async fn wait_for_status(&mut self) -> WorkerStatus {
if self.failed.load(Ordering::Acquire) {
WorkerStatus::Failed
} else {
self.avail.wait_for_update().await;
if self.avail.failed() {
self.failed.store(true, Ordering::Release);
}
self.status()
}
}
pub fn stop(&self, timeout: Millis) -> WorkerStop {
let (result, rx) = oneshot::async_channel();
let _ = self.tx2.try_send(Shutdown { timeout, result });
WorkerStop(rx)
}
}
impl<T> Clone for Worker<T> {
fn clone(&self) -> Self {
Worker {
tx1: self.tx1.clone(),
tx2: self.tx2.clone(),
name: self.name.clone(),
avail: self.avail.clone(),
failed: self.failed.clone(),
}
}
}
impl Future for WorkerStop {
type Output = bool;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match ready!(Pin::new(&mut self.0).poll(cx)) {
Ok(res) => Poll::Ready(res),
Err(_) => Poll::Ready(true),
}
}
}
#[derive(Debug, Clone)]
struct WorkerAvailability {
inner: Arc<Inner>,
}
#[derive(Debug, Clone)]
struct WorkerAvailabilityTx {
inner: Arc<Inner>,
}
#[derive(Debug)]
struct Inner {
waker: AtomicWaker,
updated: AtomicBool,
available: AtomicBool,
failed: AtomicBool,
}
impl WorkerAvailability {
fn create() -> (Self, WorkerAvailabilityTx) {
let inner = Arc::new(Inner {
waker: AtomicWaker::new(),
updated: AtomicBool::new(false),
available: AtomicBool::new(false),
failed: AtomicBool::new(false),
});
let avail = WorkerAvailability {
inner: inner.clone(),
};
let avail_tx = WorkerAvailabilityTx { inner };
(avail, avail_tx)
}
fn failed(&self) -> bool {
self.inner.failed.load(Ordering::Acquire)
}
fn available(&self) -> bool {
self.inner.available.load(Ordering::Acquire)
}
async fn wait_for_update(&self) {
poll_fn(|cx| {
if self.inner.updated.load(Ordering::Acquire) {
self.inner.updated.store(false, Ordering::Release);
Poll::Ready(())
} else {
self.inner.waker.register(cx.waker());
Poll::Pending
}
})
.await;
}
}
impl WorkerAvailabilityTx {
fn set(&self, val: bool) {
let old = self.inner.available.swap(val, Ordering::Release);
if old != val {
self.inner.updated.store(true, Ordering::Release);
self.inner.waker.wake();
}
}
}
impl Drop for WorkerAvailabilityTx {
fn drop(&mut self) {
self.inner.failed.store(true, Ordering::Release);
self.inner.updated.store(true, Ordering::Release);
self.inner.available.store(false, Ordering::Release);
self.inner.waker.wake();
}
}
struct WorkerSt<T, F: ServiceFactory<T>> {
name: String,
rx: Receiver<T>,
stop: Pin<Box<dyn Stream<Item = Shutdown>>>,
factory: F,
availability: WorkerAvailabilityTx,
}
async fn run_worker<T, F>(mut svc: PipelineBinding<F::Service, T>, mut wrk: WorkerSt<T, F>)
where
T: Send + 'static,
F: ServiceFactory<T> + 'static,
{
loop {
let mut recv = std::pin::pin!(wrk.rx.recv());
let fut = poll_fn(|cx| {
match svc.poll_ready(cx) {
Poll::Ready(Ok(())) => {
wrk.availability.set(true);
}
Poll::Ready(Err(err)) => {
wrk.availability.set(false);
return Poll::Ready(Err(err));
}
Poll::Pending => {
wrk.availability.set(false);
return Poll::Pending;
}
}
if let Ok(item) = ready!(recv.as_mut().poll(cx)) {
let fut = svc.call(item);
spawn(async move {
let _ = fut.await;
});
Poll::Ready(Ok::<_, F::Error>(true))
} else {
log::error!("Server is gone");
Poll::Ready(Ok(false))
}
});
match select(fut, stream_recv(&mut wrk.stop)).await {
Either::Left(Ok(true)) => continue,
Either::Left(Err(_)) => {
ntex_rt::spawn(async move {
svc.shutdown().await;
});
}
Either::Right(Some(Shutdown { timeout, result })) => {
wrk.availability.set(false);
let timeout = if timeout.is_zero() { STOP_TIMEOUT } else { timeout };
stop_svc(&wrk.name, svc, timeout, Some(result)).await;
return;
}
Either::Left(Ok(false)) | Either::Right(None) => {
wrk.availability.set(false);
stop_svc(&wrk.name, svc, STOP_TIMEOUT, None).await;
return;
}
}
loop {
match select(wrk.factory.create(()), stream_recv(&mut wrk.stop)).await {
Either::Left(Ok(service)) => {
svc = Pipeline::new(service).bind();
break;
}
Either::Left(Err(_)) => sleep(Millis::ONE_SEC).await,
Either::Right(_) => return,
}
}
}
}
async fn stop_svc<T, F>(
name: &str,
svc: PipelineBinding<F, T>,
timeout: Millis,
result: Option<oneshot::Sender<bool>>,
) where
T: Send + 'static,
F: Service<T> + 'static,
{
let res = timeout_checked(timeout, svc.shutdown()).await;
if let Some(result) = result {
let _ = result.send(res.is_ok());
}
log::info!("Worker {name:?} has been stopped");
}
async fn create<T, F>(
name: String,
rx: Receiver<T>,
stop: Receiver<Shutdown>,
factory: Result<F, ()>,
availability: WorkerAvailabilityTx,
) -> Result<(PipelineBinding<F::Service, T>, WorkerSt<T, F>), ()>
where
T: Send + 'static,
F: ServiceFactory<T> + 'static,
{
availability.set(false);
let factory = factory?;
let mut stop = Box::pin(stop);
let svc = match select(factory.create(()), stream_recv(&mut stop)).await {
Either::Left(Ok(svc)) => Pipeline::new(svc).bind(),
Either::Right(Some(Shutdown { result, .. })) => {
log::trace!("Shutdown uninitialized worker");
let _ = result.send(false);
return Err(());
}
Either::Left(Err(_)) | Either::Right(None) => return Err(()),
};
availability.set(true);
Ok((
svc,
WorkerSt {
name,
rx,
factory,
availability,
stop: Box::pin(stop),
},
))
}