use http::{Request, Response, StatusCode};
use pin_project_lite::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower_layer::Layer;
use tower_service::Service;
#[derive(Debug, Clone, Copy)]
pub struct SetStatusLayer {
status: StatusCode,
}
impl SetStatusLayer {
pub fn new(status: StatusCode) -> Self {
SetStatusLayer { status }
}
}
impl<S> Layer<S> for SetStatusLayer {
type Service = SetStatus<S>;
fn layer(&self, inner: S) -> Self::Service {
SetStatus::new(inner, self.status)
}
}
#[derive(Debug, Clone, Copy)]
pub struct SetStatus<S> {
inner: S,
status: StatusCode,
}
impl<S> SetStatus<S> {
pub fn new(inner: S, status: StatusCode) -> Self {
Self { status, inner }
}
define_inner_service_accessors!();
pub fn layer(status: StatusCode) -> SetStatusLayer {
SetStatusLayer::new(status)
}
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SetStatus<S>
where
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = ResponseFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
ResponseFuture {
inner: self.inner.call(req),
status: Some(self.status),
}
}
}
pin_project! {
pub struct ResponseFuture<F> {
#[pin]
inner: F,
status: Option<StatusCode>,
}
}
impl<F, B, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<B>, E>>,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut response = futures_core::ready!(this.inner.poll(cx)?);
*response.status_mut() = this.status.take().expect("future polled after completion");
Poll::Ready(Ok(response))
}
}