use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use pin_project_lite::pin_project;
use tower::Layer;
use tower_service::Service;
use super::guard::SubgraphRequestGuard;
use super::tracker::RouterOverheadTracker;
use crate::services::http::HttpRequest;
use crate::services::http::HttpResponse;
#[derive(Clone)]
pub(crate) struct OverheadLayer;
impl OverheadLayer {
pub(crate) fn new() -> Self {
OverheadLayer
}
}
impl<S> Layer<S> for OverheadLayer {
type Service = OverheadService<S>;
fn layer(&self, inner: S) -> Self::Service {
OverheadService { inner }
}
}
pub(crate) struct OverheadService<S> {
inner: S,
}
impl<S> Clone for OverheadService<S>
where
S: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<S> Service<HttpRequest> for OverheadService<S>
where
S: Service<HttpRequest, Response = HttpResponse> + Send,
S::Future: Send + 'static,
{
type Response = HttpResponse;
type Error = S::Error;
type Future = OverheadFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: HttpRequest) -> Self::Future {
let guard = request
.context
.extensions()
.with_lock(|lock| lock.get::<RouterOverheadTracker>().cloned())
.map(|tracker| tracker.create_guard());
let future = self.inner.call(request);
OverheadFuture {
inner: future,
_guard: guard,
}
}
}
pin_project! {
pub(crate) struct OverheadFuture<F> {
#[pin]
inner: F,
_guard: Option<SubgraphRequestGuard>,
}
}
impl<F, E> Future for OverheadFuture<F>
where
F: Future<Output = Result<HttpResponse, E>>,
{
type Output = Result<HttpResponse, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
this.inner.poll(cx)
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use tower::Service;
use tower::ServiceBuilder;
use tower::ServiceExt;
use super::*;
use crate::Context;
use crate::services::http::HttpRequest;
use crate::services::http::HttpResponse;
use crate::services::router::body;
#[derive(Clone)]
struct MockHttpService;
impl Service<HttpRequest> for MockHttpService {
type Response = HttpResponse;
type Error = tower::BoxError;
type Future = std::pin::Pin<
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, req: HttpRequest) -> Self::Future {
Box::pin(async move {
Ok(HttpResponse {
http_response: http::Response::new(body::empty()),
context: req.context,
})
})
}
}
#[tokio::test]
async fn test_layer_creates_guard_when_tracker_present() {
let tracker = RouterOverheadTracker::new();
let context = Context::new();
context.extensions().with_lock(|lock| {
lock.insert(tracker.clone());
});
let mut service = ServiceBuilder::new()
.layer(OverheadLayer::new())
.service(MockHttpService);
let request = HttpRequest {
http_request: http::Request::new(body::empty()),
context: context.clone(),
};
tokio::time::sleep(Duration::from_millis(10)).await;
let _response = service.ready().await.unwrap().call(request).await.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
let result = tracker.calculate_overhead();
assert_eq!(result.active_subgraph_requests, 0);
assert!(
result.overhead >= Duration::from_millis(10)
&& result.overhead <= Duration::from_millis(60),
"overhead was {:?}",
result.overhead
);
}
#[tokio::test]
async fn test_layer_works_without_tracker() {
let context = Context::new();
let mut service = ServiceBuilder::new()
.layer(OverheadLayer::new())
.service(MockHttpService);
let request = HttpRequest {
http_request: http::Request::new(body::empty()),
context,
};
let _response = service.ready().await.unwrap().call(request).await.unwrap();
}
}