use std::{
fmt::Debug,
sync::Arc,
task::{Context, Poll},
};
use futures_core::ready;
use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore};
use tokio_util::sync::PollSemaphore;
use tower::Service;
use super::{
future::ResponseFuture,
message::Message,
worker::{Handle, Worker},
BatchControl,
};
#[derive(Debug)]
pub struct Batch<T, Request>
where
T: Service<BatchControl<Request>>,
{
tx: mpsc::UnboundedSender<Message<Request, T::Future>>,
semaphore: PollSemaphore,
permit: Option<OwnedSemaphorePermit>,
handle: Handle,
}
impl<T, Request> Batch<T, Request>
where
T: Service<BatchControl<Request>>,
T::Error: Into<crate::BoxError>,
{
pub fn new(service: T, size: usize, time: std::time::Duration) -> Self
where
T: Send + 'static,
T::Future: Send,
T::Error: Send + Sync,
Request: Send + 'static,
{
let (service, worker) = Self::pair(service, size, time);
tokio::spawn(worker);
service
}
pub fn pair(service: T, size: usize, time: std::time::Duration) -> (Self, Worker<T, Request>)
where
T: Send + 'static,
T::Future: Send,
T::Error: Send + Sync,
Request: Send + 'static,
{
let (tx, rx) = mpsc::unbounded_channel();
let bound = size;
let semaphore = Arc::new(Semaphore::new(bound));
let (handle, worker) = Worker::new(rx, service, size, time, &semaphore);
let batch = Self {
tx,
semaphore: PollSemaphore::new(semaphore),
permit: None,
handle,
};
(batch, worker)
}
fn get_worker_error(&self) -> crate::BoxError {
self.handle.get_error_on_closed()
}
}
impl<T, Request> Service<Request> for Batch<T, Request>
where
T: Service<BatchControl<Request>>,
T::Error: Into<crate::BoxError>,
{
type Response = T::Response;
type Error = crate::BoxError;
type Future = ResponseFuture<T::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
tracing::debug!("checking if service is ready");
if self.tx.is_closed() {
return Poll::Ready(Err(self.get_worker_error()));
}
if self.permit.is_some() {
return Poll::Ready(Ok(()));
}
let permit =
ready!(self.semaphore.poll_acquire(cx)).ok_or_else(|| self.get_worker_error())?;
self.permit = Some(permit);
Poll::Ready(Ok(()))
}
fn call(&mut self, request: Request) -> Self::Future {
tracing::debug!("sending request to batch worker");
let permit = self
.permit
.take()
.expect("batch full; poll_ready must be called first");
let span = tracing::Span::current();
let (tx, rx) = oneshot::channel();
match self.tx.send(Message {
request,
tx,
span,
_permit: permit,
}) {
Err(_) => ResponseFuture::failed(self.get_worker_error()),
Ok(()) => ResponseFuture::new(rx),
}
}
}
impl<T, Request> Clone for Batch<T, Request>
where
T: Service<BatchControl<Request>>,
{
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
semaphore: self.semaphore.clone(),
handle: self.handle.clone(),
permit: None,
}
}
}