use super::{Extension, FromRequestParts};
use async_trait::async_trait;
use http::{request::Parts, Uri};
use std::convert::Infallible;
#[cfg(feature = "original-uri")]
#[derive(Debug, Clone)]
pub struct OriginalUri(pub Uri);
#[cfg(feature = "original-uri")]
#[async_trait]
impl<S> FromRequestParts<S> for OriginalUri
where
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let uri = Extension::<Self>::from_request_parts(parts, state)
.await
.unwrap_or_else(|_| Extension(OriginalUri(parts.uri.clone())))
.0;
Ok(uri)
}
}
#[cfg(feature = "original-uri")]
axum_core::__impl_deref!(OriginalUri: Uri);
#[cfg(test)]
mod tests {
use crate::{extract::Extension, routing::get, test_helpers::*, Router};
use http::{Method, StatusCode};
#[crate::test]
async fn extract_request_parts() {
#[derive(Clone)]
struct Ext;
async fn handler(parts: http::request::Parts) {
assert_eq!(parts.method, Method::GET);
assert_eq!(parts.uri, "/");
assert_eq!(parts.version, http::Version::HTTP_11);
assert_eq!(parts.headers["x-foo"], "123");
parts.extensions.get::<Ext>().unwrap();
}
let client = TestClient::new(Router::new().route("/", get(handler)).layer(Extension(Ext)));
let res = client.get("/").header("x-foo", "123").await;
assert_eq!(res.status(), StatusCode::OK);
}
}