use futures_util::ready;
use http::{Request, Response};
use http_body::Body;
use pin_project_lite::pin_project;
use std::{
future::Future,
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
task::{Context, Poll},
time::Duration,
};
use tower_layer::Layer;
use tower_service::Service;
#[derive(Clone, Debug)]
pub struct InFlightRequestsLayer {
counter: InFlightRequestsCounter,
}
impl InFlightRequestsLayer {
pub fn pair() -> (Self, InFlightRequestsCounter) {
let counter = InFlightRequestsCounter::new();
let layer = Self::new(counter.clone());
(layer, counter)
}
pub fn new(counter: InFlightRequestsCounter) -> Self {
Self { counter }
}
}
impl<S> Layer<S> for InFlightRequestsLayer {
type Service = InFlightRequests<S>;
fn layer(&self, inner: S) -> Self::Service {
InFlightRequests {
inner,
counter: self.counter.clone(),
}
}
}
#[derive(Clone, Debug)]
pub struct InFlightRequests<S> {
inner: S,
counter: InFlightRequestsCounter,
}
impl<S> InFlightRequests<S> {
pub fn pair(inner: S) -> (Self, InFlightRequestsCounter) {
let counter = InFlightRequestsCounter::new();
let service = Self::new(inner, counter.clone());
(service, counter)
}
pub fn new(inner: S, counter: InFlightRequestsCounter) -> Self {
Self { inner, counter }
}
define_inner_service_accessors!();
}
#[derive(Debug, Clone, Default)]
pub struct InFlightRequestsCounter {
count: Arc<AtomicUsize>,
}
impl InFlightRequestsCounter {
pub fn new() -> Self {
Self::default()
}
pub fn get(&self) -> usize {
self.count.load(Ordering::Relaxed)
}
fn increment(&self) -> IncrementGuard {
self.count.fetch_add(1, Ordering::Relaxed);
IncrementGuard {
count: self.count.clone(),
}
}
pub async fn run_emitter<F, Fut>(mut self, interval: Duration, mut emit: F)
where
F: FnMut(usize) -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send,
{
let mut interval = tokio::time::interval(interval);
loop {
match Arc::try_unwrap(self.count) {
Ok(_) => return,
Err(shared_count) => {
self = Self {
count: shared_count,
}
}
}
interval.tick().await;
emit(self.get()).await;
}
}
}
struct IncrementGuard {
count: Arc<AtomicUsize>,
}
impl Drop for IncrementGuard {
fn drop(&mut self) {
self.count.fetch_sub(1, Ordering::Relaxed);
}
}
impl<S, R, ResBody> Service<Request<R>> for InFlightRequests<S>
where
S: Service<Request<R>, Response = Response<ResBody>>,
{
type Response = Response<ResponseBody<ResBody>>;
type Error = S::Error;
type Future = ResponseFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<R>) -> Self::Future {
let guard = self.counter.increment();
ResponseFuture {
inner: self.inner.call(req),
guard: Some(guard),
}
}
}
pin_project! {
pub struct ResponseFuture<F> {
#[pin]
inner: F,
guard: Option<IncrementGuard>,
}
}
impl<F, B, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<B>, E>>,
{
type Output = Result<Response<ResponseBody<B>>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let response = ready!(this.inner.poll(cx))?;
let guard = this.guard.take().unwrap();
let response = response.map(move |body| ResponseBody { inner: body, guard });
Poll::Ready(Ok(response))
}
}
pin_project! {
pub struct ResponseBody<B> {
#[pin]
inner: B,
guard: IncrementGuard,
}
}
impl<B> Body for ResponseBody<B>
where
B: Body,
{
type Data = B::Data;
type Error = B::Error;
#[inline]
fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
self.project().inner.poll_data(cx)
}
#[inline]
fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
self.project().inner.poll_trailers(cx)
}
#[inline]
fn is_end_stream(&self) -> bool {
self.inner.is_end_stream()
}
#[inline]
fn size_hint(&self) -> http_body::SizeHint {
self.inner.size_hint()
}
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
use http::Request;
use hyper::Body;
use tower::{BoxError, ServiceBuilder};
#[tokio::test]
async fn basic() {
let (in_flight_requests_layer, counter) = InFlightRequestsLayer::pair();
let mut service = ServiceBuilder::new()
.layer(in_flight_requests_layer)
.service_fn(echo);
assert_eq!(counter.get(), 0);
futures::future::poll_fn(|cx| service.poll_ready(cx))
.await
.unwrap();
assert_eq!(counter.get(), 0);
let response_future = service.call(Request::new(Body::empty()));
assert_eq!(counter.get(), 1);
let response = response_future.await.unwrap();
assert_eq!(counter.get(), 1);
let body = response.into_body();
hyper::body::to_bytes(body).await.unwrap();
assert_eq!(counter.get(), 0);
}
async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
Ok(Response::new(req.into_body()))
}
}