use http::header::CONTENT_TYPE;
use http::Request as HttpRequest;
use serde::de::value::Error as SerdeError;
use serde::Deserialize;
use serde_json;
use serde_urlencoded;
use request::RequestContext;
use strmap::StrMap;
pub(crate) struct QueryStringParameters(pub(crate) StrMap);
pub(crate) struct PathParameters(pub(crate) StrMap);
pub(crate) struct StageVariables(pub(crate) StrMap);
#[derive(Debug, Fail)]
pub enum PayloadError {
#[fail(display = "failed to parse payload from application/json")]
Json(serde_json::Error),
#[fail(display = "failed to parse payload application/x-www-form-urlencoded")]
WwwFormUrlEncoded(SerdeError),
}
pub trait RequestExt {
fn query_string_parameters(&self) -> StrMap;
fn path_parameters(&self) -> StrMap;
fn stage_variables(&self) -> StrMap;
fn request_context(&self) -> RequestContext;
fn payload<D>(&self) -> Result<Option<D>, PayloadError>
where
for<'de> D: Deserialize<'de>;
}
impl RequestExt for HttpRequest<super::Body> {
fn query_string_parameters(&self) -> StrMap {
self.extensions()
.get::<QueryStringParameters>()
.map(|ext| ext.0.clone())
.unwrap_or_default()
}
fn path_parameters(&self) -> StrMap {
self.extensions()
.get::<PathParameters>()
.map(|ext| ext.0.clone())
.unwrap_or_default()
}
fn stage_variables(&self) -> StrMap {
self.extensions()
.get::<StageVariables>()
.map(|ext| ext.0.clone())
.unwrap_or_default()
}
fn request_context(&self) -> RequestContext {
self.extensions()
.get::<RequestContext>()
.cloned()
.unwrap_or_default()
}
fn payload<D>(&self) -> Result<Option<D>, PayloadError>
where
for<'de> D: Deserialize<'de>,
{
self.headers()
.get(CONTENT_TYPE)
.map(|ct| match ct.to_str() {
Ok("application/x-www-form-urlencoded") => {
serde_urlencoded::from_bytes::<D>(self.body().as_ref())
.map_err(PayloadError::WwwFormUrlEncoded)
.map(Some)
}
Ok("application/json") => serde_json::from_slice::<D>(self.body().as_ref())
.map_err(PayloadError::Json)
.map(Some),
_ => Ok(None),
})
.unwrap_or_else(|| Ok(None))
}
}
#[cfg(test)]
mod tests {
use http::HeaderMap;
use http::Request as HttpRequest;
use std::collections::HashMap;
use {GatewayRequest, RequestExt, StrMap};
#[test]
fn requests_have_query_string_ext() {
let mut headers = HeaderMap::new();
headers.insert("Host", "www.rust-lang.org".parse().unwrap());
let mut query = HashMap::new();
query.insert("foo".to_owned(), "bar".to_owned());
let gwr: GatewayRequest = GatewayRequest {
path: "/foo".into(),
headers,
query_string_parameters: StrMap(query.clone().into()),
..GatewayRequest::default()
};
let actual = HttpRequest::from(gwr);
assert_eq!(
actual.query_string_parameters(),
StrMap(query.clone().into())
);
}
#[test]
fn requests_have_form_post_parseable_payloads() {
let mut headers = HeaderMap::new();
headers.insert("Host", "www.rust-lang.org".parse().unwrap());
headers.insert(
"Content-Type",
"application/x-www-form-urlencoded".parse().unwrap(),
);
#[derive(Deserialize, PartialEq, Debug)]
struct Payload {
foo: String,
baz: usize,
}
let gwr: GatewayRequest = GatewayRequest {
path: "/foo".into(),
headers,
body: Some("foo=bar&baz=2".into()),
..GatewayRequest::default()
};
let actual = HttpRequest::from(gwr);
let payload: Option<Payload> = actual.payload().unwrap_or_default();
assert_eq!(
payload,
Some(Payload {
foo: "bar".into(),
baz: 2
})
)
}
#[test]
fn requests_have_form_post_parseable_payloads_for_hashmaps() {
let mut headers = HeaderMap::new();
headers.insert("Host", "www.rust-lang.org".parse().unwrap());
headers.insert(
"Content-Type",
"application/x-www-form-urlencoded".parse().unwrap(),
);
let gwr: GatewayRequest = GatewayRequest {
path: "/foo".into(),
headers,
body: Some("foo=bar&baz=2".into()),
..GatewayRequest::default()
};
let actual = HttpRequest::from(gwr);
let mut expected = HashMap::new();
expected.insert("foo".to_string(), "bar".to_string());
expected.insert("baz".to_string(), "2".to_string());
let payload: Option<HashMap<String, String>> = actual.payload().unwrap_or_default();
assert_eq!(payload, Some(expected))
}
#[test]
fn requests_have_json_parseable_payloads() {
let mut headers = HeaderMap::new();
headers.insert("Host", "www.rust-lang.org".parse().unwrap());
headers.insert("Content-Type", "application/json".parse().unwrap());
#[derive(Deserialize, PartialEq, Debug)]
struct Payload {
foo: String,
baz: usize,
}
let gwr: GatewayRequest = GatewayRequest {
path: "/foo".into(),
headers,
body: Some(r#"{"foo":"bar", "baz": 2}"#.into()),
..GatewayRequest::default()
};
let actual = HttpRequest::from(gwr);
let payload: Option<Payload> = actual.payload().unwrap_or_default();
assert_eq!(
payload,
Some(Payload {
foo: "bar".into(),
baz: 2
})
)
}
}