#![deny(clippy::all, clippy::cargo)]
#![warn(missing_docs, nonstandard_style, rust_2018_idioms)]
pub use crate::types::Context;
use client::Client;
use hyper::client::{connect::Connection, HttpConnector};
use serde::{Deserialize, Serialize};
use std::{
convert::{TryFrom, TryInto},
env, fmt,
future::Future,
sync::Arc,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_stream::{Stream, StreamExt};
use tower_service::Service;
use tracing::{error, trace};
mod client;
mod requests;
#[cfg(test)]
mod simulated;
mod types;
use requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEventRequest};
use types::Diagnostic;
pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
#[derive(Debug, Default, Clone, PartialEq)]
pub struct Config {
pub endpoint: String,
pub function_name: String,
pub memory: i32,
pub version: String,
pub log_stream: String,
pub log_group: String,
}
impl Config {
pub fn from_env() -> Result<Self, Error> {
let conf = Config {
endpoint: env::var("AWS_LAMBDA_RUNTIME_API")?,
function_name: env::var("AWS_LAMBDA_FUNCTION_NAME")?,
memory: env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE")?.parse::<i32>()?,
version: env::var("AWS_LAMBDA_FUNCTION_VERSION")?,
log_stream: env::var("AWS_LAMBDA_LOG_STREAM_NAME")?,
log_group: env::var("AWS_LAMBDA_LOG_GROUP_NAME")?,
};
Ok(conf)
}
}
pub trait Handler<A, B> {
type Error;
type Fut: Future<Output = Result<B, Self::Error>>;
fn call(&self, event: A, context: Context) -> Self::Fut;
}
pub fn handler_fn<F>(f: F) -> HandlerFn<F> {
HandlerFn { f }
}
#[derive(Clone, Debug)]
pub struct HandlerFn<F> {
f: F,
}
impl<F, A, B, Error, Fut> Handler<A, B> for HandlerFn<F>
where
F: Fn(A, Context) -> Fut,
Fut: Future<Output = Result<B, Error>> + Send,
Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>> + fmt::Display,
{
type Error = Error;
type Fut = Fut;
fn call(&self, req: A, ctx: Context) -> Self::Fut {
(self.f)(req, ctx)
}
}
#[non_exhaustive]
#[derive(Debug, PartialEq)]
enum BuilderError {
UnsetUri,
}
struct Runtime<C: Service<http::Uri> = HttpConnector> {
client: Client<C>,
}
impl Runtime {
pub fn builder() -> RuntimeBuilder<HttpConnector> {
RuntimeBuilder {
connector: HttpConnector::new(),
uri: None,
}
}
}
impl<C> Runtime<C>
where
C: Service<http::Uri> + Clone + Send + Sync + Unpin + 'static,
<C as Service<http::Uri>>::Future: Unpin + Send,
<C as Service<http::Uri>>::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
<C as Service<http::Uri>>::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static,
{
pub async fn run<F, A, B>(
&self,
incoming: impl Stream<Item = Result<http::Response<hyper::Body>, Error>> + Send,
handler: F,
) -> Result<(), Error>
where
F: Handler<A, B> + Send + Sync + 'static,
<F as Handler<A, B>>::Fut: Future<Output = Result<B, <F as Handler<A, B>>::Error>> + Send + 'static,
<F as Handler<A, B>>::Error: fmt::Display + Send + Sync + 'static,
A: for<'de> Deserialize<'de> + Send + Sync + 'static,
B: Serialize + Send + Sync + 'static,
{
let client = &self.client;
let handler = Arc::new(handler);
tokio::pin!(incoming);
while let Some(event) = incoming.next().await {
trace!("New event arrived (run loop)");
let event = event?;
let (parts, body) = event.into_parts();
let ctx: Context = Context::try_from(parts.headers)?;
let body = hyper::body::to_bytes(body).await?;
trace!("{}", std::str::from_utf8(&body)?);
let body = serde_json::from_slice(&body)?;
let handler = Arc::clone(&handler);
let request_id = &ctx.request_id.clone();
#[allow(clippy::async_yields_async)]
let task = tokio::spawn(async move { handler.call(body, ctx) });
let req = match task.await {
Ok(response) => match response.await {
Ok(response) => {
trace!("Ok response from handler (run loop)");
EventCompletionRequest {
request_id,
body: response,
}
.into_req()
}
Err(err) => {
error!("{}", err);
EventErrorRequest {
request_id,
diagnostic: Diagnostic {
error_type: type_name_of_val(&err).to_owned(),
error_message: format!("{}", err),
},
}
.into_req()
}
},
Err(err) if err.is_panic() => {
error!("{:?}", err);
EventErrorRequest {
request_id,
diagnostic: Diagnostic {
error_type: type_name_of_val(&err).to_owned(),
error_message: format!("Lambda panicked: {}", err),
},
}
.into_req()
}
Err(_) => unreachable!("tokio::task should not be canceled"),
};
let req = req?;
client.call(req).await.expect("Unable to send response to Runtime APIs");
}
Ok(())
}
}
struct RuntimeBuilder<C: Service<http::Uri> = hyper::client::HttpConnector> {
connector: C,
uri: Option<http::Uri>,
}
impl<C> RuntimeBuilder<C>
where
C: Service<http::Uri> + Clone + Send + Sync + Unpin + 'static,
<C as Service<http::Uri>>::Future: Unpin + Send,
<C as Service<http::Uri>>::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
<C as Service<http::Uri>>::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static,
{
pub fn with_connector<C2>(self, connector: C2) -> RuntimeBuilder<C2>
where
C2: Service<http::Uri> + Clone + Send + Sync + Unpin + 'static,
<C2 as Service<http::Uri>>::Future: Unpin + Send,
<C2 as Service<http::Uri>>::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
<C2 as Service<http::Uri>>::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static,
{
RuntimeBuilder {
connector,
uri: self.uri,
}
}
pub fn with_endpoint(self, uri: http::Uri) -> Self {
Self { uri: Some(uri), ..self }
}
pub fn build(self) -> Result<Runtime<C>, BuilderError> {
let uri = match self.uri {
Some(uri) => uri,
None => return Err(BuilderError::UnsetUri),
};
let client = Client::with(uri, self.connector);
Ok(Runtime { client })
}
}
#[test]
fn test_builder() {
let runtime = Runtime::builder()
.with_connector(HttpConnector::new())
.with_endpoint(http::Uri::from_static("http://nomatter.com"))
.build();
runtime.unwrap();
}
fn incoming<C>(client: &Client<C>) -> impl Stream<Item = Result<http::Response<hyper::Body>, Error>> + Send + '_
where
C: Service<http::Uri> + Clone + Send + Sync + Unpin + 'static,
<C as Service<http::Uri>>::Future: Unpin + Send,
<C as Service<http::Uri>>::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
<C as Service<http::Uri>>::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static,
{
async_stream::stream! {
loop {
trace!("Waiting for next event (incoming loop)");
let req = NextEventRequest.into_req().expect("Unable to construct request");
let res = client.call(req).await;
yield res;
}
}
}
pub async fn run<A, B, F>(handler: F) -> Result<(), Error>
where
F: Handler<A, B> + Send + Sync + 'static,
<F as Handler<A, B>>::Fut: Future<Output = Result<B, <F as Handler<A, B>>::Error>> + Send + 'static,
<F as Handler<A, B>>::Error: fmt::Display + Send + Sync + 'static,
A: for<'de> Deserialize<'de> + Send + Sync + 'static,
B: Serialize + Send + Sync + 'static,
{
trace!("Loading config from env");
let config = Config::from_env()?;
let uri = config.endpoint.try_into().expect("Unable to convert to URL");
let runtime = Runtime::builder()
.with_connector(HttpConnector::new())
.with_endpoint(uri)
.build()
.expect("Unable create runtime");
let client = &runtime.client;
let incoming = incoming(client);
runtime.run(incoming, handler).await
}
fn type_name_of_val<T>(_: T) -> &'static str {
std::any::type_name::<T>()
}