use crate::error::{ApiError, ApiErrorKind, ErrorResponse};
use failure::ResultExt;
use hyper::{
client::HttpConnector,
header::{self, HeaderMap, HeaderValue},
rt::{Future, Stream},
Body, Client, Method, Request, Uri,
};
use log::*;
use serde_derive::*;
use serde_json;
use std::{collections::HashMap, fmt};
use tokio::runtime::Runtime;
const RUNTIME_API_VERSION: &str = "2018-06-01";
const API_CONTENT_TYPE: &str = "application/json";
const API_ERROR_CONTENT_TYPE: &str = "application/vnd.aws.lambda.error+json";
const RUNTIME_ERROR_HEADER: &str = "Lambda-Runtime-Function-Error-Type";
const DEFAULT_AGENT: &str = "AWS_Lambda_Rust";
pub enum LambdaHeaders {
RequestId,
FunctionArn,
TraceId,
Deadline,
ClientContext,
CognitoIdentity,
}
impl LambdaHeaders {
fn as_str(&self) -> &'static str {
match self {
LambdaHeaders::RequestId => "Lambda-Runtime-Aws-Request-Id",
LambdaHeaders::FunctionArn => "Lambda-Runtime-Invoked-Function-Arn",
LambdaHeaders::TraceId => "Lambda-Runtime-Trace-Id",
LambdaHeaders::Deadline => "Lambda-Runtime-Deadline-Ms",
LambdaHeaders::ClientContext => "Lambda-Runtime-Client-Context",
LambdaHeaders::CognitoIdentity => "Lambda-Runtime-Cognito-Identity",
}
}
}
impl fmt::Display for LambdaHeaders {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Deserialize, Clone)]
pub struct ClientApplication {
#[serde(rename = "installationId")]
pub installation_id: String,
#[serde(rename = "appTitle")]
pub app_title: String,
#[serde(rename = "appVersionName")]
pub app_version_name: String,
#[serde(rename = "appVersionCode")]
pub app_version_code: String,
#[serde(rename = "appPackageName")]
pub app_package_name: String,
}
#[derive(Deserialize, Clone)]
pub struct ClientContext {
pub client: ClientApplication,
pub custom: HashMap<String, String>,
pub environment: HashMap<String, String>,
}
#[derive(Deserialize, Clone)]
pub struct CognitoIdentity {
pub identity_id: String,
pub identity_pool_id: String,
}
#[derive(Clone)]
pub struct EventContext {
pub invoked_function_arn: String,
pub aws_request_id: String,
pub xray_trace_id: String,
pub deadline: i64,
pub client_context: Option<ClientContext>,
pub identity: Option<CognitoIdentity>,
}
pub struct RuntimeClient {
_runtime: Runtime,
http_client: Client<HttpConnector, Body>,
next_endpoint: Uri,
runtime_agent: String,
host: String,
}
impl<'ev> RuntimeClient {
pub fn new(host: &str, agent: Option<String>, runtime: Option<Runtime>) -> Result<Self, ApiError> {
debug!("Starting new HttpRuntimeClient for {}", host);
let runtime_agent = match agent {
Some(a) => a,
None => DEFAULT_AGENT.to_owned(),
};
let runtime = match runtime {
Some(r) => r,
None => Runtime::new().context(ApiErrorKind::Unrecoverable("Could not initialize runtime".to_string()))?,
};
let http_client = Client::builder().executor(runtime.executor()).build_http();
let next_endpoint = format!("http://{}/{}/runtime/invocation/next", host, RUNTIME_API_VERSION)
.parse::<Uri>()
.context(ApiErrorKind::Unrecoverable("Could not parse API uri".to_string()))?;
Ok(RuntimeClient {
_runtime: runtime,
http_client,
next_endpoint,
runtime_agent,
host: host.to_owned(),
})
}
}
impl<'ev> RuntimeClient {
pub fn next_event(&self) -> Result<(Vec<u8>, EventContext), ApiError> {
trace!("Polling for next event");
let resp = self
.http_client
.get(self.next_endpoint.clone())
.wait()
.context(ApiErrorKind::Unrecoverable("Could not fetch next event".to_string()))?;
if resp.status().is_client_error() {
error!(
"Runtime API returned client error when polling for new events: {}",
resp.status()
);
Err(ApiErrorKind::Recoverable(format!(
"Error {} when polling for events",
resp.status()
)))?;
}
if resp.status().is_server_error() {
error!(
"Runtime API returned server error when polling for new events: {}",
resp.status()
);
Err(ApiErrorKind::Unrecoverable(
"Server error when polling for new events".to_string(),
))?;
}
let ctx = self.get_event_context(&resp.headers())?;
let out = resp
.into_body()
.concat2()
.wait()
.context(ApiErrorKind::Recoverable("Could not read event boxy".to_string()))?;
let buf = out.into_bytes().to_vec();
trace!(
"Received new event for request id {}. Event length {} bytes",
ctx.aws_request_id,
buf.len()
);
Ok((buf, ctx))
}
pub fn event_response(&self, request_id: &str, output: &[u8]) -> Result<(), ApiError> {
trace!(
"Posting response for request {} to Runtime API. Response length {} bytes",
request_id,
output.len()
);
let uri = format!(
"http://{}/{}/runtime/invocation/{}/response",
self.host, RUNTIME_API_VERSION, request_id
)
.parse::<Uri>()
.context(ApiErrorKind::Unrecoverable(
"Could not generate response uri".to_owned(),
))?;
let req = self.get_runtime_post_request(&uri, output);
let resp = self
.http_client
.request(req)
.wait()
.context(ApiErrorKind::Recoverable("Could not post event response".to_string()))?;
if !resp.status().is_success() {
error!(
"Error from Runtime API when posting response for request {}: {}",
request_id,
resp.status()
);
Err(ApiErrorKind::Recoverable(format!(
"Error {} while sending response",
resp.status()
)))?;
}
trace!("Posted response to Runtime API for request {}", request_id);
Ok(())
}
pub fn event_error(&self, request_id: &str, e: &ErrorResponse) -> Result<(), ApiError> {
trace!(
"Posting error to runtime API for request {}: {}",
request_id,
e.error_message
);
let uri = format!(
"http://{}/{}/runtime/invocation/{}/error",
self.host, RUNTIME_API_VERSION, request_id
)
.parse::<Uri>()
.context(ApiErrorKind::Unrecoverable(
"Could not generate response uri".to_owned(),
))?;
let req = self.get_runtime_error_request(&uri, &e);
let resp = self.http_client.request(req).wait().context(ApiErrorKind::Recoverable(
"Could not post event error response".to_string(),
))?;
if !resp.status().is_success() {
error!(
"Error from Runtime API when posting error response for request {}: {}",
request_id,
resp.status()
);
Err(ApiErrorKind::Recoverable(format!(
"Error {} while sending response",
resp.status()
)))?;
}
trace!("Posted error response for request id {}", request_id);
Ok(())
}
pub fn fail_init(&self, e: &ErrorResponse) {
error!("Calling fail_init Runtime API: {}", e.error_message);
let uri = format!("http://{}/{}/runtime/init/error", self.host, RUNTIME_API_VERSION)
.parse::<Uri>()
.map_err(|e| {
error!("Could not parse fail init URI: {}", e);
panic!("Killing runtime");
});
let req = self.get_runtime_error_request(&uri.unwrap(), &e);
self.http_client
.request(req)
.wait()
.map_err(|e| {
error!("Error while sending init failed message: {}", e);
panic!("Error while sending init failed message: {}", e);
})
.map(|resp| {
info!("Successfully sent error response to the runtime API: {:?}", resp);
})
.expect("Could not complete init_fail request");
}
pub fn get_endpoint(&self) -> &str {
&self.host
}
fn get_runtime_post_request(&self, uri: &Uri, body: &[u8]) -> Request<Body> {
Request::builder()
.method(Method::POST)
.uri(uri.clone())
.header(header::CONTENT_TYPE, header::HeaderValue::from_static(API_CONTENT_TYPE))
.header(header::USER_AGENT, self.runtime_agent.clone())
.body(Body::from(body.to_owned()))
.unwrap()
}
fn get_runtime_error_request(&self, uri: &Uri, e: &ErrorResponse) -> Request<Body> {
let body = serde_json::to_vec(&e).expect("Could not turn error object into response JSON");
Request::builder()
.method(Method::POST)
.uri(uri.clone())
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static(API_ERROR_CONTENT_TYPE),
)
.header(header::USER_AGENT, self.runtime_agent.clone())
.header(RUNTIME_ERROR_HEADER, HeaderValue::from_static("Unhandled"))
.body(Body::from(body))
.unwrap()
}
fn get_event_context(&self, headers: &HeaderMap<HeaderValue>) -> Result<EventContext, ApiError> {
let aws_request_id = header_string(
headers.get(LambdaHeaders::RequestId.as_str()),
&LambdaHeaders::RequestId,
)?;
let invoked_function_arn = header_string(
headers.get(LambdaHeaders::FunctionArn.as_str()),
&LambdaHeaders::FunctionArn,
)?;
let xray_trace_id = header_string(headers.get(LambdaHeaders::TraceId.as_str()), &LambdaHeaders::TraceId)?;
let deadline = header_string(headers.get(LambdaHeaders::Deadline.as_str()), &LambdaHeaders::Deadline)?
.parse::<i64>()
.context(ApiErrorKind::Recoverable(
"Could not parse deadline header value to int".to_string(),
))?;
let mut ctx = EventContext {
aws_request_id,
invoked_function_arn,
xray_trace_id,
deadline,
client_context: Option::default(),
identity: Option::default(),
};
if let Some(ctx_json) = headers.get(LambdaHeaders::ClientContext.as_str()) {
let ctx_json = ctx_json.to_str().context(ApiErrorKind::Recoverable(
"Could not convert context header content to string".to_string(),
))?;
trace!("Found Client Context in response headers: {}", ctx_json);
let ctx_value: ClientContext = serde_json::from_str(&ctx_json).context(ApiErrorKind::Recoverable(
"Could not parse client context value as json object".to_string(),
))?;
ctx.client_context = Option::from(ctx_value);
};
if let Some(cognito_json) = headers.get(LambdaHeaders::CognitoIdentity.as_str()) {
let cognito_json = cognito_json.to_str().context(ApiErrorKind::Recoverable(
"Could not convert congnito context header content to string".to_string(),
))?;
trace!("Found Cognito Identity in response headers: {}", cognito_json);
let identity_value: CognitoIdentity = serde_json::from_str(&cognito_json).context(
ApiErrorKind::Recoverable("Could not parse cognito context value as json object".to_string()),
)?;
ctx.identity = Option::from(identity_value);
};
Ok(ctx)
}
}
fn header_string(value: Option<&HeaderValue>, header_type: &LambdaHeaders) -> Result<String, ApiError> {
match value {
Some(value_str) => Ok(value_str
.to_str()
.context(ApiErrorKind::Recoverable(format!(
"Could not parse {} header",
header_type
)))?
.to_owned()),
None => {
error!("Response headers do not contain {} header", header_type);
Err(ApiErrorKind::Recoverable(format!("Missing {} header", header_type)))?
}
}
}