1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
use crate::response;
use crate::response::ResponderError;
use async_trait::async_trait;
use derive_more::{Display, From};
use http::{HeaderMap, HeaderValue};
use hyper::Body;
use hyper::Response;
use serde::de;
use serde_urlencoded;
use std::sync::Arc;

#[async_trait]
pub trait FromRequestBodyWithContainer<T, E, C>
where
    T: de::DeserializeOwned + 'static,
    E: ResponderError + 'static,
    C: 'static + Sync + Send,
{
    async fn assert_content_type(_content_type: Option<&HeaderValue>, _: Arc<C>) -> Result<(), E> {
        Ok(())
    }
    async fn extract(headers: &HeaderMap, b: Body, _: Arc<C>) -> Result<T, E>;
}

#[async_trait]
pub trait FromRequestBody<T, E>
where
    T: de::DeserializeOwned + 'static,
    E: ResponderError + 'static,
{
    async fn assert_content_type(_content_type: Option<&HeaderValue>) -> Result<(), E> {
        Ok(())
    }
    async fn extract(headers: &HeaderMap, b: Body) -> Result<T, E>;
}

#[async_trait]
impl<F, T, E, C> FromRequestBodyWithContainer<T, E, C> for F
where
    F: FromRequestBody<T, E> + 'static,
    T: de::DeserializeOwned + 'static,
    E: ResponderError + 'static,
    C: std::any::Any + Sync + Send,
{
    async fn assert_content_type(content_type: Option<&HeaderValue>, _: Arc<C>) -> Result<(), E> {
        F::assert_content_type(content_type).await
    }
    async fn extract(headers: &HeaderMap<HeaderValue>, b: Body, _: Arc<C>) -> Result<T, E> {
        F::extract(headers, b).await
    }
}

#[derive(Debug, Display, From)]
pub enum RequestErr {
    #[display(fmt = "Not found")]
    NotFound,
}

impl ResponderError for RequestErr {}

/// A set of errors that can occur during parsing query strings
#[derive(Debug, Display, From)]
pub enum PayloadError {
    /// Deserialize error
    #[display(fmt = "Payload deserialize error: {}", _0)]
    Deserialize(serde::de::value::Error),
    #[display(fmt = "Empty Payload")]
    NotExist,
    #[display(fmt = "Payload maximum {} exceeded: received {} bytes", _0, _1)]
    Size(u64, u64),
}

impl ResponderError for PayloadError {}

/// A set of errors that can occur during parsing query strings
#[derive(Debug, Display, From)]
pub enum QueryPayloadError {
    /// Deserialize error
    #[display(fmt = "Query deserialize error: {}", _0)]
    Deserialize(serde::de::value::Error),
    #[display(fmt = "Empty query")]
    NotExist,
}

impl ResponderError for QueryPayloadError {}
impl std::error::Error for QueryPayloadError {}

pub trait FromQuery<T, E> {
    fn from_query(query_str: Option<&str>) -> Result<T, E>
    where
        T: de::DeserializeOwned,
        E: ResponderError;
}

impl<T> FromQuery<T, QueryPayloadError> for T {
    fn from_query(query_str: Option<&str>) -> Result<T, QueryPayloadError>
    where
        T: de::DeserializeOwned,
    {
        match query_str {
            Some(query_str) => serde_urlencoded::from_str::<T>(query_str)
                .map(|val| Ok(val))
                .unwrap_or_else(move |e| Err(QueryPayloadError::Deserialize(e))),
            None => Err(QueryPayloadError::NotExist),
        }
    }
}

#[derive(Debug, Display)]
pub enum PathError {
    #[display(fmt = "Path deserialize error: {}", _0)]
    Deserialize(String),
    #[display(fmt = "Missing field error: {}", _0)]
    Missing(String),
}

impl ResponderError for PathError {}
impl std::error::Error for PathError {}

pub fn assert_respond_err<T, E>(e: E) -> Response<Body>
where
    T: response::ErrResponder<E, Body>,
    E: std::error::Error,
{
    T::respond_err(e)
}

impl<T> FromQuery<Option<T>, QueryPayloadError> for T
where
    T: de::DeserializeOwned,
{
    fn from_query(query_str: Option<&str>) -> Result<Option<T>, QueryPayloadError>
    where
        T: FromQuery<T, QueryPayloadError>,
    {
        match T::from_query(query_str) {
            Ok(t) => Ok(Some(t)),
            Err(_) => Ok(None),
        }
    }
}