use std::{
future::Future,
mem,
ops::Add,
pin::Pin,
sync::{Arc, Mutex, Weak},
task::{Context, Poll},
time::{Duration, Instant},
};
use futures_core::ready;
use tokio::{
sync::{mpsc, Semaphore},
time::{sleep_until, Sleep},
};
use tower::Service;
use tracing::{debug, trace};
use super::{
error::{Closed, ServiceError},
message::{Message, Tx},
BatchControl,
};
#[derive(Debug)]
pub(crate) struct Handle {
inner: Arc<Mutex<Option<ServiceError>>>,
}
#[derive(Debug)]
struct Bridge<Fut, Request> {
rx: mpsc::UnboundedReceiver<Message<Request, Fut>>,
handle: Handle,
current_message: Option<Message<Request, Fut>>,
close: Option<Weak<Semaphore>>,
failed: Option<ServiceError>,
}
#[derive(Debug)]
struct Lot<Fut> {
max_size: usize,
max_time: Duration,
responses: Vec<(Tx<Fut>, Result<Fut, ServiceError>)>,
time_elapses: Option<Pin<Box<Sleep>>>,
time_elapsed: bool,
}
pin_project_lite::pin_project! {
#[project = StateProj]
#[derive(Debug)]
enum State<Fut> {
Collecting,
Flushing {
reason: Option<String>,
#[pin]
flush_fut: Option<Fut>,
},
Finished
}
}
pin_project_lite::pin_project! {
#[derive(Debug)]
pub struct Worker<T, Request>
where
T: Service<BatchControl<Request>>,
T::Error: Into<crate::BoxError>,
{
service: T,
bridge: Bridge<T::Future, Request>,
lot: Lot<T::Future>,
#[pin]
state: State<T::Future>,
}
}
impl<T, Request> Worker<T, Request>
where
T: Service<BatchControl<Request>>,
T::Error: Into<crate::BoxError>,
{
pub(crate) fn new(
rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
service: T,
max_size: usize,
max_time: Duration,
semaphore: &Arc<Semaphore>,
) -> (Handle, Worker<T, Request>) {
trace!("creating Batch worker");
let handle = Handle {
inner: Arc::new(Mutex::new(None)),
};
let semaphore = Arc::downgrade(semaphore);
let worker = Self {
service,
bridge: Bridge {
rx,
current_message: None,
handle: handle.clone(),
close: Some(semaphore),
failed: None,
},
lot: Lot::new(max_size, max_time),
state: State::Collecting,
};
(handle, worker)
}
}
impl<T, Request> Future for Worker<T, Request>
where
T: Service<BatchControl<Request>>,
T::Error: Into<crate::BoxError>,
{
type Output = ();
#[allow(clippy::too_many_lines)]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
trace!("polling worker");
let mut this = self.project();
if matches!(this.state.as_ref().get_ref(), State::Collecting) {
if let Poll::Ready(Some(())) = this.lot.poll_max_time(cx) {
this.state.set(State::flushing("time".to_owned(), None));
}
}
loop {
match this.state.as_mut().project() {
StateProj::Collecting => {
if let Some((msg, first)) = ready!(this.bridge.poll_next_msg(cx)) {
let guard = msg.span.enter();
trace!(resumed = !first, message = "worker received request");
trace!(message = "waiting for service readiness");
match this.service.poll_ready(cx) {
Poll::Ready(Ok(())) => {
debug!(service.ready = true, message = "adding item");
let response = this.service.call(msg.request.into());
this.lot.add((msg.tx, Ok(response)));
if this.lot.is_full() {
this.state.set(State::flushing("size".to_owned(), None));
} else if this.lot.poll_max_time(cx).is_ready() {
this.state.set(State::flushing("time".to_owned(), None));
}
}
Poll::Pending => {
drop(guard);
debug!(service.ready = false, message = "delay item addition");
this.bridge.return_msg(msg);
return Poll::Pending;
}
Poll::Ready(Err(e)) => {
drop(guard);
this.bridge.failed("item addition", e.into());
if let Some(ref e) = this.bridge.failed {
this.lot.add((msg.tx, Err(e.clone())));
this.lot.notify(Some(e));
}
this.state.set(State::Finished);
return Poll::Ready(());
}
}
} else {
trace!("shutting down, no more requests _ever_");
this.state.set(State::Finished);
return Poll::Ready(());
}
}
StateProj::Flushing { reason, flush_fut } => match flush_fut.as_pin_mut() {
None => {
trace!(
reason = reason.as_mut().unwrap().as_str(),
message = "waiting for service readiness"
);
match this.service.poll_ready(cx) {
Poll::Ready(Ok(())) => {
debug!(
service.ready = true,
reason = reason.as_mut().unwrap().as_str(),
message = "flushing batch"
);
let response = this.service.call(BatchControl::Flush);
let reason = reason.take().expect("missing reason");
this.state.set(State::flushing(reason, Some(response)));
}
Poll::Pending => {
debug!(
service.ready = false,
reason = reason.as_mut().unwrap().as_str(),
message = "delay flush"
);
return Poll::Pending;
}
Poll::Ready(Err(e)) => {
this.bridge.failed("flush", e.into());
if let Some(ref e) = this.bridge.failed {
this.lot.notify(Some(e));
}
this.state.set(State::Finished);
return Poll::Ready(());
}
}
}
Some(future) => match ready!(future.poll(cx)) {
Ok(_) => {
debug!(reason = reason.as_mut().unwrap().as_str(), "batch flushed");
this.lot.notify(None);
this.state.set(State::Collecting);
}
Err(e) => {
this.bridge.failed("flush", e.into());
if let Some(ref e) = this.bridge.failed {
this.lot.notify(Some(e));
}
this.state.set(State::Finished);
return Poll::Ready(());
}
},
},
StateProj::Finished => {
return Poll::Ready(());
}
}
}
}
}
impl<Fut> State<Fut> {
fn flushing(reason: String, f: Option<Fut>) -> Self {
Self::Flushing {
reason: Some(reason),
flush_fut: f,
}
}
}
impl<Fut, Request> Drop for Bridge<Fut, Request> {
fn drop(&mut self) {
self.close_semaphore();
}
}
impl<Fut, Request> Bridge<Fut, Request> {
fn close_semaphore(&mut self) {
if let Some(close) = self
.close
.take()
.as_ref()
.and_then(Weak::<Semaphore>::upgrade)
{
debug!("buffer closing; waking pending tasks");
close.close();
} else {
trace!("buffer already closed");
}
}
fn failed(&mut self, action: &str, error: crate::BoxError) {
debug!(action, %error , "service failed");
let error = ServiceError::new(error);
let mut inner = self.handle.inner.lock().unwrap();
if inner.is_some() {
return;
}
*inner = Some(error.clone());
drop(inner);
self.rx.close();
self.close_semaphore();
self.failed = Some(error);
}
fn poll_next_msg(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<(Message<Request, Fut>, bool)>> {
trace!("worker polling for next message");
if let Some(msg) = self.current_message.take() {
if !msg.tx.is_closed() {
trace!("resuming buffered request");
return Poll::Ready(Some((msg, false)));
}
trace!("dropping cancelled buffered request");
}
while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
if !msg.tx.is_closed() {
trace!("processing new request");
return Poll::Ready(Some((msg, true)));
}
trace!("dropping cancelled request");
}
Poll::Ready(None)
}
fn return_msg(&mut self, msg: Message<Request, Fut>) {
self.current_message = Some(msg);
}
}
impl<Fut> Lot<Fut> {
fn new(max_size: usize, max_time: Duration) -> Self {
Self {
max_size,
max_time,
responses: Vec::with_capacity(max_size),
time_elapses: None,
time_elapsed: false,
}
}
fn poll_max_time(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
if self.time_elapsed {
return Poll::Ready(None);
}
if let Some(ref mut sleep) = self.time_elapses {
if Pin::new(sleep).poll(cx).is_ready() {
self.time_elapsed = true;
return Poll::Ready(Some(()));
}
}
Poll::Pending
}
fn is_full(&self) -> bool {
self.responses.len() == self.max_size
}
fn add(&mut self, item: (Tx<Fut>, Result<Fut, ServiceError>)) {
if self.responses.is_empty() {
self.time_elapses = Some(Box::pin(sleep_until(
Instant::now().add(self.max_time).into(),
)));
}
self.responses.push(item);
}
fn notify(&mut self, err: Option<&ServiceError>) {
for (tx, response) in mem::replace(&mut self.responses, Vec::with_capacity(self.max_size)) {
if let Some(err) = err {
let _ = tx.send(Err(err.clone()));
} else {
let _ = tx.send(response);
}
}
self.time_elapses = None;
self.time_elapsed = false;
}
}
impl Handle {
pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
self.inner
.lock()
.unwrap()
.as_ref()
.map_or_else(|| Closed::new().into(), |svc_err| svc_err.clone().into())
}
}
impl Clone for Handle {
fn clone(&self) -> Self {
Handle {
inner: self.inner.clone(),
}
}
}