use crate::{
error::{Closed, Error, ServiceError},
message::Message,
};
use futures_core::ready;
use pin_project::pin_project;
use std::sync::{Arc, Mutex};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::sync::mpsc;
use tower_service::Service;
#[pin_project]
#[derive(Debug)]
pub struct Worker<T, Request>
where
T: Service<Request>,
T::Error: Into<Error>,
{
current_message: Option<Message<Request, T::Future>>,
rx: mpsc::Receiver<Message<Request, T::Future>>,
service: T,
finish: bool,
failed: Option<ServiceError>,
handle: Handle,
}
#[derive(Debug)]
pub(crate) struct Handle {
inner: Arc<Mutex<Option<ServiceError>>>,
}
impl<T, Request> Worker<T, Request>
where
T: Service<Request>,
T::Error: Into<Error>,
{
pub(crate) fn new(
service: T,
rx: mpsc::Receiver<Message<Request, T::Future>>,
) -> (Handle, Worker<T, Request>) {
let handle = Handle {
inner: Arc::new(Mutex::new(None)),
};
let worker = Worker {
current_message: None,
finish: false,
failed: None,
rx,
service,
handle: handle.clone(),
};
(handle, worker)
}
fn poll_next_msg(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<(Message<Request, T::Future>, bool)>> {
if self.finish {
return Poll::Ready(None);
}
tracing::trace!("worker polling for next message");
if let Some(mut msg) = self.current_message.take() {
if msg.tx.poll_closed(cx).is_pending() {
tracing::trace!("resuming buffered request");
return Poll::Ready(Some((msg, false)));
}
tracing::trace!("dropping cancelled buffered request");
}
while let Some(mut msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
if msg.tx.poll_closed(cx).is_pending() {
tracing::trace!("processing new request");
return Poll::Ready(Some((msg, true)));
}
tracing::trace!("dropping cancelled request");
}
Poll::Ready(None)
}
fn failed(&mut self, error: Error) {
let error = ServiceError::new(error);
let mut inner = self.handle.inner.lock().unwrap();
if inner.is_some() {
return;
}
*inner = Some(error.clone());
drop(inner);
self.rx.close();
self.failed = Some(error);
}
}
impl<T, Request> Future for Worker<T, Request>
where
T: Service<Request>,
T::Error: Into<Error>,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.finish {
return Poll::Ready(());
}
loop {
match ready!(self.poll_next_msg(cx)) {
Some((msg, first)) => {
let _guard = msg.span.enter();
if let Some(ref failed) = self.failed {
tracing::trace!("notifying caller about worker failure");
let _ = msg.tx.send(Err(failed.clone()));
continue;
}
tracing::trace!(
resumed = !first,
message = "worker received request; waiting for service readiness"
);
match self.service.poll_ready(cx) {
Poll::Ready(Ok(())) => {
tracing::debug!(service.ready = true, message = "processing request");
let response = self.service.call(msg.request);
tracing::trace!("returning response future");
let _ = msg.tx.send(Ok(response));
}
Poll::Pending => {
tracing::trace!(service.ready = false, message = "delay");
drop(_guard);
self.current_message = Some(msg);
return Poll::Pending;
}
Poll::Ready(Err(e)) => {
let error = e.into();
tracing::debug!({ %error }, "service failed");
drop(_guard);
self.failed(error);
let _ = msg.tx.send(Err(self
.failed
.as_ref()
.expect("Worker::failed did not set self.failed?")
.clone()));
}
}
}
None => {
self.finish = true;
return Poll::Ready(());
}
}
}
}
}
impl Handle {
pub(crate) fn get_error_on_closed(&self) -> Error {
self.inner
.lock()
.unwrap()
.as_ref()
.map(|svc_err| svc_err.clone().into())
.unwrap_or_else(|| Closed::new().into())
}
}
impl Clone for Handle {
fn clone(&self) -> Handle {
Handle {
inner: self.inner.clone(),
}
}
}