use crate::{request::RequestContext, strmap::StrMap, Body};
use serde::{de::value::Error as SerdeError, Deserialize};
use std::{error::Error, fmt};
pub(crate) struct QueryStringParameters(pub(crate) StrMap);
pub(crate) struct PathParameters(pub(crate) StrMap);
pub(crate) struct StageVariables(pub(crate) StrMap);
#[derive(Debug)]
pub enum PayloadError {
Json(serde_json::Error),
WwwFormUrlEncoded(SerdeError),
}
impl fmt::Display for PayloadError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PayloadError::Json(json) => writeln!(f, "failed to parse payload from application/json {}", json),
PayloadError::WwwFormUrlEncoded(form) => writeln!(
f,
"failed to parse payload from application/x-www-form-urlencoded {}",
form
),
}
}
}
impl Error for PayloadError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
PayloadError::Json(json) => Some(json),
PayloadError::WwwFormUrlEncoded(form) => Some(form),
}
}
}
pub trait RequestExt {
fn query_string_parameters(&self) -> StrMap;
fn with_query_string_parameters<Q>(self, parameters: Q) -> Self
where
Q: Into<StrMap>;
fn path_parameters(&self) -> StrMap;
fn with_path_parameters<P>(self, parameters: P) -> Self
where
P: Into<StrMap>;
fn stage_variables(&self) -> StrMap;
#[cfg(test)]
fn with_stage_variables<V>(self, variables: V) -> Self
where
V: Into<StrMap>;
fn request_context(&self) -> RequestContext;
fn payload<D>(&self) -> Result<Option<D>, PayloadError>
where
for<'de> D: Deserialize<'de>;
}
impl RequestExt for http::Request<Body> {
fn query_string_parameters(&self) -> StrMap {
self.extensions()
.get::<QueryStringParameters>()
.map(|ext| ext.0.clone())
.unwrap_or_default()
}
fn with_query_string_parameters<Q>(self, parameters: Q) -> Self
where
Q: Into<StrMap>,
{
let mut s = self;
s.extensions_mut().insert(QueryStringParameters(parameters.into()));
s
}
fn path_parameters(&self) -> StrMap {
self.extensions()
.get::<PathParameters>()
.map(|ext| ext.0.clone())
.unwrap_or_default()
}
fn with_path_parameters<P>(self, parameters: P) -> Self
where
P: Into<StrMap>,
{
let mut s = self;
s.extensions_mut().insert(PathParameters(parameters.into()));
s
}
fn stage_variables(&self) -> StrMap {
self.extensions()
.get::<StageVariables>()
.map(|ext| ext.0.clone())
.unwrap_or_default()
}
#[cfg(test)]
fn with_stage_variables<V>(self, variables: V) -> Self
where
V: Into<StrMap>,
{
let mut s = self;
s.extensions_mut().insert(StageVariables(variables.into()));
s
}
fn request_context(&self) -> RequestContext {
self.extensions()
.get::<RequestContext>()
.cloned()
.expect("Request did not contain a request context")
}
fn payload<D>(&self) -> Result<Option<D>, PayloadError>
where
for<'de> D: Deserialize<'de>,
{
self.headers()
.get(http::header::CONTENT_TYPE)
.map(|ct| match ct.to_str() {
Ok(content_type) => {
if content_type.starts_with("application/x-www-form-urlencoded") {
return serde_urlencoded::from_bytes::<D>(self.body().as_ref())
.map_err(PayloadError::WwwFormUrlEncoded)
.map(Some);
} else if content_type.starts_with("application/json") {
return serde_json::from_slice::<D>(self.body().as_ref())
.map_err(PayloadError::Json)
.map(Some);
}
Ok(None)
}
_ => Ok(None),
})
.unwrap_or_else(|| Ok(None))
}
}
#[cfg(test)]
mod tests {
use crate::{Body, Request, RequestExt};
use serde::Deserialize;
#[test]
fn requests_can_mock_query_string_parameters_ext() {
let mocked = hashmap! {
"foo".into() => vec!["bar".into()]
};
let request = Request::default().with_query_string_parameters(mocked.clone());
assert_eq!(request.query_string_parameters(), mocked.into());
}
#[test]
fn requests_can_mock_path_parameters_ext() {
let mocked = hashmap! {
"foo".into() => vec!["bar".into()]
};
let request = Request::default().with_path_parameters(mocked.clone());
assert_eq!(request.path_parameters(), mocked.into());
}
#[test]
fn requests_can_mock_stage_variables_ext() {
let mocked = hashmap! {
"foo".into() => vec!["bar".into()]
};
let request = Request::default().with_stage_variables(mocked.clone());
assert_eq!(request.stage_variables(), mocked.into());
}
#[test]
fn requests_have_form_post_parsable_payloads() {
#[derive(Deserialize, PartialEq, Debug)]
struct Payload {
foo: String,
baz: usize,
}
let request = http::Request::builder()
.header("Content-Type", "application/x-www-form-urlencoded")
.body(Body::from("foo=bar&baz=2"))
.expect("failed to build request");
let payload: Option<Payload> = request.payload().unwrap_or_default();
assert_eq!(
payload,
Some(Payload {
foo: "bar".into(),
baz: 2
})
);
}
#[test]
fn requests_have_json_parseable_payloads() {
#[derive(Deserialize, PartialEq, Debug)]
struct Payload {
foo: String,
baz: usize,
}
let request = http::Request::builder()
.header("Content-Type", "application/json")
.body(Body::from(r#"{"foo":"bar", "baz": 2}"#))
.expect("failed to build request");
let payload: Option<Payload> = request.payload().unwrap_or_default();
assert_eq!(
payload,
Some(Payload {
foo: "bar".into(),
baz: 2
})
);
}
#[test]
fn requests_match_form_post_content_type_with_charset() {
#[derive(Deserialize, PartialEq, Debug)]
struct Payload {
foo: String,
baz: usize,
}
let request = http::Request::builder()
.header("Content-Type", "application/x-www-form-urlencoded; charset=UTF-8")
.body(Body::from("foo=bar&baz=2"))
.expect("failed to build request");
let payload: Option<Payload> = request.payload().unwrap_or_default();
assert_eq!(
payload,
Some(Payload {
foo: "bar".into(),
baz: 2
})
);
}
#[test]
fn requests_match_json_content_type_with_charset() {
#[derive(Deserialize, PartialEq, Debug)]
struct Payload {
foo: String,
baz: usize,
}
let request = http::Request::builder()
.header("Content-Type", "application/json; charset=UTF-8")
.body(Body::from(r#"{"foo":"bar", "baz": 2}"#))
.expect("failed to build request");
let payload: Option<Payload> = request.payload().unwrap_or_default();
assert_eq!(
payload,
Some(Payload {
foo: "bar".into(),
baz: 2
})
);
}
#[test]
fn requests_omiting_content_types_do_not_support_parseable_payloads() {
#[derive(Deserialize, PartialEq, Debug)]
struct Payload {
foo: String,
baz: usize,
}
let request = http::Request::builder()
.body(Body::from(r#"{"foo":"bar", "baz": 2}"#))
.expect("failed to bulid request");
let payload: Option<Payload> = request.payload().unwrap_or_default();
assert_eq!(payload, None);
}
}