use crate::{
HttpRequest,
error::Error,
http::Parts,
http::endpoints::args::{FromPayload, FromRequestParts, FromRequestRef, Payload, Source},
};
use pin_project_lite::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
pin_project! {
pub struct ResultFromPayloadFuture<F> {
#[pin]
inner: F,
}
}
impl<F, T> Future for ResultFromPayloadFuture<F>
where
F: Future<Output = Result<T, Error>>,
{
type Output = Result<Result<T, Error>, Error>;
#[inline]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.inner.poll(cx) {
Poll::Ready(Ok(value)) => Poll::Ready(Ok(Ok(value))),
Poll::Ready(Err(err)) => Poll::Ready(Ok(Err(err))),
Poll::Pending => Poll::Pending,
}
}
}
impl<T: FromRequestRef> FromRequestRef for Result<T, Error> {
#[inline]
fn from_request(req: &HttpRequest) -> Result<Self, Error> {
match T::from_request(req) {
Ok(value) => Ok(Ok(value)),
Err(err) => Ok(Err(err)),
}
}
}
impl<T: FromRequestParts> FromRequestParts for Result<T, Error> {
#[inline]
fn from_parts(parts: &Parts) -> Result<Self, Error> {
match T::from_parts(parts) {
Ok(value) => Ok(Ok(value)),
Err(err) => Ok(Err(err)),
}
}
}
impl<T: FromPayload> FromPayload for Result<T, Error> {
type Future = ResultFromPayloadFuture<T::Future>;
const SOURCE: Source = T::SOURCE;
#[inline]
fn from_payload(payload: Payload<'_>) -> Self::Future {
ResultFromPayloadFuture {
inner: T::from_payload(payload),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::http::endpoints::route::{PathArg, PathArgs};
use crate::{HttpBody, error::Error};
use futures_util::future::{Ready, err, ok};
use hyper::Request;
struct SuccessExtractor;
impl FromPayload for SuccessExtractor {
type Future = Ready<Result<Self, Error>>;
const SOURCE: Source = Source::Parts;
fn from_payload(_: Payload<'_>) -> Self::Future {
ok(SuccessExtractor)
}
}
impl FromRequestParts for SuccessExtractor {
fn from_parts(_: &Parts) -> Result<Self, Error> {
Ok(SuccessExtractor)
}
}
impl FromRequestRef for SuccessExtractor {
fn from_request(_: &HttpRequest) -> Result<Self, Error> {
Ok(SuccessExtractor)
}
}
struct FailureExtractor;
impl FromPayload for FailureExtractor {
type Future = Ready<Result<Self, Error>>;
const SOURCE: Source = Source::Parts;
fn from_payload(_: Payload<'_>) -> Self::Future {
err(Error::client_error("Test error"))
}
}
impl FromRequestParts for FailureExtractor {
fn from_parts(_: &Parts) -> Result<Self, Error> {
Err(Error::client_error("Test error"))
}
}
impl FromRequestRef for FailureExtractor {
fn from_request(_: &HttpRequest) -> Result<Self, Error> {
Err(Error::client_error("Test error"))
}
}
struct BodyExtractor(String);
impl FromPayload for BodyExtractor {
type Future = Ready<Result<Self, Error>>;
const SOURCE: Source = Source::Body;
fn from_payload(payload: Payload<'_>) -> Self::Future {
match payload {
Payload::Body(_) => ok(BodyExtractor("body content".to_string())),
_ => err(Error::client_error("Expected body payload")),
}
}
}
struct PathExtractor(u32);
impl FromPayload for PathExtractor {
type Future = Ready<Result<Self, Error>>;
const SOURCE: Source = Source::Path;
fn from_payload(payload: Payload<'_>) -> Self::Future {
let Payload::Path(param) = payload else {
return err(Error::client_error("Expected path payload"));
};
match param.value.parse::<u32>() {
Ok(id) => ok(PathExtractor(id)),
Err(_) => err(Error::client_error("Invalid path parameter")),
}
}
}
#[tokio::test]
async fn it_extracts_result_returns_ok_on_success() {
let req = Request::get("/").body(()).unwrap();
let (parts, _) = req.into_parts();
let result = Result::<SuccessExtractor, Error>::from_payload(Payload::Parts(&parts)).await;
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_ok());
}
#[test]
fn it_extracts_option_returns_some_from_parts() {
let req = Request::get("/").body(()).unwrap();
let (parts, _) = req.into_parts();
let result = Result::<SuccessExtractor, Error>::from_parts(&parts);
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_ok());
}
#[test]
fn it_extracts_option_returns_some_from_request_ref() {
let req = Request::get("/").body(HttpBody::empty()).unwrap();
let (parts, body) = req.into_parts();
let req = HttpRequest::from_parts(parts, body);
let result = Result::<SuccessExtractor, Error>::from_request(&req);
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_ok());
}
#[tokio::test]
async fn it_extracts_result_returns_err_on_failure() {
let req = Request::get("/").body(()).unwrap();
let (parts, _) = req.into_parts();
let result = Result::<FailureExtractor, Error>::from_payload(Payload::Parts(&parts)).await;
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_err());
}
#[test]
fn it_extracts_option_returns_none_from_parts() {
let req = Request::get("/").body(()).unwrap();
let (parts, _) = req.into_parts();
let result = Result::<FailureExtractor, Error>::from_parts(&parts);
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_err());
}
#[test]
fn it_extracts_option_returns_none_from_request_ref() {
let req = Request::get("/").body(HttpBody::empty()).unwrap();
let (parts, body) = req.into_parts();
let req = HttpRequest::from_parts(parts, body);
let result = Result::<FailureExtractor, Error>::from_request(&req);
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_err());
}
#[tokio::test]
async fn it_extracts_result_preserves_source() {
assert_eq!(Result::<SuccessExtractor, Error>::SOURCE, Source::Parts);
assert_eq!(Result::<BodyExtractor, Error>::SOURCE, Source::Body);
assert_eq!(Result::<PathExtractor, Error>::SOURCE, Source::Path);
}
#[tokio::test]
async fn it_extracts_result_with_body_extractor() {
let body = HttpBody::empty();
let result = Result::<BodyExtractor, Error>::from_payload(Payload::Body(body)).await;
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_ok());
assert_eq!(inner_result.unwrap().0, "body content");
}
#[tokio::test]
async fn it_extracts_result_with_body_extractor_with_wrong_payload() {
let req = Request::get("/").body(()).unwrap();
let (parts, _) = req.into_parts();
let result = Result::<BodyExtractor, Error>::from_payload(Payload::Parts(&parts)).await;
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_err());
}
#[tokio::test]
async fn it_extracts_result_with_path_extractor() {
let param = PathArg {
name: "id".into(),
value: "123".into(),
};
let result = Result::<PathExtractor, Error>::from_payload(Payload::Path(param)).await;
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_ok());
assert_eq!(inner_result.unwrap().0, 123);
}
#[tokio::test]
async fn it_extracts_result_with_path_extractor_returns_invalid_value() {
let param = PathArg {
name: "id".into(),
value: "invalid".into(),
};
let result = Result::<PathExtractor, Error>::from_payload(Payload::Path(param)).await;
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_err());
}
#[tokio::test]
async fn it_extracts_result_with_path_extractor_returns_wrong_payload() {
let req = Request::get("/").body(()).unwrap();
let (parts, _) = req.into_parts();
let result = Result::<PathExtractor, Error>::from_payload(Payload::Parts(&parts)).await;
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_err());
}
#[tokio::test]
async fn it_extracts_result_with_primitive_types() {
let param = PathArg {
name: "id".into(),
value: "42".into(),
};
let result = Result::<i32, Error>::from_payload(Payload::Path(param)).await;
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_ok());
assert_eq!(inner_result.unwrap(), 42);
let param = PathArg {
name: "id".into(),
value: "invalid".into(),
};
let result = Result::<i32, Error>::from_payload(Payload::Path(param)).await;
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_err());
let param = PathArg {
name: "name".into(),
value: "test".into(),
};
let result = Result::<String, Error>::from_payload(Payload::Path(param)).await;
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_ok());
assert_eq!(inner_result.unwrap(), "test");
}
#[tokio::test]
async fn it_extracts_result_with_nested_result() {
let req = Request::get("/").body(()).unwrap();
let (parts, _) = req.into_parts();
let result =
Result::<Result<SuccessExtractor, Error>, Error>::from_payload(Payload::Parts(&parts))
.await;
assert!(result.is_ok());
let outer = result.unwrap();
assert!(outer.is_ok());
let inner = outer.unwrap();
assert!(inner.is_ok());
let result =
Result::<Result<FailureExtractor, Error>, Error>::from_payload(Payload::Parts(&parts))
.await;
assert!(result.is_ok());
let outer = result.unwrap();
assert!(outer.is_ok());
let inner = outer.unwrap();
assert!(inner.is_err());
}
#[test]
fn it_extracts_result_future_poll_ready_ok() {
use std::pin::Pin;
use std::task::{Context, Poll};
let inner_future = ok(SuccessExtractor);
let mut result_future = ResultFromPayloadFuture {
inner: inner_future,
};
let waker = futures_util::task::noop_waker();
let mut cx = Context::from_waker(&waker);
let result = Pin::new(&mut result_future).poll(&mut cx);
match result {
Poll::Ready(Ok(Ok(_))) => {}
_ => panic!("Expected Poll::Ready(Ok(Ok(_)))"),
}
}
#[test]
fn it_extracts_result_future_poll_ready_err() {
use std::pin::Pin;
use std::task::{Context, Poll};
let inner_future = err::<SuccessExtractor, Error>(Error::client_error("test"));
let mut result_future = ResultFromPayloadFuture {
inner: inner_future,
};
let waker = futures_util::task::noop_waker();
let mut cx = Context::from_waker(&waker);
let result = Pin::new(&mut result_future).poll(&mut cx);
match result {
Poll::Ready(Ok(Err(_))) => {}
_ => panic!("Expected Poll::Ready(Ok(Err(_)))"),
}
}
#[test]
fn it_extracts_result_future_poll_pending() {
use futures_util::future::pending;
use std::pin::Pin;
use std::task::{Context, Poll};
let inner_future = pending::<Result<SuccessExtractor, Error>>();
let mut result_future = ResultFromPayloadFuture {
inner: inner_future,
};
let waker = futures_util::task::noop_waker();
let mut cx = Context::from_waker(&waker);
let result = Pin::new(&mut result_future).poll(&mut cx);
match result {
Poll::Pending => {}
_ => panic!("Expected Poll::Pending"),
}
}
#[tokio::test]
async fn it_extracts_result_integration_with_real_extractors() {
use crate::NamedPath;
use serde::Deserialize;
#[derive(Deserialize)]
struct Params {
id: u32,
}
let args: PathArgs = smallvec::smallvec![PathArg {
name: "id".into(),
value: "123".into()
}]
.into();
let result =
Result::<NamedPath<Params>, Error>::from_payload(Payload::PathArgs(&args)).await;
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_ok());
assert_eq!(inner_result.unwrap().id, 123);
let result =
Result::<NamedPath<Params>, Error>::from_payload(Payload::PathArgs(&PathArgs::new()))
.await;
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_err());
}
#[tokio::test]
async fn it_extracts_result_integration_with_path_extractor() {
use crate::Path;
let args: PathArgs = smallvec::smallvec![
PathArg {
name: "id".into(),
value: "123".into()
},
PathArg {
name: "name".into(),
value: "John".into()
}
]
.into();
let result =
Result::<Path<(i32, String)>, Error>::from_payload(Payload::PathArgs(&args)).await;
assert!(result.is_ok());
let result = result.unwrap();
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.0.0, 123);
assert_eq!(result.0.1, "John");
let result =
Result::<Path<(i32, String)>, Error>::from_payload(Payload::PathArgs(&PathArgs::new()))
.await;
assert!(result.is_ok());
assert!(result.unwrap().is_err());
}
#[tokio::test]
async fn it_extracts_result_with_different_error_types() {
struct CustomExtractor;
impl FromPayload for CustomExtractor {
type Future = Ready<Result<Self, Error>>;
const SOURCE: Source = Source::Parts;
fn from_payload(_: Payload<'_>) -> Self::Future {
err(Error::server_error("Custom internal error"))
}
}
let req = Request::get("/").body(()).unwrap();
let (parts, _) = req.into_parts();
let result = Result::<CustomExtractor, Error>::from_payload(Payload::Parts(&parts)).await;
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_err());
let error = inner_result.err().unwrap();
assert!(error.to_string().contains("Custom internal error"));
}
#[tokio::test]
async fn it_extracts_result_maintains_success_value() {
struct ValueExtractor(i32);
impl FromPayload for ValueExtractor {
type Future = Ready<Result<Self, Error>>;
const SOURCE: Source = Source::Parts;
fn from_payload(_: Payload<'_>) -> Self::Future {
ok(ValueExtractor(42))
}
}
let req = Request::get("/").body(()).unwrap();
let (parts, _) = req.into_parts();
let result = Result::<ValueExtractor, Error>::from_payload(Payload::Parts(&parts)).await;
assert!(result.is_ok());
let inner_result = result.unwrap();
assert!(inner_result.is_ok());
assert_eq!(inner_result.unwrap().0, 42);
}
}