use crate::{Config, Error};
use http::{HeaderMap, HeaderValue};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, convert::TryFrom};
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct Diagnostic {
pub(crate) error_type: String,
pub(crate) error_message: String,
}
#[test]
fn round_trip_lambda_error() -> Result<(), Error> {
use serde_json::{json, Value};
let expected = json!({
"errorType": "InvalidEventDataError",
"errorMessage": "Error parsing event data.",
});
let actual: Diagnostic = serde_json::from_value(expected.clone())?;
let actual: Value = serde_json::to_value(actual)?;
assert_eq!(expected, actual);
Ok(())
}
#[derive(Debug, Clone, PartialEq)]
pub struct RequestId(pub String);
#[derive(Debug, Clone, PartialEq)]
pub struct InvocationDeadline(pub u64);
#[derive(Debug, Clone, PartialEq)]
pub struct FunctionArn(pub String);
#[derive(Debug, Clone, PartialEq)]
pub struct XRayTraceId(pub String);
#[derive(Debug, Clone, PartialEq)]
struct MobileClientContext(String);
#[derive(Debug, Clone, PartialEq)]
struct MobileClientIdentity(String);
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct ClientContext {
pub client: ClientApplication,
pub custom: HashMap<String, String>,
pub environment: HashMap<String, String>,
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct ClientApplication {
pub installation_id: String,
pub app_title: String,
pub app_version_name: String,
pub app_version_code: String,
pub app_package_name: String,
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct CognitoIdentity {
pub identity_id: String,
pub identity_pool_id: String,
}
#[non_exhaustive]
#[derive(Clone, Debug, PartialEq, Default, Serialize, Deserialize)]
pub struct Context {
pub request_id: String,
pub deadline: u64,
pub invoked_function_arn: String,
pub xray_trace_id: String,
pub client_context: Option<ClientContext>,
pub identity: Option<CognitoIdentity>,
pub env_config: Config,
}
impl TryFrom<HeaderMap> for Context {
type Error = Error;
fn try_from(headers: HeaderMap) -> Result<Self, Self::Error> {
let ctx = Context {
request_id: headers
.get("lambda-runtime-aws-request-id")
.expect("missing lambda-runtime-aws-request-id header")
.to_str()?
.to_owned(),
deadline: headers
.get("lambda-runtime-deadline-ms")
.expect("missing lambda-runtime-deadline-ms header")
.to_str()?
.parse::<u64>()?,
invoked_function_arn: headers
.get("lambda-runtime-invoked-function-arn")
.unwrap_or(&HeaderValue::from_static(
"No header lambda-runtime-invoked-function-arn found.",
))
.to_str()?
.to_owned(),
xray_trace_id: headers
.get("lambda-runtime-trace-id")
.unwrap_or(&HeaderValue::from_static(""))
.to_str()?
.to_owned(),
..Default::default()
};
Ok(ctx)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn context_with_expected_values_and_types_resolves() {
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-invoked-function-arn",
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
let tried = Context::try_from(headers);
assert!(tried.is_ok());
}
#[test]
fn context_with_certain_missing_headers_still_resolves() {
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
let tried = Context::try_from(headers);
assert!(tried.is_ok());
}
#[test]
fn context_with_bad_deadline_type_is_err() {
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert(
"lambda-runtime-deadline-ms",
HeaderValue::from_static("BAD-Type,not <u64>"),
);
headers.insert(
"lambda-runtime-invoked-function-arn",
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
let tried = Context::try_from(headers);
assert!(tried.is_err());
}
#[test]
#[should_panic]
#[allow(unused_must_use)]
fn context_with_missing_request_id_should_panic() {
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert(
"lambda-runtime-invoked-function-arn",
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
Context::try_from(headers);
}
#[test]
#[should_panic]
#[allow(unused_must_use)]
fn context_with_missing_deadline_should_panic() {
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-invoked-function-arn",
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
Context::try_from(headers);
}
}
impl Context {
pub fn with_config(self, config: &Config) -> Self {
Self {
env_config: config.clone(),
..self
}
}
}