use crate::{
body::Body,
ext::{PathParameters, QueryStringParameters, StageVariables},
strmap::StrMap,
};
use serde::{
de::{Deserializer, Error as DeError, MapAccess, Visitor},
Deserialize,
};
use serde_json::{error::Error as JsonError, Value};
use std::{borrow::Cow, collections::HashMap, fmt, io::Read, mem};
#[doc(hidden)]
#[derive(Deserialize, Debug)]
#[serde(untagged)]
pub enum LambdaRequest<'a> {
#[serde(rename_all = "camelCase")]
ApiGatewayV2 {
version: Cow<'a, str>,
route_key: Cow<'a, str>,
raw_path: Cow<'a, str>,
raw_query_string: Cow<'a, str>,
cookies: Option<Vec<Cow<'a, str>>>,
#[serde(deserialize_with = "deserialize_headers")]
headers: http::HeaderMap,
#[serde(default)]
query_string_parameters: StrMap,
#[serde(default)]
path_parameters: StrMap,
#[serde(default)]
stage_variables: StrMap,
body: Option<Cow<'a, str>>,
#[serde(default)]
is_base64_encoded: bool,
request_context: ApiGatewayV2RequestContext,
},
#[serde(rename_all = "camelCase")]
Alb {
path: Cow<'a, str>,
#[serde(deserialize_with = "deserialize_method")]
http_method: http::Method,
#[serde(deserialize_with = "deserialize_headers")]
headers: http::HeaderMap,
#[serde(default, deserialize_with = "deserialize_multi_value_headers")]
multi_value_headers: http::HeaderMap,
#[serde(deserialize_with = "nullable_default")]
query_string_parameters: StrMap,
#[serde(default, deserialize_with = "nullable_default")]
multi_value_query_string_parameters: StrMap,
body: Option<Cow<'a, str>>,
#[serde(default)]
is_base64_encoded: bool,
request_context: AlbRequestContext,
},
#[serde(rename_all = "camelCase")]
ApiGateway {
path: Cow<'a, str>,
#[serde(deserialize_with = "deserialize_method")]
http_method: http::Method,
#[serde(deserialize_with = "deserialize_headers")]
headers: http::HeaderMap,
#[serde(default, deserialize_with = "deserialize_multi_value_headers")]
multi_value_headers: http::HeaderMap,
#[serde(deserialize_with = "nullable_default")]
query_string_parameters: StrMap,
#[serde(default, deserialize_with = "nullable_default")]
multi_value_query_string_parameters: StrMap,
#[serde(default, deserialize_with = "nullable_default")]
path_parameters: StrMap,
#[serde(default, deserialize_with = "nullable_default")]
stage_variables: StrMap,
body: Option<Cow<'a, str>>,
#[serde(default)]
is_base64_encoded: bool,
request_context: ApiGatewayRequestContext,
},
}
impl LambdaRequest<'_> {
pub fn request_origin(&self) -> RequestOrigin {
match self {
LambdaRequest::ApiGatewayV2 { .. } => RequestOrigin::ApiGatewayV2,
LambdaRequest::Alb { .. } => RequestOrigin::Alb,
LambdaRequest::ApiGateway { .. } => RequestOrigin::ApiGateway,
}
}
}
#[doc(hidden)]
#[derive(Debug)]
pub enum RequestOrigin {
ApiGatewayV2,
ApiGateway,
Alb,
}
#[derive(Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct ApiGatewayV2RequestContext {
pub account_id: String,
pub api_id: String,
#[serde(default)]
pub authorizer: HashMap<String, Value>,
pub domain_name: String,
pub domain_prefix: String,
pub http: Http,
pub request_id: String,
pub route_key: String,
pub stage: String,
pub time: String,
pub time_epoch: usize,
}
#[derive(Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct ApiGatewayRequestContext {
pub account_id: String,
pub resource_id: String,
pub stage: String,
pub request_id: String,
pub resource_path: String,
pub http_method: String,
#[serde(default)]
pub authorizer: HashMap<String, Value>,
pub api_id: String,
pub identity: Identity,
}
#[derive(Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct AlbRequestContext {
pub elb: Elb,
}
#[derive(Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum RequestContext {
ApiGatewayV2(ApiGatewayV2RequestContext),
ApiGateway(ApiGatewayRequestContext),
Alb(AlbRequestContext),
}
#[derive(Deserialize, Debug, Default, Clone)]
#[serde(rename_all = "camelCase")]
pub struct Elb {
pub target_group_arn: String,
}
#[derive(Deserialize, Debug, Default, Clone)]
#[serde(rename_all = "camelCase")]
pub struct Http {
#[serde(deserialize_with = "deserialize_method")]
pub method: http::Method,
pub path: String,
pub protocol: String,
pub source_ip: String,
pub user_agent: String,
}
#[derive(Deserialize, Debug, Default, Clone)]
#[serde(rename_all = "camelCase")]
pub struct Identity {
pub source_ip: String,
pub cognito_identity_id: Option<String>,
pub cognito_identity_pool_id: Option<String>,
pub cognito_authentication_provider: Option<String>,
pub cognito_authentication_type: Option<String>,
pub account_id: Option<String>,
pub caller: Option<String>,
pub api_key: Option<String>,
pub access_key: Option<String>,
pub user: Option<String>,
pub user_agent: Option<String>,
pub user_arn: Option<String>,
}
fn deserialize_method<'de, D>(deserializer: D) -> Result<http::Method, D::Error>
where
D: Deserializer<'de>,
{
struct MethodVisitor;
impl<'de> Visitor<'de> for MethodVisitor {
type Value = http::Method;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "a Method")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: DeError,
{
v.parse().map_err(E::custom)
}
}
deserializer.deserialize_str(MethodVisitor)
}
fn deserialize_multi_value_headers<'de, D>(deserializer: D) -> Result<http::HeaderMap, D::Error>
where
D: Deserializer<'de>,
{
struct HeaderVisitor;
impl<'de> Visitor<'de> for HeaderVisitor {
type Value = http::HeaderMap;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "a multi valued HeaderMap<HeaderValue>")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut headers = map
.size_hint()
.map(http::HeaderMap::with_capacity)
.unwrap_or_else(http::HeaderMap::new);
while let Some((key, values)) = map.next_entry::<Cow<'_, str>, Vec<Cow<'_, str>>>()? {
if !key.is_empty() {
for value in values {
let header_name = key.parse::<http::header::HeaderName>().map_err(A::Error::custom)?;
let header_value = http::header::HeaderValue::from_maybe_shared(value.into_owned())
.map_err(A::Error::custom)?;
headers.append(header_name, header_value);
}
}
}
Ok(headers)
}
}
deserializer.deserialize_map(HeaderVisitor)
}
fn deserialize_headers<'de, D>(deserializer: D) -> Result<http::HeaderMap, D::Error>
where
D: Deserializer<'de>,
{
struct HeaderVisitor;
impl<'de> Visitor<'de> for HeaderVisitor {
type Value = http::HeaderMap;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "a HeaderMap<HeaderValue>")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut headers = map
.size_hint()
.map(http::HeaderMap::with_capacity)
.unwrap_or_else(http::HeaderMap::new);
while let Some((key, value)) = map.next_entry::<Cow<'_, str>, Cow<'_, str>>()? {
let header_name = key.parse::<http::header::HeaderName>().map_err(A::Error::custom)?;
let header_value =
http::header::HeaderValue::from_maybe_shared(value.into_owned()).map_err(A::Error::custom)?;
headers.append(header_name, header_value);
}
Ok(headers)
}
}
deserializer.deserialize_map(HeaderVisitor)
}
fn nullable_default<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
D: Deserializer<'de>,
T: Default + Deserialize<'de>,
{
let opt = Option::deserialize(deserializer)?;
Ok(opt.unwrap_or_else(T::default))
}
impl<'a> From<LambdaRequest<'a>> for http::Request<Body> {
fn from(value: LambdaRequest<'_>) -> Self {
match value {
LambdaRequest::ApiGatewayV2 {
raw_path,
raw_query_string,
mut headers,
query_string_parameters,
path_parameters,
stage_variables,
body,
is_base64_encoded,
request_context,
cookies,
..
} => {
if let Some(cookies) = cookies {
if let Ok(header_value) = http::header::HeaderValue::from_str(&cookies.join(";")) {
headers.append(http::header::COOKIE, header_value);
}
}
let builder = http::Request::builder()
.method(request_context.http.method.as_ref())
.uri({
let mut url = format!(
"{}://{}{}",
headers
.get("X-Forwarded-Proto")
.and_then(|val| val.to_str().ok())
.unwrap_or_else(|| "https"),
headers
.get(http::header::HOST)
.and_then(|val| val.to_str().ok())
.unwrap_or_else(|| request_context.domain_name.as_ref()),
raw_path
);
if !raw_query_string.is_empty() {
url.push('?');
url.push_str(raw_query_string.as_ref());
}
url
})
.extension(QueryStringParameters(query_string_parameters))
.extension(PathParameters(path_parameters))
.extension(StageVariables(stage_variables))
.extension(RequestContext::ApiGatewayV2(request_context));
let mut req = builder
.body(body.map_or_else(Body::default, |b| Body::from_maybe_encoded(is_base64_encoded, b)))
.expect("failed to build request");
let _ = mem::replace(req.headers_mut(), headers);
req
}
LambdaRequest::ApiGateway {
path,
http_method,
headers,
mut multi_value_headers,
query_string_parameters,
multi_value_query_string_parameters,
path_parameters,
stage_variables,
body,
is_base64_encoded,
request_context,
} => {
let builder = http::Request::builder()
.method(http_method)
.uri({
format!(
"{}://{}{}",
headers
.get("X-Forwarded-Proto")
.and_then(|val| val.to_str().ok())
.unwrap_or_else(|| "https"),
headers
.get(http::header::HOST)
.and_then(|val| val.to_str().ok())
.unwrap_or_default(),
path
)
})
.extension(QueryStringParameters(
if multi_value_query_string_parameters.is_empty() {
query_string_parameters
} else {
multi_value_query_string_parameters
},
))
.extension(PathParameters(path_parameters))
.extension(StageVariables(stage_variables))
.extension(RequestContext::ApiGateway(request_context));
let mut req = builder
.body(body.map_or_else(Body::default, |b| Body::from_maybe_encoded(is_base64_encoded, b)))
.expect("failed to build request");
for (key, value) in headers {
if let Some(first_key) = key {
if !multi_value_headers.contains_key(&first_key) {
multi_value_headers.append(first_key, value);
}
}
}
let _ = mem::replace(req.headers_mut(), multi_value_headers);
req
}
LambdaRequest::Alb {
path,
http_method,
headers,
mut multi_value_headers,
query_string_parameters,
multi_value_query_string_parameters,
body,
is_base64_encoded,
request_context,
} => {
let builder = http::Request::builder()
.method(http_method)
.uri({
format!(
"{}://{}{}",
headers
.get("X-Forwarded-Proto")
.and_then(|val| val.to_str().ok())
.unwrap_or_else(|| "https"),
headers
.get(http::header::HOST)
.and_then(|val| val.to_str().ok())
.unwrap_or_default(),
path
)
})
.extension(QueryStringParameters(
if multi_value_query_string_parameters.is_empty() {
query_string_parameters
} else {
multi_value_query_string_parameters
},
))
.extension(RequestContext::Alb(request_context));
let mut req = builder
.body(body.map_or_else(Body::default, |b| Body::from_maybe_encoded(is_base64_encoded, b)))
.expect("failed to build request");
for (key, value) in headers {
if let Some(first_key) = key {
if !multi_value_headers.contains_key(&first_key) {
multi_value_headers.append(first_key, value);
}
}
}
let _ = mem::replace(req.headers_mut(), multi_value_headers);
req
}
}
}
}
pub fn from_reader<R>(rdr: R) -> Result<crate::Request, JsonError>
where
R: Read,
{
serde_json::from_reader(rdr).map(LambdaRequest::into)
}
pub fn from_str(s: &str) -> Result<crate::Request, JsonError> {
serde_json::from_str(s).map(LambdaRequest::into)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::RequestExt;
use serde_json;
use std::{collections::HashMap, fs::File};
#[test]
fn deserializes_apigw_request_events_from_readables() {
let result = from_reader(File::open("tests/data/apigw_proxy_request.json").expect("expected file"));
assert!(result.is_ok(), format!("event was not parsed as expected {:?}", result));
}
#[test]
fn deserializes_minimal_apigw_v2_request_events() {
let input = include_str!("../tests/data/apigw_v2_proxy_request_minimal.json");
let result = from_str(input);
assert!(
result.is_ok(),
format!("event was not parsed as expected {:?} given {}", result, input)
);
let req = result.expect("failed to parse request");
assert_eq!(req.method(), "GET");
assert_eq!(req.uri(), "https://xxx.execute-api.us-east-1.amazonaws.com/");
}
#[test]
fn deserializes_apigw_v2_request_events() {
let input = include_str!("../tests/data/apigw_v2_proxy_request.json");
let result = from_str(input);
assert!(
result.is_ok(),
format!("event was not parsed as expected {:?} given {}", result, input)
);
let req = result.expect("failed to parse request");
let cookie_header = req
.headers()
.get(http::header::COOKIE)
.ok_or_else(|| "Cookie header not found".to_string())
.and_then(|v| v.to_str().map_err(|e| e.to_string()));
assert_eq!(req.method(), "POST");
assert_eq!(req.uri(), "https://id.execute-api.us-east-1.amazonaws.com/my/path?parameter1=value1¶meter1=value2¶meter2=value");
assert_eq!(cookie_header, Ok("cookie1=value1;cookie2=value2"));
}
#[test]
fn deserializes_apigw_request_events() {
let input = include_str!("../tests/data/apigw_proxy_request.json");
let result = from_str(input);
assert!(
result.is_ok(),
format!("event was not parsed as expected {:?} given {}", result, input)
);
let req = result.expect("failed to parse request");
assert_eq!(req.method(), "GET");
assert_eq!(
req.uri(),
"https://wt6mne2s9k.execute-api.us-west-2.amazonaws.com/test/hello"
);
}
#[test]
fn deserializes_alb_request_events() {
let input = include_str!("../tests/data/alb_request.json");
let result = from_str(input);
assert!(
result.is_ok(),
format!("event was not parsed as expected {:?} given {}", result, input)
);
let req = result.expect("failed to parse request");
assert_eq!(req.method(), "GET");
assert_eq!(req.uri(), "https://lambda-846800462-us-east-2.elb.amazonaws.com/");
}
#[test]
fn deserializes_apigw_multi_value_request_events() {
let input = include_str!("../tests/data/apigw_multi_value_proxy_request.json");
let result = from_str(input);
assert!(
result.is_ok(),
format!("event is was not parsed as expected {:?} given {}", result, input)
);
let request = result.expect("failed to parse request");
assert!(!request.query_string_parameters().is_empty());
assert_eq!(
request.query_string_parameters().get_all("multivalueName"),
Some(vec!["you", "me"])
);
}
#[test]
fn deserializes_alb_multi_value_request_events() {
let input = include_str!("../tests/data/alb_multi_value_request.json");
let result = from_str(input);
assert!(
result.is_ok(),
format!("event is was not parsed as expected {:?} given {}", result, input)
);
let request = result.expect("failed to parse request");
assert!(!request.query_string_parameters().is_empty());
assert_eq!(
request.query_string_parameters().get_all("myKey"),
Some(vec!["val1", "val2"])
);
}
#[test]
fn deserialize_with_null() {
#[derive(Debug, PartialEq, Deserialize)]
struct Test {
#[serde(deserialize_with = "nullable_default")]
foo: HashMap<String, String>,
}
assert_eq!(
serde_json::from_str::<Test>(r#"{"foo":null}"#).expect("failed to deserialize"),
Test { foo: HashMap::new() }
)
}
}