use lambda_extension::Extension;
use lambda_http::{Body, Request, RequestExt, Response};
use reqwest::{redirect, Client, Url};
use std::{
env,
future::Future,
mem,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use tokio::time::timeout;
use tokio_retry::{strategy::FixedInterval, Retry};
use tower::Service;
type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
#[derive(Default)]
pub struct AdapterOptions {
host: String,
port: String,
readiness_check_port: String,
readiness_check_path: String,
base_path: Option<String>,
async_init: bool,
}
impl AdapterOptions {
pub fn from_env() -> Self {
AdapterOptions {
host: env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string()),
port: env::var("PORT").unwrap_or_else(|_| "8080".to_string()),
readiness_check_port: env::var("READINESS_CHECK_PORT")
.unwrap_or_else(|_| env::var("PORT").unwrap_or_else(|_| "8080".to_string())),
readiness_check_path: env::var("READINESS_CHECK_PATH").unwrap_or_else(|_| "/".to_string()),
base_path: env::var("REMOVE_BASE_PATH").ok(),
async_init: env::var("ASYNC_INIT")
.unwrap_or_else(|_| "false".to_string())
.parse()
.unwrap_or(false),
}
}
}
pub struct Adapter {
client: Arc<Client>,
healthcheck_url: String,
async_init: bool,
ready_at_init: Arc<AtomicBool>,
domain: Url,
base_path: Option<String>,
}
impl Adapter {
pub fn new(options: &AdapterOptions) -> Adapter {
let client = Client::builder()
.redirect(redirect::Policy::none())
.pool_idle_timeout(Duration::from_secs(4))
.build()
.unwrap();
let healthcheck_url = format!(
"http://{}:{}{}",
options.host, options.readiness_check_port, options.readiness_check_path
);
let domain = format!("http://{}:{}", options.host, options.port).parse().unwrap();
Adapter {
client: Arc::new(client),
healthcheck_url,
domain,
base_path: options.base_path.clone(),
async_init: options.async_init,
ready_at_init: Arc::new(AtomicBool::new(false)),
}
}
pub fn with_client(self, client: Client) -> Self {
Adapter {
client: Arc::new(client),
..self
}
}
pub async fn check_init_health(&mut self) {
let ready_at_init = if self.async_init {
timeout(Duration::from_secs_f32(9.8), self.check_readiness())
.await
.unwrap_or_default()
} else {
self.check_readiness().await
};
self.ready_at_init.store(ready_at_init, Ordering::SeqCst);
}
async fn check_readiness(&self) -> bool {
let url = self.healthcheck_url.clone();
is_web_ready(&url).await
}
}
impl Service<Request> for Adapter {
type Response = http::Response<Body>;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn poll_ready(&mut self, _cx: &mut core::task::Context<'_>) -> core::task::Poll<Result<(), Self::Error>> {
core::task::Poll::Ready(Ok(()))
}
fn call(&mut self, event: Request) -> Self::Future {
let async_init = self.async_init;
let client = self.client.clone();
let ready_at_init = self.ready_at_init.clone();
let healthcheck_url = self.healthcheck_url.clone();
let domain = self.domain.clone();
let base_path = self.base_path.clone();
Box::pin(async move {
fetch_response(
async_init,
ready_at_init,
client,
base_path,
domain,
healthcheck_url,
event,
)
.await
})
}
}
async fn fetch_response(
async_init: bool,
ready_at_init: Arc<AtomicBool>,
client: Arc<Client>,
base_path: Option<String>,
domain: Url,
healthcheck_url: String,
event: Request,
) -> Result<http::Response<Body>, Error> {
if async_init && !ready_at_init.load(Ordering::SeqCst) {
is_web_ready(&healthcheck_url).await;
ready_at_init.store(true, Ordering::SeqCst);
}
let path = event.raw_http_path();
let mut path = path.as_str();
let (parts, body) = event.into_parts();
if let Some(base_path) = base_path.as_deref() {
path = path.trim_start_matches(base_path);
}
let mut app_url = domain;
app_url.set_path(path);
app_url.set_query(parts.uri.query());
tracing::debug!(app_url = %app_url, "sending request to server");
let app_response = client
.request(parts.method, app_url.to_string())
.headers(parts.headers)
.body(body.to_vec())
.send()
.await?;
let mut lambda_response = Response::builder();
let _ = mem::replace(lambda_response.headers_mut().unwrap(), app_response.headers().clone());
let status = app_response.status();
let body = convert_body(app_response).await?;
let resp = lambda_response.status(status).body(body).map_err(Box::new)?;
Ok(resp)
}
async fn is_web_ready(url: &str) -> bool {
Retry::spawn(FixedInterval::from_millis(10), || check_web_readiness(url))
.await
.is_ok()
}
async fn check_web_readiness(url: &str) -> Result<(), i8> {
match reqwest::get(url).await {
Ok(response) if { response.status().is_success() } => Ok(()),
_ => Err(-1),
}
}
pub fn register_default_extension() {
tokio::task::spawn(async move {
match Extension::new().with_events(&[]).run().await {
Ok(_) => {}
Err(err) => {
tracing::error!(err = err, "extension terminated unexpectedly");
panic!("extension thread execution");
}
}
});
}
async fn convert_body(app_response: reqwest::Response) -> Result<Body, Error> {
tracing::debug!(resp_headers = ?app_response.headers(), "converting response body");
if app_response.headers().get(http::header::CONTENT_ENCODING).is_some() {
let content = app_response.bytes().await?;
return Ok(Body::Binary(content.to_vec()));
}
match app_response.headers().get(http::header::CONTENT_TYPE) {
Some(value) => {
let content_type = value.to_str().unwrap_or_default();
if content_type.starts_with("text")
|| content_type.starts_with("application/json")
|| content_type.starts_with("application/javascript")
|| content_type.starts_with("application/xml")
{
Ok(Body::Text(app_response.text().await?))
} else {
let content = app_response.bytes().await?;
if content.is_empty() {
Ok(Body::Empty)
} else {
Ok(Body::Binary(content.to_vec()))
}
}
}
None => Ok(Body::Text(app_response.text().await?)),
}
}