mod de;
use super::rejection::ExtensionsAlreadyExtracted;
use crate::{
body::{boxed, Full},
extract::{rejection::*, FromRequest, RequestParts},
response::{IntoResponse, Response},
routing::{InvalidUtf8InPathParam, UrlParams},
};
use async_trait::async_trait;
use http::StatusCode;
use serde::de::DeserializeOwned;
use std::{
borrow::Cow,
fmt,
ops::{Deref, DerefMut},
};
#[derive(Debug)]
pub struct Path<T>(pub T);
impl<T> Deref for Path<T> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Path<T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[async_trait]
impl<T, B> FromRequest<B> for Path<T>
where
T: DeserializeOwned + Send,
B: Send,
{
type Rejection = PathRejection;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let ext = req
.extensions_mut()
.ok_or_else::<Self::Rejection, _>(|| ExtensionsAlreadyExtracted::default().into())?;
let params = match ext.get::<Option<UrlParams>>() {
Some(Some(UrlParams(Ok(params)))) => Cow::Borrowed(params),
Some(Some(UrlParams(Err(InvalidUtf8InPathParam { key })))) => {
let err = PathDeserializationError {
kind: ErrorKind::InvalidUtf8InPathParam {
key: key.as_str().to_owned(),
},
};
let err = FailedToDeserializePathParams(err);
return Err(err.into());
}
Some(None) => Cow::Owned(Vec::new()),
None => {
return Err(MissingPathParams.into());
}
};
T::deserialize(de::PathDeserializer::new(&*params))
.map_err(|err| {
PathRejection::FailedToDeserializePathParams(FailedToDeserializePathParams(err))
})
.map(Path)
}
}
#[derive(Debug)]
pub(crate) struct PathDeserializationError {
pub(super) kind: ErrorKind,
}
impl PathDeserializationError {
pub(super) fn new(kind: ErrorKind) -> Self {
Self { kind }
}
pub(super) fn wrong_number_of_parameters() -> WrongNumberOfParameters<()> {
WrongNumberOfParameters { got: () }
}
pub(super) fn unsupported_type(name: &'static str) -> Self {
Self::new(ErrorKind::UnsupportedType { name })
}
}
pub(super) struct WrongNumberOfParameters<G> {
got: G,
}
impl<G> WrongNumberOfParameters<G> {
#[allow(clippy::unused_self)]
pub(super) fn got<G2>(self, got: G2) -> WrongNumberOfParameters<G2> {
WrongNumberOfParameters { got }
}
}
impl WrongNumberOfParameters<usize> {
pub(super) fn expected(self, expected: usize) -> PathDeserializationError {
PathDeserializationError::new(ErrorKind::WrongNumberOfParameters {
got: self.got,
expected,
})
}
}
impl serde::de::Error for PathDeserializationError {
#[inline]
fn custom<T>(msg: T) -> Self
where
T: fmt::Display,
{
Self {
kind: ErrorKind::Message(msg.to_string()),
}
}
}
impl fmt::Display for PathDeserializationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.kind.fmt(f)
}
}
impl std::error::Error for PathDeserializationError {}
#[derive(Debug, PartialEq)]
#[non_exhaustive]
pub enum ErrorKind {
WrongNumberOfParameters {
got: usize,
expected: usize,
},
ParseErrorAtKey {
key: String,
value: String,
expected_type: &'static str,
},
ParseErrorAtIndex {
index: usize,
value: String,
expected_type: &'static str,
},
ParseError {
value: String,
expected_type: &'static str,
},
InvalidUtf8InPathParam {
key: String,
},
UnsupportedType {
name: &'static str,
},
Message(String),
}
impl fmt::Display for ErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ErrorKind::Message(error) => error.fmt(f),
ErrorKind::InvalidUtf8InPathParam { key } => write!(f, "Invalid UTF-8 in `{}`", key),
ErrorKind::WrongNumberOfParameters { got, expected } => write!(
f,
"Wrong number of parameters. Expected {} but got {}",
expected, got
),
ErrorKind::UnsupportedType { name } => write!(f, "Unsupported type `{}`", name),
ErrorKind::ParseErrorAtKey {
key,
value,
expected_type,
} => write!(
f,
"Cannot parse `{}` with value `{:?}` to a `{}`",
key, value, expected_type
),
ErrorKind::ParseError {
value,
expected_type,
} => write!(f, "Cannot parse `{:?}` to a `{}`", value, expected_type),
ErrorKind::ParseErrorAtIndex {
index,
value,
expected_type,
} => write!(
f,
"Cannot parse value at index {} with value `{:?}` to a `{}`",
index, value, expected_type
),
}
}
}
#[derive(Debug)]
pub struct FailedToDeserializePathParams(PathDeserializationError);
impl FailedToDeserializePathParams {
pub fn into_kind(self) -> ErrorKind {
self.0.kind
}
}
impl IntoResponse for FailedToDeserializePathParams {
fn into_response(self) -> Response {
let (status, body) = match self.0.kind {
ErrorKind::Message(_)
| ErrorKind::InvalidUtf8InPathParam { .. }
| ErrorKind::WrongNumberOfParameters { .. }
| ErrorKind::ParseError { .. }
| ErrorKind::ParseErrorAtIndex { .. }
| ErrorKind::ParseErrorAtKey { .. } => (
StatusCode::BAD_REQUEST,
format!("Invalid URL: {}", self.0.kind),
),
ErrorKind::UnsupportedType { .. } => {
(StatusCode::INTERNAL_SERVER_ERROR, self.0.kind.to_string())
}
};
let mut res = Response::new(boxed(Full::from(body)));
*res.status_mut() = status;
res
}
}
impl fmt::Display for FailedToDeserializePathParams {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl std::error::Error for FailedToDeserializePathParams {}
#[cfg(test)]
mod tests {
use super::*;
use crate::{routing::get, test_helpers::*, Router};
use http::{Request, StatusCode};
use hyper::Body;
use std::collections::HashMap;
#[tokio::test]
async fn extracting_url_params() {
let app = Router::new().route(
"/users/:id",
get(|Path(id): Path<i32>| async move {
assert_eq!(id, 42);
})
.post(|Path(params_map): Path<HashMap<String, i32>>| async move {
assert_eq!(params_map.get("id").unwrap(), &1337);
}),
);
let client = TestClient::new(app);
let res = client.get("/users/42").send().await;
assert_eq!(res.status(), StatusCode::OK);
let res = client.post("/users/1337").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn extracting_url_params_multiple_times() {
let app = Router::new().route("/users/:id", get(|_: Path<i32>, _: Path<String>| async {}));
let client = TestClient::new(app);
let res = client.get("/users/42").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn percent_decoding() {
let app = Router::new().route(
"/:key",
get(|Path(param): Path<String>| async move { param }),
);
let client = TestClient::new(app);
let res = client.get("/one%20two").send().await;
assert_eq!(res.text().await, "one two");
}
#[tokio::test]
async fn supports_128_bit_numbers() {
let app = Router::new()
.route(
"/i/:key",
get(|Path(param): Path<i128>| async move { param.to_string() }),
)
.route(
"/u/:key",
get(|Path(param): Path<u128>| async move { param.to_string() }),
);
let client = TestClient::new(app);
let res = client.get("/i/123").send().await;
assert_eq!(res.text().await, "123");
let res = client.get("/u/123").send().await;
assert_eq!(res.text().await, "123");
}
#[tokio::test]
async fn wildcard() {
let app = Router::new()
.route(
"/foo/*rest",
get(|Path(param): Path<String>| async move { param }),
)
.route(
"/bar/*rest",
get(|Path(params): Path<HashMap<String, String>>| async move {
params.get("rest").unwrap().clone()
}),
);
let client = TestClient::new(app);
let res = client.get("/foo/bar/baz").send().await;
assert_eq!(res.text().await, "/bar/baz");
let res = client.get("/bar/baz/qux").send().await;
assert_eq!(res.text().await, "/baz/qux");
}
#[tokio::test]
async fn captures_dont_match_empty_segments() {
let app = Router::new().route("/:key", get(|| async {}));
let client = TestClient::new(app);
let res = client.get("/").send().await;
assert_eq!(res.status(), StatusCode::NOT_FOUND);
let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn when_extensions_are_missing() {
let app = Router::new().route("/:key", get(|_: Request<Body>, _: Path<String>| async {}));
let client = TestClient::new(app);
let res = client.get("/foo").send().await;
assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(res.text().await, "Extensions taken by other extractor");
}
}