use crate::LambdaInvocation;
use futures::{future::BoxFuture, ready, FutureExt, TryFutureExt};
use hyper::body::Incoming;
use lambda_runtime_api_client::{body::Body, BoxError, Client};
use pin_project::pin_project;
use std::{future::Future, pin::Pin, sync::Arc, task};
use tower::Service;
use tracing::error;
pub struct RuntimeApiClientService<S> {
inner: S,
client: Arc<Client>,
}
impl<S> RuntimeApiClientService<S> {
pub fn new(inner: S, client: Arc<Client>) -> Self {
Self { inner, client }
}
}
impl<S> Service<LambdaInvocation> for RuntimeApiClientService<S>
where
S: Service<LambdaInvocation, Error = BoxError>,
S::Future: Future<Output = Result<http::Request<Body>, BoxError>>,
{
type Response = ();
type Error = S::Error;
type Future = RuntimeApiClientFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: LambdaInvocation) -> Self::Future {
let request_fut = self.inner.call(req);
let client = self.client.clone();
RuntimeApiClientFuture::First(request_fut, client)
}
}
impl<S> Clone for RuntimeApiClientService<S>
where
S: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
client: self.client.clone(),
}
}
}
#[pin_project(project = RuntimeApiClientFutureProj)]
pub enum RuntimeApiClientFuture<F> {
First(#[pin] F, Arc<Client>),
Second(#[pin] BoxFuture<'static, Result<http::Response<Incoming>, BoxError>>),
}
impl<F> Future for RuntimeApiClientFuture<F>
where
F: Future<Output = Result<http::Request<Body>, BoxError>>,
{
type Output = Result<(), BoxError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
task::Poll::Ready(loop {
match self.as_mut().project() {
RuntimeApiClientFutureProj::First(fut, client) => match ready!(fut.poll(cx)) {
Ok(ok) => {
let next_fut = client
.call(ok)
.map_err(|err| {
error!(error = ?err, "failed to send request to Lambda Runtime API");
err
})
.boxed();
self.set(RuntimeApiClientFuture::Second(next_fut));
}
Err(err) => {
log_or_print!(
tracing: tracing::error!(error = ?err, "failed to build Lambda Runtime API request"),
fallback: eprintln!("failed to build Lambda Runtime API request: {err:?}")
);
break Err(err);
}
},
RuntimeApiClientFutureProj::Second(fut) => match ready!(fut.poll(cx)) {
Ok(resp) if !resp.status().is_success() => {
let status = resp.status();
log_or_print!(
tracing: tracing::error!(status = %status, "Lambda Runtime API returned non-200 response"),
fallback: eprintln!("Lambda Runtime API returned non-200 response: status={status}")
);
if status == 410 {
log_or_print!(
tracing: tracing::error!("Lambda function timeout!"),
fallback: eprintln!("Lambda function timeout!")
);
}
break Ok(());
}
Ok(_) => break Ok(()),
Err(err) => {
log_or_print!(
tracing: tracing::error!(error = ?err, "Lambda Runtime API request failed"),
fallback: eprintln!("Lambda Runtime API request failed: {err:?}")
);
break Err(err);
}
},
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::StatusCode;
use http_body_util::Full;
use hyper::body::Bytes;
use lambda_runtime_api_client::body::Body;
use std::convert::Infallible;
use tokio::net::TcpListener;
use tracing_test::traced_test;
async fn start_mock_server(status: StatusCode) -> String {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let url = format!("http://{}", addr);
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let io = hyper_util::rt::TokioIo::new(stream);
let service = hyper::service::service_fn(move |_req| async move {
Ok::<_, Infallible>(
http::Response::builder()
.status(status)
.body(Full::new(Bytes::from("test response")))
.unwrap(),
)
});
let _ = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
.serve_connection(io, service)
.await;
});
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
url
}
#[tokio::test]
#[traced_test]
async fn test_successful_response() {
let url = start_mock_server(StatusCode::OK).await;
let client = Arc::new(
lambda_runtime_api_client::Client::builder()
.with_endpoint(url.parse().unwrap())
.build()
.unwrap(),
);
let request_fut =
async { Ok::<_, BoxError>(http::Request::builder().uri("/test").body(Body::empty()).unwrap()) };
let future = RuntimeApiClientFuture::First(request_fut, client);
let result = future.await;
assert!(result.is_ok());
assert!(!logs_contain("Lambda Runtime API returned non-200 response"));
}
#[tokio::test]
#[traced_test]
async fn test_410_timeout_error() {
let url = start_mock_server(StatusCode::GONE).await;
let client = Arc::new(
lambda_runtime_api_client::Client::builder()
.with_endpoint(url.parse().unwrap())
.build()
.unwrap(),
);
let request_fut =
async { Ok::<_, BoxError>(http::Request::builder().uri("/test").body(Body::empty()).unwrap()) };
let future = RuntimeApiClientFuture::First(request_fut, client);
let result = future.await;
assert!(result.is_ok());
assert!(logs_contain("Lambda Runtime API returned non-200 response"));
assert!(logs_contain("Lambda function timeout!"));
}
#[tokio::test]
#[traced_test]
async fn test_500_error() {
let url = start_mock_server(StatusCode::INTERNAL_SERVER_ERROR).await;
let client = Arc::new(
lambda_runtime_api_client::Client::builder()
.with_endpoint(url.parse().unwrap())
.build()
.unwrap(),
);
let request_fut =
async { Ok::<_, BoxError>(http::Request::builder().uri("/test").body(Body::empty()).unwrap()) };
let future = RuntimeApiClientFuture::First(request_fut, client);
let result = future.await;
assert!(result.is_ok());
assert!(logs_contain("Lambda Runtime API returned non-200 response"));
}
#[tokio::test]
#[traced_test]
async fn test_404_error() {
let url = start_mock_server(StatusCode::NOT_FOUND).await;
let client = Arc::new(
lambda_runtime_api_client::Client::builder()
.with_endpoint(url.parse().unwrap())
.build()
.unwrap(),
);
let request_fut =
async { Ok::<_, BoxError>(http::Request::builder().uri("/test").body(Body::empty()).unwrap()) };
let future = RuntimeApiClientFuture::First(request_fut, client);
let result = future.await;
assert!(result.is_ok());
assert!(logs_contain("Lambda Runtime API returned non-200 response"));
}
#[tokio::test]
#[traced_test]
async fn test_request_build_error() {
let client = Arc::new(
lambda_runtime_api_client::Client::builder()
.with_endpoint("http://localhost:9001".parse().unwrap())
.build()
.unwrap(),
);
let request_fut = async { Err::<http::Request<Body>, BoxError>("Request build error".into()) };
let future = RuntimeApiClientFuture::First(request_fut, client);
let result = future.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Request build error"));
assert!(logs_contain("failed to build Lambda Runtime API request"));
}
#[tokio::test]
#[traced_test]
async fn test_network_error() {
let client = Arc::new(
lambda_runtime_api_client::Client::builder()
.with_endpoint("http://127.0.0.1:1".parse().unwrap()) .build()
.unwrap(),
);
let request_fut =
async { Ok::<_, BoxError>(http::Request::builder().uri("/test").body(Body::empty()).unwrap()) };
let future = RuntimeApiClientFuture::First(request_fut, client);
let result = future.await;
assert!(result.is_err());
assert!(logs_contain("Lambda Runtime API request failed"));
}
}