use http::{
header::{HeaderName, HeaderValue},
Method, StatusCode,
};
use hyper::body::Incoming;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::client::legacy::Client;
use lambda_http::request::RequestContext;
use lambda_http::Body;
pub use lambda_http::Error;
use lambda_http::{Request, RequestExt, Response};
use std::fmt::Debug;
use std::{
env,
future::Future,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tokio::net::TcpStream;
use tokio::time::timeout;
use tokio_retry::{strategy::FixedInterval, Retry};
use tower::{Service, ServiceBuilder};
use tower_http::compression::CompressionLayer;
use url::Url;
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub enum Protocol {
#[default]
Http,
Tcp,
}
impl From<&str> for Protocol {
fn from(value: &str) -> Self {
match value.to_lowercase().as_str() {
"http" => Protocol::Http,
"tcp" => Protocol::Tcp,
_ => Protocol::Http,
}
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub enum LambdaInvokeMode {
#[default]
Buffered,
ResponseStream,
}
impl From<&str> for LambdaInvokeMode {
fn from(value: &str) -> Self {
match value.to_lowercase().as_str() {
"buffered" => LambdaInvokeMode::Buffered,
"response_stream" => LambdaInvokeMode::ResponseStream,
_ => LambdaInvokeMode::Buffered,
}
}
}
pub struct AdapterOptions {
pub host: String,
pub port: String,
pub readiness_check_port: String,
pub readiness_check_path: String,
pub readiness_check_protocol: Protocol,
pub readiness_check_min_unhealthy_status: u16,
pub base_path: Option<String>,
pub pass_through_path: String,
pub async_init: bool,
pub compression: bool,
pub invoke_mode: LambdaInvokeMode,
}
impl Default for AdapterOptions {
fn default() -> Self {
AdapterOptions {
host: env::var("AWS_LWA_HOST").unwrap_or(env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string())),
port: env::var("AWS_LWA_PORT").unwrap_or(env::var("PORT").unwrap_or_else(|_| "8080".to_string())),
readiness_check_port: env::var("AWS_LWA_READINESS_CHECK_PORT").unwrap_or(
env::var("READINESS_CHECK_PORT").unwrap_or(
env::var("AWS_LWA_PORT")
.unwrap_or_else(|_| env::var("PORT").unwrap_or_else(|_| "8080".to_string())),
),
),
readiness_check_min_unhealthy_status: env::var("AWS_LWA_READINESS_CHECK_MIN_UNHEALTHY_STATUS")
.unwrap_or_else(|_| "500".to_string())
.parse()
.unwrap_or(500),
readiness_check_path: env::var("AWS_LWA_READINESS_CHECK_PATH")
.unwrap_or(env::var("READINESS_CHECK_PATH").unwrap_or_else(|_| "/".to_string())),
readiness_check_protocol: env::var("AWS_LWA_READINESS_CHECK_PROTOCOL")
.unwrap_or(env::var("READINESS_CHECK_PROTOCOL").unwrap_or_else(|_| "HTTP".to_string()))
.as_str()
.into(),
base_path: env::var("AWS_LWA_REMOVE_BASE_PATH").map_or_else(|_| env::var("REMOVE_BASE_PATH").ok(), Some),
pass_through_path: env::var("AWS_LWA_PASS_THROUGH_PATH").unwrap_or_else(|_| "/events".to_string()),
async_init: env::var("AWS_LWA_ASYNC_INIT")
.unwrap_or(env::var("ASYNC_INIT").unwrap_or_else(|_| "false".to_string()))
.parse()
.unwrap_or(false),
compression: env::var("AWS_LWA_ENABLE_COMPRESSION")
.unwrap_or_else(|_| "false".to_string())
.parse()
.unwrap_or(false),
invoke_mode: env::var("AWS_LWA_INVOKE_MODE")
.unwrap_or("buffered".to_string())
.as_str()
.into(),
}
}
}
#[derive(Clone)]
pub struct Adapter<C, B> {
client: Arc<Client<C, B>>,
healthcheck_url: Url,
healthcheck_protocol: Protocol,
healthcheck_min_unhealthy_status: u16,
async_init: bool,
ready_at_init: Arc<AtomicBool>,
domain: Url,
base_path: Option<String>,
path_through_path: String,
compression: bool,
invoke_mode: LambdaInvokeMode,
}
impl Adapter<HttpConnector, Body> {
pub fn new(options: &AdapterOptions) -> Adapter<HttpConnector, Body> {
let client = Client::builder(hyper_util::rt::TokioExecutor::new())
.pool_idle_timeout(Duration::from_secs(4))
.build(HttpConnector::new());
let schema = "http";
let healthcheck_url = format!(
"{}://{}:{}{}",
schema, options.host, options.readiness_check_port, options.readiness_check_path
)
.parse()
.unwrap();
let domain = format!("{}://{}:{}", schema, options.host, options.port)
.parse()
.unwrap();
Adapter {
client: Arc::new(client),
healthcheck_url,
healthcheck_protocol: options.readiness_check_protocol,
healthcheck_min_unhealthy_status: options.readiness_check_min_unhealthy_status,
domain,
base_path: options.base_path.clone(),
path_through_path: options.pass_through_path.clone(),
async_init: options.async_init,
ready_at_init: Arc::new(AtomicBool::new(false)),
compression: options.compression,
invoke_mode: options.invoke_mode,
}
}
}
impl Adapter<HttpConnector, Body> {
pub fn register_default_extension(&self) {
tokio::task::spawn(async move {
let aws_lambda_runtime_api: String =
env::var("AWS_LAMBDA_RUNTIME_API").unwrap_or_else(|_| "127.0.0.1:9001".to_string());
let client = Client::builder(hyper_util::rt::TokioExecutor::new()).build(HttpConnector::new());
let register_req = hyper::Request::builder()
.method(Method::POST)
.uri(format!("http://{aws_lambda_runtime_api}/2020-01-01/extension/register"))
.header("Lambda-Extension-Name", "lambda-adapter")
.body(Body::from("{ \"events\": [] }"))
.unwrap();
let register_res = client.request(register_req).await.unwrap();
if register_res.status() != StatusCode::OK {
panic!("extension registration failure");
}
let next_req = hyper::Request::builder()
.method(Method::GET)
.uri(format!(
"http://{aws_lambda_runtime_api}/2020-01-01/extension/event/next"
))
.header(
"Lambda-Extension-Identifier",
register_res.headers().get("Lambda-Extension-Identifier").unwrap(),
)
.body(Body::Empty)
.unwrap();
client.request(next_req).await.unwrap();
});
}
pub async fn check_init_health(&mut self) {
let ready_at_init = if self.async_init {
timeout(Duration::from_secs_f32(9.8), self.check_readiness())
.await
.unwrap_or_default()
} else {
self.check_readiness().await
};
self.ready_at_init.store(ready_at_init, Ordering::SeqCst);
}
async fn check_readiness(&self) -> bool {
let url = self.healthcheck_url.clone();
let protocol = self.healthcheck_protocol;
self.is_web_ready(&url, &protocol).await
}
async fn is_web_ready(&self, url: &Url, protocol: &Protocol) -> bool {
Retry::spawn(FixedInterval::from_millis(10), || {
self.check_web_readiness(url, protocol)
})
.await
.is_ok()
}
async fn check_web_readiness(&self, url: &Url, protocol: &Protocol) -> Result<(), i8> {
match protocol {
Protocol::Http => match self.client.get(url.to_string().parse().unwrap()).await {
Ok(response)
if {
self.healthcheck_min_unhealthy_status > response.status().as_u16()
&& response.status().as_u16() >= 100
} =>
{
Ok(())
}
_ => {
tracing::debug!("app is not ready");
Err(-1)
}
},
Protocol::Tcp => match TcpStream::connect(format!("{}:{}", url.host().unwrap(), url.port().unwrap())).await
{
Ok(_) => Ok(()),
Err(_) => Err(-1),
},
}
}
pub async fn run(self) -> Result<(), Error> {
let compression = self.compression;
let invoke_mode = self.invoke_mode;
if compression {
let svc = ServiceBuilder::new().layer(CompressionLayer::new()).service(self);
match invoke_mode {
LambdaInvokeMode::Buffered => lambda_http::run(svc).await,
LambdaInvokeMode::ResponseStream => lambda_http::run_with_streaming_response(svc).await,
}
} else {
match invoke_mode {
LambdaInvokeMode::Buffered => lambda_http::run(self).await,
LambdaInvokeMode::ResponseStream => lambda_http::run_with_streaming_response(self).await,
}
}
}
async fn fetch_response(&self, event: Request) -> Result<Response<Incoming>, Error> {
if self.async_init && !self.ready_at_init.load(Ordering::SeqCst) {
self.is_web_ready(&self.healthcheck_url, &self.healthcheck_protocol)
.await;
self.ready_at_init.store(true, Ordering::SeqCst);
}
let request_context = event.request_context();
let lambda_context = event.lambda_context();
let path = event.raw_http_path().to_string();
let mut path = path.as_str();
let (parts, body) = event.into_parts();
if let Some(base_path) = self.base_path.as_deref() {
path = path.trim_start_matches(base_path);
}
if let RequestContext::PassThrough = request_context {
path = self.path_through_path.as_str();
}
let mut req_headers = parts.headers;
req_headers.insert(
HeaderName::from_static("x-amzn-request-context"),
HeaderValue::from_bytes(serde_json::to_string(&request_context)?.as_bytes())?,
);
req_headers.insert(
HeaderName::from_static("x-amzn-lambda-context"),
HeaderValue::from_bytes(serde_json::to_string(&lambda_context)?.as_bytes())?,
);
let mut app_url = self.domain.clone();
app_url.set_path(path);
app_url.set_query(parts.uri.query());
tracing::debug!(app_url = %app_url, req_headers = ?req_headers, "sending request to app server");
let mut builder = hyper::Request::builder().method(parts.method).uri(app_url.to_string());
if let Some(headers) = builder.headers_mut() {
headers.extend(req_headers);
}
let request = builder.body(Body::from(body.to_vec()))?;
let app_response = self.client.request(request).await?;
Ok(app_response)
}
}
impl Service<Request> for Adapter<HttpConnector, Body> {
type Response = Response<Incoming>;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut core::task::Context<'_>) -> core::task::Poll<Result<(), Self::Error>> {
core::task::Poll::Ready(Ok(()))
}
fn call(&mut self, event: Request) -> Self::Future {
let adapter = self.clone();
Box::pin(async move { adapter.fetch_response(event).await })
}
}
#[cfg(test)]
mod tests {
use super::*;
use httpmock::{Method::GET, MockServer};
#[tokio::test]
async fn test_status_200_is_ok() {
let app_server = MockServer::start();
let healthcheck = app_server.mock(|when, then| {
when.method(GET).path("/healthcheck");
then.status(200).body("OK");
});
let options = AdapterOptions {
host: app_server.host(),
port: app_server.port().to_string(),
readiness_check_port: app_server.port().to_string(),
readiness_check_path: "/healthcheck".to_string(),
..Default::default()
};
let adapter = Adapter::new(&options);
let url = adapter.healthcheck_url.clone();
let protocol = adapter.healthcheck_protocol;
assert!(adapter.check_web_readiness(&url, &protocol).await.is_ok());
healthcheck.assert();
}
#[tokio::test]
async fn test_status_500_is_bad() {
let app_server = MockServer::start();
let healthcheck = app_server.mock(|when, then| {
when.method(GET).path("/healthcheck");
then.status(500).body("OK");
});
let options = AdapterOptions {
host: app_server.host(),
port: app_server.port().to_string(),
readiness_check_port: app_server.port().to_string(),
readiness_check_path: "/healthcheck".to_string(),
..Default::default()
};
let adapter = Adapter::new(&options);
let url = adapter.healthcheck_url.clone();
let protocol = adapter.healthcheck_protocol;
assert!(adapter.check_web_readiness(&url, &protocol).await.is_err());
healthcheck.assert();
}
#[tokio::test]
async fn test_status_403_is_bad_when_configured() {
let app_server = MockServer::start();
let healthcheck = app_server.mock(|when, then| {
when.method(GET).path("/healthcheck");
then.status(403).body("OK");
});
let options = AdapterOptions {
host: app_server.host(),
port: app_server.port().to_string(),
readiness_check_port: app_server.port().to_string(),
readiness_check_path: "/healthcheck".to_string(),
readiness_check_min_unhealthy_status: 400,
..Default::default()
};
let adapter = Adapter::new(&options);
let url = adapter.healthcheck_url.clone();
let protocol = adapter.healthcheck_protocol;
assert!(adapter.check_web_readiness(&url, &protocol).await.is_err());
healthcheck.assert();
}
}