use crate::on_early_drop::failure::{BodyDropped, DroppedFailure, FutureDropped};
use crate::on_early_drop::traits::{OnBodyDrop, OnDropCallback, OnFutureDrop};
use crate::trace::OnFailure;
use http::{response, Request, StatusCode};
use std::time::Instant;
use tracing::Span;
#[derive(Debug, Clone, Copy)]
pub struct EarlyDropsAsFailures<F> {
on_failure: F,
}
impl<F> EarlyDropsAsFailures<F> {
pub fn new(on_failure: F) -> Self {
Self { on_failure }
}
}
pub struct FutureDropFailureCallback<F> {
start: Instant,
on_failure: F,
span: Span,
}
impl<F> OnDropCallback for FutureDropFailureCallback<F>
where
F: OnFailure<DroppedFailure> + Send + 'static,
{
fn on_drop(mut self) {
let latency = self.start.elapsed();
let _entered = self.span.enter();
self.on_failure
.on_failure(DroppedFailure::Future(FutureDropped), latency, &self.span);
}
}
pub struct PreResponseBodyDropCallback<F> {
start: Instant,
on_failure: F,
span: Span,
}
pub struct BodyDropFailureCallback<F> {
start: Instant,
on_failure: F,
span: Span,
status: StatusCode,
}
impl<F> OnDropCallback for BodyDropFailureCallback<F>
where
F: OnFailure<DroppedFailure> + Send + 'static,
{
fn on_drop(mut self) {
let latency = self.start.elapsed();
let _entered = self.span.enter();
self.on_failure.on_failure(
DroppedFailure::Body(BodyDropped {
status: self.status,
}),
latency,
&self.span,
);
}
}
impl<F, ReqB> OnFutureDrop<ReqB> for EarlyDropsAsFailures<F>
where
F: OnFailure<DroppedFailure> + Clone + Send + 'static,
{
type Callback = FutureDropFailureCallback<F>;
fn make(&mut self, _request: &Request<ReqB>) -> Self::Callback {
FutureDropFailureCallback {
start: Instant::now(),
on_failure: self.on_failure.clone(),
span: Span::current(),
}
}
}
impl<F, ReqB> OnBodyDrop<ReqB> for EarlyDropsAsFailures<F>
where
F: OnFailure<DroppedFailure> + Clone + Send + 'static,
{
type Intermediate = PreResponseBodyDropCallback<F>;
type Callback = BodyDropFailureCallback<F>;
fn make_at_call(&mut self, _request: &Request<ReqB>) -> Self::Intermediate {
PreResponseBodyDropCallback {
start: Instant::now(),
on_failure: self.on_failure.clone(),
span: Span::current(),
}
}
fn make_at_response(
&mut self,
intermediate: Self::Intermediate,
response_parts: &response::Parts,
) -> Self::Callback {
BodyDropFailureCallback {
start: intermediate.start,
on_failure: intermediate.on_failure,
span: intermediate.span,
status: response_parts.status,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::on_early_drop::OnEarlyDropLayer;
use bytes::Bytes;
use http::{Request, Response, StatusCode};
use http_body_util::{BodyExt, Full};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::time::{sleep, timeout};
use tower::{service_fn, Layer, ServiceExt};
use tracing::Span;
#[derive(Clone, Default)]
struct RecordingOnFailure {
events: Arc<Mutex<Vec<DroppedFailure>>>,
}
impl OnFailure<DroppedFailure> for RecordingOnFailure {
fn on_failure(&mut self, class: DroppedFailure, _latency: Duration, _span: &Span) {
self.events.lock().unwrap().push(class);
}
}
#[tokio::test]
async fn future_drop_reports_future_failure() {
let recorder = RecordingOnFailure::default();
let events = recorder.events.clone();
let slow_service = service_fn(|_req: Request<()>| async move {
sleep(Duration::from_secs(60)).await;
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::OK)
.body(Full::new(Bytes::new()))
.unwrap(),
)
});
let layer = OnEarlyDropLayer::new(EarlyDropsAsFailures::new(recorder));
let service = layer.layer(slow_service);
let _ = timeout(
Duration::from_millis(50),
service.oneshot(Request::builder().uri("/").body(()).unwrap()),
)
.await;
sleep(Duration::from_millis(10)).await;
let captured = events.lock().unwrap();
assert_eq!(captured.len(), 1);
assert!(matches!(captured[0], DroppedFailure::Future(_)));
}
#[tokio::test]
async fn body_drop_reports_body_failure_with_status() {
let recorder = RecordingOnFailure::default();
let events = recorder.events.clone();
struct PendingBody;
impl http_body::Body for PendingBody {
type Data = Bytes;
type Error = std::convert::Infallible;
fn poll_frame(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>>
{
std::task::Poll::Pending
}
fn is_end_stream(&self) -> bool {
false
}
}
let service = service_fn(|_req: Request<()>| async move {
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::CREATED)
.body(PendingBody)
.unwrap(),
)
});
let layer = OnEarlyDropLayer::new(EarlyDropsAsFailures::new(recorder));
let service = layer.layer(service);
let response = service
.oneshot(Request::builder().uri("/").body(()).unwrap())
.await
.unwrap();
drop(response);
let captured = events.lock().unwrap();
assert_eq!(captured.len(), 1);
match &captured[0] {
DroppedFailure::Body(body) => assert_eq!(body.status, StatusCode::CREATED),
other => panic!("expected Body failure, got {:?}", other),
}
}
#[tokio::test]
async fn completion_suppresses_both() {
let recorder = RecordingOnFailure::default();
let events = recorder.events.clone();
let ok_service = service_fn(|_req: Request<()>| async move {
Ok::<_, std::convert::Infallible>(
Response::builder()
.status(StatusCode::OK)
.body(Full::new(Bytes::from_static(b"hi")))
.unwrap(),
)
});
let layer = OnEarlyDropLayer::new(EarlyDropsAsFailures::new(recorder));
let service = layer.layer(ok_service);
let response = service
.oneshot(Request::builder().uri("/").body(()).unwrap())
.await
.unwrap();
let _body = response.into_body().collect().await.unwrap();
assert!(events.lock().unwrap().is_empty());
}
}