use crate::{Error, RefConfig};
use base64::prelude::*;
use bytes::Bytes;
use http::{header::ToStrError, HeaderMap, HeaderValue, StatusCode};
use lambda_runtime_api_client::body::Body;
use serde::{Deserialize, Serialize};
use std::{
borrow::Cow,
collections::HashMap,
fmt::{Debug, Display},
time::{Duration, SystemTime},
};
use tokio_stream::Stream;
#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Diagnostic<'a> {
pub error_type: Cow<'a, str>,
pub error_message: Cow<'a, str>,
}
impl<'a, T> From<T> for Diagnostic<'a>
where
T: Display,
{
fn from(value: T) -> Self {
Diagnostic {
error_type: Cow::Borrowed(std::any::type_name::<T>()),
error_message: Cow::Owned(format!("{value}")),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
pub struct ClientContext {
#[serde(default)]
pub client: ClientApplication,
#[serde(default)]
pub custom: HashMap<String, String>,
#[serde(default)]
pub environment: HashMap<String, String>,
}
#[derive(Serialize, Deserialize, Default, Clone, Debug, Eq, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct ClientApplication {
#[serde(alias = "installation_id")]
pub installation_id: String,
#[serde(alias = "app_title")]
pub app_title: String,
#[serde(alias = "app_version_name")]
pub app_version_name: String,
#[serde(alias = "app_version_code")]
pub app_version_code: String,
#[serde(alias = "app_package_name")]
pub app_package_name: String,
}
#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct CognitoIdentity {
#[serde(alias = "cognitoIdentityId", alias = "identity_id")]
pub identity_id: String,
#[serde(alias = "cognitoIdentityPoolId", alias = "identity_pool_id")]
pub identity_pool_id: String,
}
#[non_exhaustive]
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct Context {
pub request_id: String,
pub deadline: u64,
pub invoked_function_arn: String,
pub xray_trace_id: Option<String>,
pub client_context: Option<ClientContext>,
pub identity: Option<CognitoIdentity>,
pub env_config: RefConfig,
}
impl Default for Context {
fn default() -> Context {
Context {
request_id: "".to_owned(),
deadline: 0,
invoked_function_arn: "".to_owned(),
xray_trace_id: None,
client_context: None,
identity: None,
env_config: std::sync::Arc::new(crate::Config::default()),
}
}
}
impl Context {
pub fn new(request_id: &str, env_config: RefConfig, headers: &HeaderMap) -> Result<Self, Error> {
let client_context: Option<ClientContext> = if let Some(value) = headers.get("lambda-runtime-client-context") {
serde_json::from_str(value.to_str()?)?
} else {
None
};
let identity: Option<CognitoIdentity> = if let Some(value) = headers.get("lambda-runtime-cognito-identity") {
serde_json::from_str(value.to_str()?)?
} else {
None
};
let ctx = Context {
request_id: request_id.to_owned(),
deadline: headers
.get("lambda-runtime-deadline-ms")
.expect("missing lambda-runtime-deadline-ms header")
.to_str()?
.parse::<u64>()?,
invoked_function_arn: headers
.get("lambda-runtime-invoked-function-arn")
.unwrap_or(&HeaderValue::from_static(
"No header lambda-runtime-invoked-function-arn found.",
))
.to_str()?
.to_owned(),
xray_trace_id: headers
.get("lambda-runtime-trace-id")
.map(|v| String::from_utf8_lossy(v.as_bytes()).to_string()),
client_context,
identity,
env_config,
};
Ok(ctx)
}
pub fn deadline(&self) -> SystemTime {
SystemTime::UNIX_EPOCH + Duration::from_millis(self.deadline)
}
}
pub(crate) fn invoke_request_id(headers: &HeaderMap) -> Result<&str, ToStrError> {
headers
.get("lambda-runtime-aws-request-id")
.expect("missing lambda-runtime-aws-request-id header")
.to_str()
}
#[derive(Clone, Debug)]
pub struct LambdaEvent<T> {
pub payload: T,
pub context: Context,
}
impl<T> LambdaEvent<T> {
pub fn new(payload: T, context: Context) -> Self {
Self { payload, context }
}
pub fn into_parts(self) -> (T, Context) {
(self.payload, self.context)
}
}
#[derive(Debug, Default, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct MetadataPrelude {
#[serde(with = "http_serde::status_code")]
pub status_code: StatusCode,
#[serde(with = "http_serde::header_map")]
pub headers: HeaderMap,
pub cookies: Vec<String>,
}
pub trait ToStreamErrorTrailer {
fn to_tailer(&self) -> String;
}
impl ToStreamErrorTrailer for Error {
fn to_tailer(&self) -> String {
format!(
"Lambda-Runtime-Function-Error-Type: Runtime.StreamError\r\nLambda-Runtime-Function-Error-Body: {}\r\n",
BASE64_STANDARD.encode(self.to_string())
)
}
}
#[derive(Debug)]
pub struct StreamResponse<S> {
pub metadata_prelude: MetadataPrelude,
pub stream: S,
}
pub enum FunctionResponse<B, S> {
BufferedResponse(B),
StreamingResponse(StreamResponse<S>),
}
pub trait IntoFunctionResponse<B, S> {
fn into_response(self) -> FunctionResponse<B, S>;
}
impl<B, S> IntoFunctionResponse<B, S> for FunctionResponse<B, S> {
fn into_response(self) -> FunctionResponse<B, S> {
self
}
}
impl<B> IntoFunctionResponse<B, Body> for B
where
B: Serialize,
{
fn into_response(self) -> FunctionResponse<B, Body> {
FunctionResponse::BufferedResponse(self)
}
}
impl<S, D, E> IntoFunctionResponse<(), S> for StreamResponse<S>
where
S: Stream<Item = Result<D, E>> + Unpin + Send + 'static,
D: Into<Bytes> + Send,
E: Into<Error> + Send + Debug,
{
fn into_response(self) -> FunctionResponse<(), S> {
FunctionResponse::StreamingResponse(self)
}
}
impl<S, D, E> From<S> for StreamResponse<S>
where
S: Stream<Item = Result<D, E>> + Unpin + Send + 'static,
D: Into<Bytes> + Send,
E: Into<Error> + Send + Debug,
{
fn from(value: S) -> Self {
StreamResponse {
metadata_prelude: Default::default(),
stream: value,
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::Config;
use std::sync::Arc;
#[test]
fn round_trip_lambda_error() {
use serde_json::{json, Value};
let expected = json!({
"errorType": "InvalidEventDataError",
"errorMessage": "Error parsing event data.",
});
let actual = Diagnostic {
error_type: Cow::Borrowed("InvalidEventDataError"),
error_message: Cow::Borrowed("Error parsing event data."),
};
let actual: Value = serde_json::to_value(actual).expect("failed to serialize diagnostic");
assert_eq!(expected, actual);
}
#[test]
fn context_with_expected_values_and_types_resolves() {
let config = Arc::new(Config::default());
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-invoked-function-arn",
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
let tried = Context::new("id", config, &headers);
assert!(tried.is_ok());
}
#[test]
fn context_with_certain_missing_headers_still_resolves() {
let config = Arc::new(Config::default());
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
let tried = Context::new("id", config, &headers);
assert!(tried.is_ok());
}
#[test]
fn context_with_client_context_resolves() {
let mut custom = HashMap::new();
custom.insert("key".to_string(), "value".to_string());
let mut environment = HashMap::new();
environment.insert("key".to_string(), "value".to_string());
let client_context = ClientContext {
client: ClientApplication {
installation_id: String::new(),
app_title: String::new(),
app_version_name: String::new(),
app_version_code: String::new(),
app_package_name: String::new(),
},
custom,
environment,
};
let client_context_str = serde_json::to_string(&client_context).unwrap();
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-client-context",
HeaderValue::from_str(&client_context_str).unwrap(),
);
let config = Arc::new(Config::default());
let tried = Context::new("id", config, &headers);
assert!(tried.is_ok());
let tried = tried.unwrap();
assert!(tried.client_context.is_some());
assert_eq!(tried.client_context.unwrap(), client_context);
}
#[test]
fn context_with_empty_client_context_resolves() {
let config = Arc::new(Config::default());
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert("lambda-runtime-client-context", HeaderValue::from_static("{}"));
let tried = Context::new("id", config, &headers);
assert!(tried.is_ok());
assert!(tried.unwrap().client_context.is_some());
}
#[test]
fn context_with_identity_resolves() {
let config = Arc::new(Config::default());
let cognito_identity = CognitoIdentity {
identity_id: String::new(),
identity_pool_id: String::new(),
};
let cognito_identity_str = serde_json::to_string(&cognito_identity).unwrap();
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-cognito-identity",
HeaderValue::from_str(&cognito_identity_str).unwrap(),
);
let tried = Context::new("id", config, &headers);
assert!(tried.is_ok());
let tried = tried.unwrap();
assert!(tried.identity.is_some());
assert_eq!(tried.identity.unwrap(), cognito_identity);
}
#[test]
fn context_with_bad_deadline_type_is_err() {
let config = Arc::new(Config::default());
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert(
"lambda-runtime-deadline-ms",
HeaderValue::from_static("BAD-Type,not <u64>"),
);
headers.insert(
"lambda-runtime-invoked-function-arn",
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
let tried = Context::new("id", config, &headers);
assert!(tried.is_err());
}
#[test]
fn context_with_bad_client_context_is_err() {
let config = Arc::new(Config::default());
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-client-context",
HeaderValue::from_static("BAD-Type,not JSON"),
);
let tried = Context::new("id", config, &headers);
assert!(tried.is_err());
}
#[test]
fn context_with_empty_identity_is_err() {
let config = Arc::new(Config::default());
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert("lambda-runtime-cognito-identity", HeaderValue::from_static("{}"));
let tried = Context::new("id", config, &headers);
assert!(tried.is_err());
}
#[test]
fn context_with_bad_identity_is_err() {
let config = Arc::new(Config::default());
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-cognito-identity",
HeaderValue::from_static("BAD-Type,not JSON"),
);
let tried = Context::new("id", config, &headers);
assert!(tried.is_err());
}
#[test]
#[should_panic]
fn context_with_missing_deadline_should_panic() {
let config = Arc::new(Config::default());
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert(
"lambda-runtime-invoked-function-arn",
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
let _ = Context::new("id", config, &headers);
}
#[test]
fn invoke_request_id_should_not_panic() {
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-aws-request-id", HeaderValue::from_static("my-id"));
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-invoked-function-arn",
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
let _ = invoke_request_id(&headers);
}
#[test]
#[should_panic]
fn invoke_request_id_should_panic() {
let mut headers = HeaderMap::new();
headers.insert("lambda-runtime-deadline-ms", HeaderValue::from_static("123"));
headers.insert(
"lambda-runtime-invoked-function-arn",
HeaderValue::from_static("arn::myarn"),
);
headers.insert("lambda-runtime-trace-id", HeaderValue::from_static("arn::myarn"));
let _ = invoke_request_id(&headers);
}
}