use std::{
error::Error,
future::Future,
pin::Pin,
task::{ready, Context, Poll},
};
use axum_core::{body::Body, response::Response};
use bytes::Bytes;
use http_body::Body as HttpBody;
use pin_project_lite::pin_project;
use tower::{Layer, Service};
#[derive(Debug, Clone)]
pub struct ResponseAxumBodyLayer;
impl<S> Layer<S> for ResponseAxumBodyLayer {
type Service = ResponseAxumBody<S>;
fn layer(&self, inner: S) -> Self::Service {
ResponseAxumBody::<S>(inner)
}
}
#[derive(Debug, Clone)]
pub struct ResponseAxumBody<S>(S);
impl<S, Request, ResBody> Service<Request> for ResponseAxumBody<S>
where
S: Service<Request, Response = Response<ResBody>>,
ResBody: HttpBody<Data = Bytes> + Send + 'static,
<ResBody as HttpBody>::Error: Error + Send + Sync,
{
type Response = Response;
type Error = S::Error;
type Future = ResponseAxumBodyFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
fn call(&mut self, req: Request) -> Self::Future {
ResponseAxumBodyFuture {
inner: self.0.call(req),
}
}
}
pin_project! {
pub struct ResponseAxumBodyFuture<Fut> {
#[pin]
inner: Fut,
}
}
impl<Fut, ResBody, E> Future for ResponseAxumBodyFuture<Fut>
where
Fut: Future<Output = Result<Response<ResBody>, E>>,
ResBody: HttpBody<Data = Bytes> + Send + 'static,
<ResBody as HttpBody>::Error: Error + Send + Sync,
{
type Output = Result<Response<Body>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let res = ready!(this.inner.poll(cx)?);
Poll::Ready(Ok(res.map(Body::new)))
}
}