use super::{Extension, FromRequest, FromRequestParts};
use crate::{
body::{Body, Bytes, HttpBody},
BoxError, Error,
};
use async_trait::async_trait;
use futures_util::stream::Stream;
use http::{request::Parts, Request, Uri};
use std::{
convert::Infallible,
fmt,
pin::Pin,
task::{Context, Poll},
};
use sync_wrapper::SyncWrapper;
#[cfg(feature = "original-uri")]
#[derive(Debug, Clone)]
pub struct OriginalUri(pub Uri);
#[cfg(feature = "original-uri")]
#[async_trait]
impl<S> FromRequestParts<S> for OriginalUri
where
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let uri = Extension::<Self>::from_request_parts(parts, state)
.await
.unwrap_or_else(|_| Extension(OriginalUri(parts.uri.clone())))
.0;
Ok(uri)
}
}
pub struct BodyStream(
SyncWrapper<Pin<Box<dyn HttpBody<Data = Bytes, Error = Error> + Send + 'static>>>,
);
impl Stream for BodyStream {
type Item = Result<Bytes, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(self.0.get_mut()).poll_data(cx)
}
}
#[async_trait]
impl<S, B> FromRequest<S, B> for BodyStream
where
B: HttpBody + Send + 'static,
B::Data: Into<Bytes>,
B::Error: Into<BoxError>,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
let body = req
.into_body()
.map_data(Into::into)
.map_err(|err| Error::new(err.into()));
let stream = BodyStream(SyncWrapper::new(Box::pin(body)));
Ok(stream)
}
}
impl fmt::Debug for BodyStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("BodyStream").finish()
}
}
#[test]
fn body_stream_traits() {
crate::test_helpers::assert_send::<BodyStream>();
crate::test_helpers::assert_sync::<BodyStream>();
}
#[derive(Debug, Default, Clone)]
pub struct RawBody<B = Body>(pub B);
#[async_trait]
impl<S, B> FromRequest<S, B> for RawBody<B>
where
B: Send,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
Ok(Self(req.into_body()))
}
}
#[cfg(test)]
mod tests {
use crate::{extract::Extension, routing::get, test_helpers::*, Router};
use http::{Method, StatusCode};
#[crate::test]
async fn extract_request_parts() {
#[derive(Clone)]
struct Ext;
async fn handler(parts: http::request::Parts) {
assert_eq!(parts.method, Method::GET);
assert_eq!(parts.uri, "/");
assert_eq!(parts.version, http::Version::HTTP_11);
assert_eq!(parts.headers["x-foo"], "123");
parts.extensions.get::<Ext>().unwrap();
}
let client = TestClient::new(Router::new().route("/", get(handler)).layer(Extension(Ext)));
let res = client.get("/").header("x-foo", "123").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
}