use std::{
borrow::Cow,
env,
future::Future,
mem,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use encoding_rs::{Encoding, UTF_8};
use http::{
header::{HeaderName, HeaderValue},
Method, StatusCode, Uri,
};
use hyper::{
body::HttpBody,
client::{Client, HttpConnector},
Body,
};
use lambda_http::aws_lambda_events::serde_json;
pub use lambda_http::Error;
use lambda_http::{Body as LambdaBody, Request, RequestExt, Response};
use tokio::net::TcpStream;
use tokio::time::timeout;
use tokio_retry::{strategy::FixedInterval, Retry};
use tower::Service;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum Protocol {
Http,
Tcp,
}
impl Default for Protocol {
fn default() -> Self {
Protocol::Http
}
}
impl From<&str> for Protocol {
fn from(value: &str) -> Self {
match value.to_lowercase().as_str() {
"http" => Protocol::Http,
"tcp" => Protocol::Tcp,
_ => Protocol::Http,
}
}
}
#[derive(Default)]
pub struct AdapterOptions {
pub host: String,
pub port: String,
pub readiness_check_port: String,
pub readiness_check_path: String,
pub readiness_check_protocol: Protocol,
pub base_path: Option<String>,
pub async_init: bool,
}
impl AdapterOptions {
pub fn from_env() -> Self {
AdapterOptions {
host: env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string()),
port: env::var("PORT").unwrap_or_else(|_| "8080".to_string()),
readiness_check_port: env::var("READINESS_CHECK_PORT")
.unwrap_or_else(|_| env::var("PORT").unwrap_or_else(|_| "8080".to_string())),
readiness_check_path: env::var("READINESS_CHECK_PATH").unwrap_or_else(|_| "/".to_string()),
readiness_check_protocol: env::var("READINESS_CHECK_PROTOCOL")
.unwrap_or_else(|_| "HTTP".to_string())
.as_str()
.into(),
base_path: env::var("REMOVE_BASE_PATH").ok(),
async_init: env::var("ASYNC_INIT")
.unwrap_or_else(|_| "false".to_string())
.parse()
.unwrap_or(false),
}
}
}
pub struct Adapter {
client: Arc<Client<HttpConnector>>,
healthcheck_url: Uri,
healthcheck_protocol: Protocol,
async_init: bool,
ready_at_init: Arc<AtomicBool>,
domain: Uri,
base_path: Option<String>,
}
impl Adapter {
pub fn new(options: &AdapterOptions) -> Adapter {
let client = Client::builder().pool_idle_timeout(Duration::from_secs(4)).build_http();
let healthcheck_url = format!(
"http://{}:{}{}",
options.host, options.readiness_check_port, options.readiness_check_path
)
.parse()
.unwrap();
let domain = format!("http://{}:{}", options.host, options.port).parse().unwrap();
Adapter {
client: Arc::new(client),
healthcheck_url,
healthcheck_protocol: options.readiness_check_protocol,
domain,
base_path: options.base_path.clone(),
async_init: options.async_init,
ready_at_init: Arc::new(AtomicBool::new(false)),
}
}
pub fn with_client(self, client: Client<HttpConnector>) -> Self {
Adapter {
client: Arc::new(client),
..self
}
}
pub fn register_default_extension(&self) {
tokio::task::spawn(async move {
let aws_lambda_runtime_api: String =
env::var("AWS_LAMBDA_RUNTIME_API").unwrap_or_else(|_| "127.0.0.1:9001".to_string());
let client = hyper::Client::new();
let register_req = hyper::Request::builder()
.method(Method::POST)
.uri(format!("http://{aws_lambda_runtime_api}/2020-01-01/extension/register"))
.header("Lambda-Extension-Name", "lambda-adapter")
.body(Body::from("{ \"events\": [] }"))
.unwrap();
let register_res = client.request(register_req).await.unwrap();
if register_res.status() != StatusCode::OK {
panic!("extension registration failure");
}
let next_req = hyper::Request::builder()
.method(Method::GET)
.uri(format!(
"http://{aws_lambda_runtime_api}/2020-01-01/extension/event/next"
))
.header(
"Lambda-Extension-Identifier",
register_res.headers().get("Lambda-Extension-Identifier").unwrap(),
)
.body(Body::empty())
.unwrap();
client.request(next_req).await.unwrap();
});
}
pub async fn check_init_health(&mut self) {
let ready_at_init = if self.async_init {
timeout(Duration::from_secs_f32(9.8), self.check_readiness())
.await
.unwrap_or_default()
} else {
self.check_readiness().await
};
self.ready_at_init.store(ready_at_init, Ordering::SeqCst);
}
async fn check_readiness(&self) -> bool {
let url = self.healthcheck_url.clone();
let protocol = self.healthcheck_protocol;
is_web_ready(&url, &protocol).await
}
pub async fn run(self) -> Result<(), Error> {
lambda_http::run(self).await
}
}
impl Service<Request> for Adapter {
type Response = Response<LambdaBody>;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut core::task::Context<'_>) -> core::task::Poll<Result<(), Self::Error>> {
core::task::Poll::Ready(Ok(()))
}
fn call(&mut self, event: Request) -> Self::Future {
let async_init = self.async_init;
let client = self.client.clone();
let ready_at_init = self.ready_at_init.clone();
let healthcheck_url = self.healthcheck_url.clone();
let healthcheck_protocol = self.healthcheck_protocol;
let domain = self.domain.clone();
let base_path = self.base_path.clone();
Box::pin(async move {
fetch_response(
async_init,
ready_at_init,
client,
base_path,
domain,
healthcheck_url,
healthcheck_protocol,
event,
)
.await
})
}
}
#[allow(clippy::too_many_arguments)]
async fn fetch_response(
async_init: bool,
ready_at_init: Arc<AtomicBool>,
client: Arc<Client<HttpConnector>>,
base_path: Option<String>,
domain: Uri,
healthcheck_url: Uri,
healthcheck_protocol: Protocol,
event: Request,
) -> Result<Response<LambdaBody>, Error> {
if async_init && !ready_at_init.load(Ordering::SeqCst) {
is_web_ready(&healthcheck_url, &healthcheck_protocol).await;
ready_at_init.store(true, Ordering::SeqCst);
}
let request_context = event.request_context();
let path = event.raw_http_path();
let mut path = path.as_str();
let (parts, body) = event.into_parts();
if let Some(base_path) = base_path.as_deref() {
path = path.trim_start_matches(base_path);
}
let mut req_headers = parts.headers;
req_headers.append(
HeaderName::from_static("x-amzn-request-context"),
HeaderValue::from_bytes(serde_json::to_string(&request_context)?.as_bytes())?,
);
let mut pq = path.to_string();
if let Some(q) = parts.uri.query() {
pq.push('?');
pq.push_str(q);
}
let mut app_parts = domain.into_parts();
app_parts.path_and_query = Some(pq.parse()?);
let app_url = Uri::from_parts(app_parts)?;
tracing::debug!(app_url = %app_url, req_headers = ?req_headers, "sending request to app server");
let mut builder = hyper::Request::builder().method(parts.method).uri(app_url);
if let Some(headers) = builder.headers_mut() {
headers.extend(req_headers);
}
let request = builder.body(hyper::Body::from(body.to_vec()))?;
let app_response = client.request(request).await?;
let app_headers = app_response.headers().clone();
let status = app_response.status();
let body = convert_body(app_response).await?;
tracing::debug!(status = %status, body_size = body.size_hint().lower(), app_headers = ?app_headers, "responding to lambda event");
let mut lambda_response = Response::builder();
if let Some(headers) = lambda_response.headers_mut() {
let _ = mem::replace(headers, app_headers);
}
let resp = lambda_response.status(status).body(body).map_err(Box::new)?;
Ok(resp)
}
async fn is_web_ready(url: &Uri, protocol: &Protocol) -> bool {
Retry::spawn(FixedInterval::from_millis(10), || check_web_readiness(url, protocol))
.await
.is_ok()
}
async fn check_web_readiness(url: &Uri, protocol: &Protocol) -> Result<(), i8> {
match protocol {
Protocol::Http => match Client::new().get(url.clone()).await {
Ok(_) => Ok(()),
Err(_) => Err(-1),
},
Protocol::Tcp => match TcpStream::connect(format!("{}:{}", url.host().unwrap(), url.port().unwrap())).await {
Ok(_) => Ok(()),
Err(_) => Err(-1),
},
}
}
async fn convert_body(app_response: hyper::Response<hyper::Body>) -> Result<LambdaBody, Error> {
let (parts, body) = app_response.into_parts();
let body = hyper::body::to_bytes(body).await?;
if parts.headers.get(http::header::CONTENT_ENCODING).is_some() {
return if body.is_empty() {
Ok(LambdaBody::Empty)
} else {
Ok(LambdaBody::Binary(body.to_vec()))
};
}
let content_type = parts
.headers
.get(http::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok());
if serialize_as_text(content_type) {
let content_type = content_type.and_then(|value| value.parse::<mime::Mime>().ok());
let encoding_name = content_type
.as_ref()
.and_then(|mime| mime.get_param("charset").map(|charset| charset.as_str()))
.unwrap_or("utf-8");
let encoding = Encoding::for_label(encoding_name.as_bytes()).unwrap_or(UTF_8);
let (text, _, _) = encoding.decode(&body);
match text {
Cow::Owned(s) => Ok(LambdaBody::Text(s)),
_ => unsafe {
Ok(LambdaBody::Text(String::from_utf8_unchecked(body.to_vec())))
},
}
} else if body.is_empty() {
Ok(LambdaBody::Empty)
} else {
Ok(LambdaBody::Binary(body.to_vec()))
}
}
fn serialize_as_text(content_type: Option<&str>) -> bool {
match content_type {
None => true,
Some(content_type) => {
content_type.starts_with("text")
|| content_type.starts_with("application/json")
|| content_type.starts_with("application/javascript")
|| content_type.starts_with("application/xml")
}
}
}