use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::task::Context;
use std::task::Poll;
use tower::BoxError;
use tower::Layer;
use tower::buffer::Buffer;
use tower::buffer::future::ResponseFuture;
use tower_service::Service;
#[derive(Clone, Copy)]
pub struct UnconstrainedBufferLayer<Request> {
bound: usize,
_p: PhantomData<fn(Request)>,
}
impl<Request> UnconstrainedBufferLayer<Request> {
pub const fn new(bound: usize) -> Self {
UnconstrainedBufferLayer {
bound,
_p: PhantomData,
}
}
}
impl<S, Request> Layer<S> for UnconstrainedBufferLayer<Request>
where
S: Service<Request> + Send + 'static,
S::Future: Send,
S::Error: Into<BoxError> + Send + Sync,
Request: Send + 'static,
{
type Service = UnconstrainedBuffer<Request, S::Future>;
fn layer(&self, service: S) -> Self::Service {
UnconstrainedBuffer::new(service, self.bound)
}
}
impl<Request> fmt::Debug for UnconstrainedBufferLayer<Request> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("UnconstrainedBufferLayer")
.field("bound", &self.bound)
.finish()
}
}
#[derive(Debug)]
pub struct UnconstrainedBuffer<Req, F> {
inner: Buffer<Req, F>,
}
impl<Req, F> UnconstrainedBuffer<Req, F>
where
F: 'static,
{
pub fn new<S>(service: S, bound: usize) -> Self
where
S: Service<Req, Future = F> + Send + 'static,
F: Send,
S::Error: Into<BoxError> + Send + Sync,
Req: Send + 'static,
{
let inner = Buffer::new(service, bound);
Self { inner }
}
}
impl<Req, Rsp, F, E> Service<Req> for UnconstrainedBuffer<Req, F>
where
F: Future<Output = Result<Rsp, E>> + Send + 'static,
E: Into<BoxError>,
Req: Send + 'static,
{
type Response = Rsp;
type Error = BoxError;
type Future = ResponseFuture<F>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
std::pin::pin!(tokio::task::unconstrained(std::future::poll_fn(|cx| {
self.inner.poll_ready(cx)
})))
.as_mut()
.poll(cx)
}
fn call(&mut self, request: Req) -> Self::Future {
self.inner.call(request)
}
}
impl<Req, F> Clone for UnconstrainedBuffer<Req, F>
where
Req: Send + 'static,
F: Send + 'static,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
#[cfg(test)]
mod tests {
use std::future::poll_fn;
use std::task::Poll;
use tokio::task::JoinSet;
use tokio::task::coop::has_budget_remaining;
use tokio::task::coop::poll_proceed;
use tower::BoxError;
use tower::Service;
use tower::load_shed::LoadShed;
use super::*;
fn consume_all_budget(cx: &mut Context) -> usize {
let mut consumed = 0;
loop {
let restore = poll_proceed(cx);
match restore {
Poll::Ready(r) => {
consumed += 1;
r.made_progress();
continue;
}
Poll::Pending => return consumed,
}
}
}
#[tokio::test]
async fn coop_budget_exhaustion_should_not_cause_buffer_poll_ready_to_return_pending() {
let inner = tower::service_fn(|_: ()| async { Ok::<_, BoxError>("ok") });
let mut inner_buffered = UnconstrainedBuffer::new(inner, 1000);
tokio::task::yield_now().await;
poll_fn(|cx| {
assert!(has_budget_remaining(), "Budget should not be exhausted");
assert!(
matches!(inner_buffered.poll_ready(cx), Poll::Ready(Ok(()))),
"Buffer::poll_ready should return Ready"
);
let fut = inner_buffered.call(());
let mut fut = std::pin::pin!(fut);
assert!(
matches!(fut.as_mut().poll(cx), Poll::Ready(Ok(_)) | Poll::Pending),
"Buffer::call should succeed"
);
Poll::Ready(())
})
.await;
tokio::task::yield_now().await;
poll_fn(|cx| {
let budget_consumed = consume_all_budget(cx);
assert_ne!(
budget_consumed,
0,
"Expected non-zero budget units consumed"
);
assert!(
!has_budget_remaining(),
"Expected budget to be exhausted after consuming all units, but poll_proceed is still Ready"
);
assert!(
matches!(inner_buffered.poll_ready(cx), Poll::Ready(Ok(()))),
"Buffer::poll_ready should return Ready even with exhausted budget"
);
let fut = inner_buffered.call(());
let mut fut = std::pin::pin!(fut);
assert!(
matches!(fut.as_mut().poll(cx), Poll::Ready(Ok(_)) | Poll::Pending),
"Buffer::call should succeed"
);
Poll::Ready(())
})
.await;
}
#[tokio::test]
async fn coop_budget_exhaustion_should_not_cause_false_shedding() {
let inner = tower::service_fn(|_: ()| async { Ok::<_, BoxError>("ok") });
let inner_buffered = UnconstrainedBuffer::new(inner, 1000);
let mut load_shed = LoadShed::new(inner_buffered);
tokio::task::yield_now().await;
poll_fn(|cx| {
assert!(has_budget_remaining(), "budget should not be exhausted");
assert!(
matches!(load_shed.poll_ready(cx), Poll::Ready(Ok(()))),
"LoadShed::poll_ready should return Ready"
);
let fut = load_shed.call(());
let mut fut = std::pin::pin!(fut);
assert!(
!matches!(fut.as_mut().poll(cx), Poll::Ready(Err(_))),
"requests should not be shed with fresh budget"
);
Poll::Ready(())
})
.await;
tokio::task::yield_now().await;
poll_fn(|cx| {
let budget_consumed = consume_all_budget(cx);
assert_ne!(
budget_consumed,
0,
"Expected non-zero budget units consumed"
);
assert!(
!has_budget_remaining(),
"Expected budget to be exhausted after consuming all units, but poll_proceed is still Ready"
);
assert!(
matches!(load_shed.poll_ready(cx), Poll::Ready(Ok(()))),
"LoadShed::poll_ready should return Ready"
);
let fut = load_shed.call(());
let mut fut = std::pin::pin!(fut);
let shed = match fut.as_mut().poll(cx) {
Poll::Ready(Err(e)) => e
.downcast_ref::<tower::load_shed::error::Overloaded>()
.is_some(),
_ => false,
};
assert!(
!shed,
"Load should not be shed (Overloaded) when there's enough Buffer permits"
);
Poll::Ready(())
})
.await;
}
#[tokio::test]
async fn full_buffer_should_still_cause_load_shedding() {
use std::sync::Arc;
use tokio::sync::Semaphore;
let gate = Arc::new(Semaphore::new(0));
let gate_clone = gate.clone();
let inner = tower::service_fn(move |_: ()| {
let gate = gate_clone.clone();
async move {
let _permit = gate.acquire().await.unwrap();
Ok::<_, BoxError>("ok")
}
});
let inner_buffered = UnconstrainedBuffer::new(inner, 1);
let mut load_shed = LoadShed::new(inner_buffered);
poll_fn(|cx| load_shed.poll_ready(cx)).await.unwrap();
drop(load_shed.call(()));
tokio::task::yield_now().await;
poll_fn(|cx| load_shed.poll_ready(cx)).await.unwrap();
drop(load_shed.call(()));
poll_fn(|cx| {
assert!(matches!(load_shed.poll_ready(cx), Poll::Ready(Ok(()))));
let fut = load_shed.call(());
let mut fut = std::pin::pin!(fut);
let is_overloaded = match fut.as_mut().poll(cx) {
Poll::Ready(Err(e)) => e
.downcast_ref::<tower::load_shed::error::Overloaded>()
.is_some(),
_ => false,
};
assert!(
is_overloaded,
"Expected Overloaded when buffer is genuinely full; \
UnconstrainedBuffer must not suppress real backpressure"
);
Poll::Ready(())
})
.await;
gate.add_permits(2);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn should_not_shed_under_load() {
let iterations: usize = 500;
let total_requests: usize = 100;
let buffer_capacity = 200;
let service = tower::service_fn(move |_: ()| async move { Ok::<_, BoxError>("ok") });
let inner_buffer = UnconstrainedBuffer::new(service, buffer_capacity);
let load_shed = LoadShed::new(inner_buffer);
let outer_buffer = UnconstrainedBuffer::new(load_shed, buffer_capacity);
let mut shed = 0usize;
let mut other_err = 0usize;
let mut tasks = JoinSet::new();
for _ in 0..iterations {
for _ in 0..total_requests {
let svc = outer_buffer.clone();
tasks.spawn(async move {
let mut svc = svc;
let svc = tower::ServiceExt::ready(&mut svc).await;
match svc {
Ok(svc) => svc.call(()).await,
Err(e) => Err(e),
}
});
}
while let Some(handle) = tasks.join_next().await {
if let Err(e) = handle.expect("task panicked") {
if e.downcast_ref::<tower::load_shed::error::Overloaded>()
.is_some()
{
shed += 1;
} else {
other_err += 1;
}
}
}
}
assert_eq!(shed, 0, "Expected all requests to succeed without shedding");
assert_eq!(
other_err, 0,
"Expected all requests to succeed without errors"
);
}
}